team-10/venv/Lib/site-packages/torch/jit/_ir_utils.py
2025-08-02 02:00:33 +02:00

33 lines
886 B
Python

from types import TracebackType
from typing import Optional, Union
import torch
class _InsertPoint:
def __init__(
self,
insert_point_graph: torch._C.Graph,
insert_point: Union[torch._C.Node, torch._C.Block],
):
self.insert_point = insert_point
self.g = insert_point_graph
self.guard = None
def __enter__(self) -> None:
self.prev_insert_point = self.g.insertPoint()
self.g.setInsertPoint(self.insert_point)
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.g.setInsertPoint(self.prev_insert_point)
def insert_point_guard(
self: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]
) -> _InsertPoint:
return _InsertPoint(self, insert_point)