from types import TracebackType from typing import Optional, Union import torch class _InsertPoint: def __init__( self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block], ): self.insert_point = insert_point self.g = insert_point_graph self.guard = None def __enter__(self) -> None: self.prev_insert_point = self.g.insertPoint() self.g.setInsertPoint(self.insert_point) def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.g.setInsertPoint(self.prev_insert_point) def insert_point_guard( self: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block] ) -> _InsertPoint: return _InsertPoint(self, insert_point)