constraints for type_as, long, int (#81265)

The constraints are are mostly just propagating type information
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81265
Approved by: https://github.com/jamesr66a
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index af6e0ef..7c95d79 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -279,6 +279,36 @@
         self.assertEqual(s.model()[output].arg(0).arg(1), b[0])
         self.assertEqual(s.model()[output].arg(1).arg(1), b[1])
 
+    def test_ne_int_long_type_as(self):
+
+        class BasicBlock(torch.nn.Module):
+            def __init__(self):
+                super(BasicBlock, self).__init__()
+
+            def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn, Dyn])):
+                ne_int = torch.ne(x, y).int()
+                type_as = ne_int.type_as(y)
+                long = type_as.long()
+                return long
+
+        symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock())
+        transformed = transform_all_constraints(symbolic_traced, counter=0)
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEquals(s.check(), z3.sat)
+
+        # migrate one of the parameters to a fully static shape so we can compare
+
+        input = z3.Const(1, tensor_type)
+        output_long = z3.Const(8, tensor_type)
+        s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4)))
+        self.assertEquals(s.check(), z3.sat)
+        actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape
+
+        self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0])
+        self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1])
+
+
     def test_ne(self):
         s1, s2 = z3.Ints('s1 s2')
         s11, s22 = z3.Ints('s11 s22')
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index f380e46..5aafe5d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -39,28 +39,6 @@
     return Conj([c1, c2, *nat_constraints]), counter
 
 
-# TODO
-@register_inference_rule("long")
-def long_inference_rule(n: Node, symbols, constraints, counter):
-    """
-    """
-    raise NotImplementedError('Not yet implemented')
-
-# TODO
-@register_inference_rule("type_as")
-def type_as_inference_rule(n: Node, symbols, constraints, counter):
-    """
-    """
-    raise NotImplementedError('Not yet implemented')
-
-# TODO
-@register_inference_rule("int")
-def int_inference_rule(n: Node, symbols, constraints, counter):
-    """
-    """
-    raise NotImplementedError('Not yet implemented')
-
-
 @register_inference_rule(getattr)
 def get_attr_inference_rule(n: Node, symbols, constraints, counter):
     """
@@ -118,7 +96,9 @@
 
 
 @register_inference_rule("to")
-def to_inference_rule(n: Node, symbols, constraints, counter):
+@register_inference_rule("int")
+@register_inference_rule("long")
+def equality_inference_rule(n: Node, symbols, constraints, counter):
     """
     We generate the constraint: input = output
     """
@@ -126,8 +106,30 @@
     output, counter = gen_tvar(counter)
     symbols[n] = output
     input = symbols[n.args[0]]
+    assert isinstance(input, TVar)
     return [BinConstraintT(input, output, op_eq)], counter
 
+
+@register_inference_rule("type_as")
+def type_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    from_arg = symbols[n.args[0]]
+    to_arg = symbols[n.args[1]]
+
+    assert isinstance(from_arg, TVar)
+    assert isinstance(to_arg, TVar)
+
+    return [BinConstraintT(from_arg, to_arg, op_consistency),
+            BinConstraintT(output, to_arg, op_eq)], counter
+
 @register_inference_rule("masked_fill_")
 def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
     """
@@ -649,6 +651,7 @@
     output, counter = gen_tvar(counter)
     symbols[n] = output
     input = symbols[n.args[0]]
+    assert isinstance(input, TVar)
     return [BinConstraintT(input, output, op_eq)], counter
 
 @register_inference_rule(torch.nn.Linear)