constraints for type_as, long, int (#81265)
The constraints are are mostly just propagating type information
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81265
Approved by: https://github.com/jamesr66a
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index af6e0ef..7c95d79 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -279,6 +279,36 @@
self.assertEqual(s.model()[output].arg(0).arg(1), b[0])
self.assertEqual(s.model()[output].arg(1).arg(1), b[1])
+ def test_ne_int_long_type_as(self):
+
+ class BasicBlock(torch.nn.Module):
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+
+ def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn, Dyn])):
+ ne_int = torch.ne(x, y).int()
+ type_as = ne_int.type_as(y)
+ long = type_as.long()
+ return long
+
+ symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock())
+ transformed = transform_all_constraints(symbolic_traced, counter=0)
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEquals(s.check(), z3.sat)
+
+ # migrate one of the parameters to a fully static shape so we can compare
+
+ input = z3.Const(1, tensor_type)
+ output_long = z3.Const(8, tensor_type)
+ s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4)))
+ self.assertEquals(s.check(), z3.sat)
+ actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape
+
+ self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0])
+ self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1])
+
+
def test_ne(self):
s1, s2 = z3.Ints('s1 s2')
s11, s22 = z3.Ints('s11 s22')
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index f380e46..5aafe5d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -39,28 +39,6 @@
return Conj([c1, c2, *nat_constraints]), counter
-# TODO
-@register_inference_rule("long")
-def long_inference_rule(n: Node, symbols, constraints, counter):
- """
- """
- raise NotImplementedError('Not yet implemented')
-
-# TODO
-@register_inference_rule("type_as")
-def type_as_inference_rule(n: Node, symbols, constraints, counter):
- """
- """
- raise NotImplementedError('Not yet implemented')
-
-# TODO
-@register_inference_rule("int")
-def int_inference_rule(n: Node, symbols, constraints, counter):
- """
- """
- raise NotImplementedError('Not yet implemented')
-
-
@register_inference_rule(getattr)
def get_attr_inference_rule(n: Node, symbols, constraints, counter):
"""
@@ -118,7 +96,9 @@
@register_inference_rule("to")
-def to_inference_rule(n: Node, symbols, constraints, counter):
+@register_inference_rule("int")
+@register_inference_rule("long")
+def equality_inference_rule(n: Node, symbols, constraints, counter):
"""
We generate the constraint: input = output
"""
@@ -126,8 +106,30 @@
output, counter = gen_tvar(counter)
symbols[n] = output
input = symbols[n.args[0]]
+ assert isinstance(input, TVar)
return [BinConstraintT(input, output, op_eq)], counter
+
+@register_inference_rule("type_as")
+def type_inference_rule(n: Node, symbols, constraints, counter):
+ """
+ We generate the constraint: input = output
+ """
+ assert isinstance(n.args[0], Node)
+ assert isinstance(n.args[1], Node)
+
+ output, counter = gen_tvar(counter)
+ symbols[n] = output
+
+ from_arg = symbols[n.args[0]]
+ to_arg = symbols[n.args[1]]
+
+ assert isinstance(from_arg, TVar)
+ assert isinstance(to_arg, TVar)
+
+ return [BinConstraintT(from_arg, to_arg, op_consistency),
+ BinConstraintT(output, to_arg, op_eq)], counter
+
@register_inference_rule("masked_fill_")
def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
"""
@@ -649,6 +651,7 @@
output, counter = gen_tvar(counter)
symbols[n] = output
input = symbols[n.args[0]]
+ assert isinstance(input, TVar)
return [BinConstraintT(input, output, op_eq)], counter
@register_inference_rule(torch.nn.Linear)