blob: dd2d72880431cd8804a77dc979891ff398496953 [file] [log] [blame]
import torch
from typing import Union
class _InsertPoint(object):
def __init__(self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]):
self.insert_point = insert_point
self.g = insert_point_graph
self.guard = None
def __enter__(self):
self.prev_insert_point = self.g.insertPoint()
self.g.setInsertPoint(self.insert_point)
def __exit__(self, *args):
self.g.setInsertPoint(self.prev_insert_point)
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
return _InsertPoint(self, insert_point)