| from torch.fx.experimental.graph_gradual_typechecker import Refine |
| from torch.fx.tensor_type import TensorType |
| from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] |
| |
| |
| def infer_symbolic_types_single_pass(traced): |
| """ |
| Calls our symbolic inferencer once. |
| """ |
| r = Refine(traced) |
| r.refine() |
| mgu = unify_eq(r.constraints) |
| substitute_all_types(traced.graph, mgu) |
| |
| def infer_symbolic_types(traced): |
| """ |
| Calls our symbolic inferencer twice. |
| This is useful when one pass is not enough |
| to infer all the information such as the case |
| for braodcasting. |
| """ |
| r = Refine(traced) |
| r.refine() |
| mgu = unify_eq(r.constraints) |
| substitute_all_types(traced.graph, mgu) |
| |
| r = Refine(traced) |
| r.refine() |
| mgu = unify_eq(r.constraints) |
| substitute_all_types(traced.graph, mgu) |
| |
| r.symbolic_relations() |
| |
| def convert_eq(list_of_eq): |
| """ |
| Convert equality constraints in the right format |
| to be used by unification library. |
| """ |
| lhs = [] |
| rhs = [] |
| for eq in list_of_eq: |
| lhs.append(eq.lhs) |
| rhs.append(eq.rhs) |
| return tuple(lhs), tuple(rhs) |
| |
| |
| def unify_eq(list_of_eq): |
| """ |
| Apply unification to a set of |
| equality constraints |
| """ |
| lhs, rhs = convert_eq(list_of_eq) |
| return unify(lhs, rhs) |
| |
| |
| def substitute_solution_one_type(mapping, t): |
| """ |
| Apply the most general unifier to a type |
| """ |
| if isinstance(t, Var): |
| if t in mapping.keys(): |
| return mapping[t] |
| else: |
| return t |
| |
| elif isinstance(t, TensorType): |
| new_type = [] |
| for typ in t.__args__: |
| if typ in mapping.keys(): |
| new_type.append(mapping[typ]) |
| else: |
| new_type.append(typ) |
| return TensorType(tuple(new_type)) |
| |
| elif isinstance(t, list): |
| new_type = [] |
| for typ in t: |
| new_type.append(substitute_solution_one_type(mapping, typ)) |
| return new_type |
| |
| elif isinstance(t, tuple): |
| new_type = [] |
| for typ in t: |
| new_type.append(substitute_solution_one_type(mapping, typ)) |
| return tuple(new_type) |
| |
| else: |
| return t |
| |
| |
| def substitute_all_types(graph, mapping): |
| """ |
| Apply the most general unifier to all types in a graph |
| till reaching a fixed point. If the input and output graph |
| are the same, we converge. |
| """ |
| flag = True |
| while flag: |
| flag = False |
| for k in mapping: |
| old_mapping_val = mapping[k] |
| if mapping[k] in mapping.keys(): |
| new_key = mapping[k] |
| mapping[k] = mapping[new_key] |
| if old_mapping_val != mapping[k]: |
| flag = True |
| |
| for n in graph.nodes: |
| n.type = substitute_solution_one_type(mapping, n.type) |
| |
| def check_for_type_equality(g1, g2): |
| """ |
| A check equality to be used in fixed points. |
| We do not use graph equality but instead type |
| equality. |
| """ |
| for n, m in zip(g1.nodes, g2.nodes): |
| if n.type != m.type: |
| return False |
| return True |