IndexSelect constraints, add a missing constraint in constraint tranformations and modify tests accordingly (#81344)

- Constraints for IndexSelect, which allow us to replace a dimension from the input
- There was a bug, which is that for a constraint of the form "x \leq 4" where x is a tensor, we did not generate the constraint" x = Dyn in the transformation step. This is fixed and the tests are modified accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81344
Approved by: https://github.com/jansel, https://github.com/jamesr66a
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index 7c95d79..0b941aa 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -9,7 +9,7 @@
 from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency
 from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\
     evaluate_conditional_with_constraints
-from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D
+from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn
 from torch.fx.experimental.rewriter import RewritingTracer
 from torch.fx.tensor_type import Dyn, TensorType
 import torch
@@ -32,6 +32,40 @@
 
 class HFOperations(unittest.TestCase):
 
+    def test_index_select(self):
+        class BasicBlock(torch.nn.Module):
+            def __init__(self):
+                super(BasicBlock, self).__init__()
+
+            def forward(self, x: TensorType([2050, 1024]), y: Dyn):
+                index_select = x.index_select(0, y)
+                return index_select
+        symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock())
+        # print(symbolic_traced)
+        b = BasicBlock().forward(torch.rand(2050, 1024), torch.ones(8).int())
+        transformed = transform_all_constraints(symbolic_traced, counter=0)
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEqual(s.check(), z3.sat)
+        index_select = z3.Const(3, tensor_type)
+
+        # the second dimension of the result should not be affected since
+        # the index is 0
+        self.assertEqual(s.model()[index_select].arg(1).arg(1), b.shape[1])
+
+        replacement_vector = z3.Const(2, tensor_type)
+
+        # we set the vector to Dyn
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEqual(s.check(), z3.sat)
+        index_select = z3.Const(3, tensor_type)
+        s.add(replacement_vector == z3_dyn)
+        self.assertEqual(s.check(), z3.sat)
+
+        # this implies that the index at 0 should be dyn
+        self.assertEqual(s.model()[index_select].arg(0).arg(0), 0)
+
     def test_get_attr(self):
         class BasicBlock(torch.nn.Module):
             def __init__(self):
@@ -176,52 +210,6 @@
         assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3]
 
 
-    def test_masked_fill(self):
-        class BasicBlock(torch.nn.Module):
-            def __init__(self):
-                super(BasicBlock, self).__init__()
-
-            def forward(self, x: TensorType([2, 4])):
-                size = x.size()
-                getitem = size[-1]
-                arange = torch.arange(getitem)
-                view = x.view(-1, getitem)
-                lt = arange > view
-                masked_fill = x.masked_fill_(lt, 0)
-                return masked_fill
-
-        B = BasicBlock().forward(torch.rand(2, 4))
-        # print(B.shape)
-
-        symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={})
-        # print(symbolic_traced)
-        transformed = transform_all_constraints(symbolic_traced, counter=0)
-        s = z3.Solver()
-        s.add(transformed)
-        self.assertEqual(s.check(), z3.sat)
-        masked_fill_res = z3.Const(10, tensor_type)
-        self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0])
-        self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1])
-
-        # change the annotation to Dyn. This will migrate to an arbitirary type
-        for n in symbolic_traced.graph.nodes:
-            if n.op == 'placeholder':
-                n.type = Dyn
-
-        transformed = transform_all_constraints(symbolic_traced, counter=0)
-        s = z3.Solver()
-        s.add(transformed)
-        self.assertEqual(s.check(), z3.sat)
-
-        for n in symbolic_traced.graph.nodes:
-            if n.op == 'placeholder':
-                n.type = TensorType([Dyn, Dyn, Dyn, Dyn])
-
-        transformed = transform_all_constraints(symbolic_traced, counter=0)
-        s = z3.Solver()
-        s.add(transformed)
-        self.assertEqual(s.check(), z3.sat)
-
 
     def test_layer_norm(self):
 
@@ -300,11 +288,15 @@
         # migrate one of the parameters to a fully static shape so we can compare
 
         input = z3.Const(1, tensor_type)
+        input_2 = z3.Const(2, tensor_type)
+        s1, s2 = z3.Ints('s1 s2')
+
         output_long = z3.Const(8, tensor_type)
         s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4)))
+        s.add(input_2 == tensor_type.tensor2(D(1, s1), D(1, s2)))
+
         self.assertEquals(s.check(), z3.sat)
         actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape
-
         self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0])
         self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1])
 
@@ -665,7 +657,7 @@
         # check that the item is correct
         self.assertEquals(s.model()[s1], s.model()[s2])
 
-        # invalid index
+        # invalid index but should still be SAT because input will be Dyn
         class BasicBlock(torch.nn.Module):
             def __init__(self):
                 super(BasicBlock, self).__init__()
@@ -683,7 +675,9 @@
         s = z3.Solver()
         s.add(transformed)
 
-        self.assertEquals(s.check(), z3.unsat)
+        self.assertEquals(s.check(), z3.sat)
+        s.add(input != z3_dyn)
+        self.assertEqual(s.check(), z3.unsat)
 
     def test_view_mul(self):
         class BasicBlock(torch.nn.Module):
@@ -882,6 +876,52 @@
 
 class ComposeOperationsGradualTypes(unittest.TestCase):
 
+    def test_masked_fill(self):
+        class BasicBlock(torch.nn.Module):
+            def __init__(self):
+                super(BasicBlock, self).__init__()
+
+            def forward(self, x: TensorType([2, 4])):
+                size = x.size()
+                getitem = size[-1]
+                arange = torch.arange(getitem)
+                view = x.view(-1, getitem)
+                lt = arange > view
+                masked_fill = x.masked_fill_(lt, 0)
+                return masked_fill
+
+        B = BasicBlock().forward(torch.rand(2, 4))
+        # print(B.shape)
+
+        symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={})
+        # print(symbolic_traced)
+        transformed = transform_all_constraints(symbolic_traced, counter=0)
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEqual(s.check(), z3.sat)
+        masked_fill_res = z3.Const(10, tensor_type)
+        self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0])
+        self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1])
+
+        # change the annotation to Dyn. This will migrate to an arbitirary type
+        for n in symbolic_traced.graph.nodes:
+            if n.op == 'placeholder':
+                n.type = Dyn
+
+        transformed = transform_all_constraints(symbolic_traced, counter=0)
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEqual(s.check(), z3.sat)
+
+        for n in symbolic_traced.graph.nodes:
+            if n.op == 'placeholder':
+                n.type = TensorType([Dyn, Dyn, Dyn, Dyn])
+
+        transformed = transform_all_constraints(symbolic_traced, counter=0)
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEqual(s.check(), z3.sat)
+
     def test_add_reshape_1(self):
         class BasicBlock(torch.nn.Module):
             def __init__(self):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py
index 6fb7133..2b6415a68e0 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint.py
@@ -211,6 +211,49 @@
             return False
 
 
+
+class IndexSelect(Constraint):
+
+    def __init__(self, tensor_size, input_var, dim_replace, index, output):
+        """
+        Args:
+            input_var: input to index_select
+            tensor_size: tensor size we are considering
+            dim_replace: the dimension of the output at "index"
+            index: location of the dimensiont to replace in the input
+            outut: variable to store the result
+        """
+        assert isinstance(input_var, TVar)
+        assert isinstance(output, TVar)
+        assert isinstance(dim_replace, DVar) or dim_replace == Dyn
+        assert isinstance(index, int)
+
+        self.input_var = input_var
+        self.tensor_size = tensor_size
+        self.dim_replace = dim_replace
+        self.index = index
+        self.output = output
+
+    def __repr__(self):
+
+        return f' {self.output} = ' \
+               f'IndexSelect({self.input_var}, ' \
+               f'tensor_size: {self.tensor_size}, ' \
+               f'{self.dim_replace}, ' \
+               f'{self.index})'
+
+    def __eq__(self, other):
+        if isinstance(other, IndexSelect):
+            return self.tensor_size == other.tensor_size and\
+                self.dim_replace == other.dim_replace and\
+                self.index == other.index and\
+                self.output == other.output and\
+                self.input_var == other.input_var
+        else:
+            return False
+
+
+
 class GetItem(Constraint):
 
     def __init__(self, tensor_size, index, res, input_var):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index d0b14e1..90ddb3d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -5,7 +5,7 @@
 from torch.fx._symbolic_trace import _assert_is_none
 from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
     Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
-    TVar, DVar, GetItemTensor
+    TVar, DVar, GetItemTensor, IndexSelect
 from torch.fx.experimental.migrate_gradual_types.operation import \
     op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add
 from torch.fx.node import Target, Node
@@ -58,6 +58,38 @@
     else:
         raise NotImplementedError('Not yet implemented')
 
+
+@register_inference_rule("index_select")
+def index_select_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We constrain the second argument to a vector or Dyn.
+    The output replaces the input with the shape of the vector
+    at the position given by the index (first argument)
+    """
+    # print(n.args)
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], int)
+    assert isinstance(n.args[2], Node)
+
+
+
+    index_select, counter = gen_tvar(counter)
+    symbols[n] = index_select
+
+    dims, counter = gen_tensor_dims(1, counter)
+
+    # matching constraint
+    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
+    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
+
+    c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
+                                for i in range(MAX_TENSOR_RANK)])])
+    c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
+                             for i in range(MAX_TENSOR_RANK)])])
+
+    return [Disj([c2, c3])], counter
+
+
 @register_inference_rule("expand")
 def expand_inference_rule(n: Node, symbols, constraints, counter):
     """
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
index d8e6b9f..4be1209 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -1,4 +1,5 @@
 # mypy: ignore-errors
+import copy
 import itertools
 from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
 from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar
@@ -6,7 +7,7 @@
 from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
 from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
 from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
-from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor
+from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
 from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
 from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
 from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
@@ -37,6 +38,32 @@
         return F()
 
 
+@register_transformation_rule(IndexSelect)
+def transform_index_select(constraint, counter):
+    """
+    The constraints consider the given tensor size, checks if the index is valid
+    and if so, generates a constraint for replacing the input dimension
+    with the required dimension
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index = valid_index(constraint.index, dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # if the index is valid then replace the input dimension with the new dimension
+    # otherwise the dimension will not be replaced and the clause will contain False
+    if is_valid_index == T():
+        new_dims = copy.deepcopy((dims))
+        new_dims[constraint.index] = constraint.dim_replace
+
+
+    transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+                                   *nat_constraints,
+                                   is_valid_index,
+                                   BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
+
+    # print(constraints)
+    return transformed_constraint, counter
+
 @register_transformation_rule(GetItem)
 def transform_get_item(constraint, counter):
     """
@@ -198,7 +225,7 @@
 
     elif constraint.op == op_leq:
         assert isinstance(constraint.rhs, int)
-        disj = []
+        disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
         for i in range(1, constraint.rhs + 1):
             dims = []
             for j in range(1, i + 1):
diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
index d8381d9..6c49b30 100644
--- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
+++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
@@ -107,6 +107,7 @@
                 raise NotImplementedError('operation not yet implemented')
 
         else:
+            # print(constraint)
 
             raise NotImplementedError('Operation not yet implemented')