IndexSelect constraints, add a missing constraint in constraint tranformations and modify tests accordingly (#81344)
- Constraints for IndexSelect, which allow us to replace a dimension from the input
- There was a bug, which is that for a constraint of the form "x \leq 4" where x is a tensor, we did not generate the constraint" x = Dyn in the transformation step. This is fixed and the tests are modified accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81344
Approved by: https://github.com/jansel, https://github.com/jamesr66a
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index 7c95d79..0b941aa 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -9,7 +9,7 @@
from torch.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency
from torch.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\
evaluate_conditional_with_constraints
-from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D
+from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.tensor_type import Dyn, TensorType
import torch
@@ -32,6 +32,40 @@
class HFOperations(unittest.TestCase):
+ def test_index_select(self):
+ class BasicBlock(torch.nn.Module):
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+
+ def forward(self, x: TensorType([2050, 1024]), y: Dyn):
+ index_select = x.index_select(0, y)
+ return index_select
+ symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock())
+ # print(symbolic_traced)
+ b = BasicBlock().forward(torch.rand(2050, 1024), torch.ones(8).int())
+ transformed = transform_all_constraints(symbolic_traced, counter=0)
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEqual(s.check(), z3.sat)
+ index_select = z3.Const(3, tensor_type)
+
+ # the second dimension of the result should not be affected since
+ # the index is 0
+ self.assertEqual(s.model()[index_select].arg(1).arg(1), b.shape[1])
+
+ replacement_vector = z3.Const(2, tensor_type)
+
+ # we set the vector to Dyn
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEqual(s.check(), z3.sat)
+ index_select = z3.Const(3, tensor_type)
+ s.add(replacement_vector == z3_dyn)
+ self.assertEqual(s.check(), z3.sat)
+
+ # this implies that the index at 0 should be dyn
+ self.assertEqual(s.model()[index_select].arg(0).arg(0), 0)
+
def test_get_attr(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
@@ -176,52 +210,6 @@
assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3]
- def test_masked_fill(self):
- class BasicBlock(torch.nn.Module):
- def __init__(self):
- super(BasicBlock, self).__init__()
-
- def forward(self, x: TensorType([2, 4])):
- size = x.size()
- getitem = size[-1]
- arange = torch.arange(getitem)
- view = x.view(-1, getitem)
- lt = arange > view
- masked_fill = x.masked_fill_(lt, 0)
- return masked_fill
-
- B = BasicBlock().forward(torch.rand(2, 4))
- # print(B.shape)
-
- symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={})
- # print(symbolic_traced)
- transformed = transform_all_constraints(symbolic_traced, counter=0)
- s = z3.Solver()
- s.add(transformed)
- self.assertEqual(s.check(), z3.sat)
- masked_fill_res = z3.Const(10, tensor_type)
- self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0])
- self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1])
-
- # change the annotation to Dyn. This will migrate to an arbitirary type
- for n in symbolic_traced.graph.nodes:
- if n.op == 'placeholder':
- n.type = Dyn
-
- transformed = transform_all_constraints(symbolic_traced, counter=0)
- s = z3.Solver()
- s.add(transformed)
- self.assertEqual(s.check(), z3.sat)
-
- for n in symbolic_traced.graph.nodes:
- if n.op == 'placeholder':
- n.type = TensorType([Dyn, Dyn, Dyn, Dyn])
-
- transformed = transform_all_constraints(symbolic_traced, counter=0)
- s = z3.Solver()
- s.add(transformed)
- self.assertEqual(s.check(), z3.sat)
-
def test_layer_norm(self):
@@ -300,11 +288,15 @@
# migrate one of the parameters to a fully static shape so we can compare
input = z3.Const(1, tensor_type)
+ input_2 = z3.Const(2, tensor_type)
+ s1, s2 = z3.Ints('s1 s2')
+
output_long = z3.Const(8, tensor_type)
s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4)))
+ s.add(input_2 == tensor_type.tensor2(D(1, s1), D(1, s2)))
+
self.assertEquals(s.check(), z3.sat)
actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape
-
self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0])
self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1])
@@ -665,7 +657,7 @@
# check that the item is correct
self.assertEquals(s.model()[s1], s.model()[s2])
- # invalid index
+ # invalid index but should still be SAT because input will be Dyn
class BasicBlock(torch.nn.Module):
def __init__(self):
super(BasicBlock, self).__init__()
@@ -683,7 +675,9 @@
s = z3.Solver()
s.add(transformed)
- self.assertEquals(s.check(), z3.unsat)
+ self.assertEquals(s.check(), z3.sat)
+ s.add(input != z3_dyn)
+ self.assertEqual(s.check(), z3.unsat)
def test_view_mul(self):
class BasicBlock(torch.nn.Module):
@@ -882,6 +876,52 @@
class ComposeOperationsGradualTypes(unittest.TestCase):
+ def test_masked_fill(self):
+ class BasicBlock(torch.nn.Module):
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+
+ def forward(self, x: TensorType([2, 4])):
+ size = x.size()
+ getitem = size[-1]
+ arange = torch.arange(getitem)
+ view = x.view(-1, getitem)
+ lt = arange > view
+ masked_fill = x.masked_fill_(lt, 0)
+ return masked_fill
+
+ B = BasicBlock().forward(torch.rand(2, 4))
+ # print(B.shape)
+
+ symbolic_traced: torch.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={})
+ # print(symbolic_traced)
+ transformed = transform_all_constraints(symbolic_traced, counter=0)
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEqual(s.check(), z3.sat)
+ masked_fill_res = z3.Const(10, tensor_type)
+ self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0])
+ self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1])
+
+ # change the annotation to Dyn. This will migrate to an arbitirary type
+ for n in symbolic_traced.graph.nodes:
+ if n.op == 'placeholder':
+ n.type = Dyn
+
+ transformed = transform_all_constraints(symbolic_traced, counter=0)
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEqual(s.check(), z3.sat)
+
+ for n in symbolic_traced.graph.nodes:
+ if n.op == 'placeholder':
+ n.type = TensorType([Dyn, Dyn, Dyn, Dyn])
+
+ transformed = transform_all_constraints(symbolic_traced, counter=0)
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEqual(s.check(), z3.sat)
+
def test_add_reshape_1(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py
index 6fb7133..2b6415a68e0 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint.py
@@ -211,6 +211,49 @@
return False
+
+class IndexSelect(Constraint):
+
+ def __init__(self, tensor_size, input_var, dim_replace, index, output):
+ """
+ Args:
+ input_var: input to index_select
+ tensor_size: tensor size we are considering
+ dim_replace: the dimension of the output at "index"
+ index: location of the dimensiont to replace in the input
+ outut: variable to store the result
+ """
+ assert isinstance(input_var, TVar)
+ assert isinstance(output, TVar)
+ assert isinstance(dim_replace, DVar) or dim_replace == Dyn
+ assert isinstance(index, int)
+
+ self.input_var = input_var
+ self.tensor_size = tensor_size
+ self.dim_replace = dim_replace
+ self.index = index
+ self.output = output
+
+ def __repr__(self):
+
+ return f' {self.output} = ' \
+ f'IndexSelect({self.input_var}, ' \
+ f'tensor_size: {self.tensor_size}, ' \
+ f'{self.dim_replace}, ' \
+ f'{self.index})'
+
+ def __eq__(self, other):
+ if isinstance(other, IndexSelect):
+ return self.tensor_size == other.tensor_size and\
+ self.dim_replace == other.dim_replace and\
+ self.index == other.index and\
+ self.output == other.output and\
+ self.input_var == other.input_var
+ else:
+ return False
+
+
+
class GetItem(Constraint):
def __init__(self, tensor_size, index, res, input_var):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index d0b14e1..90ddb3d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -5,7 +5,7 @@
from torch.fx._symbolic_trace import _assert_is_none
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
- TVar, DVar, GetItemTensor
+ TVar, DVar, GetItemTensor, IndexSelect
from torch.fx.experimental.migrate_gradual_types.operation import \
op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add
from torch.fx.node import Target, Node
@@ -58,6 +58,38 @@
else:
raise NotImplementedError('Not yet implemented')
+
+@register_inference_rule("index_select")
+def index_select_inference_rule(n: Node, symbols, constraints, counter):
+ """
+ We constrain the second argument to a vector or Dyn.
+ The output replaces the input with the shape of the vector
+ at the position given by the index (first argument)
+ """
+ # print(n.args)
+ assert isinstance(n.args[0], Node)
+ assert isinstance(n.args[1], int)
+ assert isinstance(n.args[2], Node)
+
+
+
+ index_select, counter = gen_tvar(counter)
+ symbols[n] = index_select
+
+ dims, counter = gen_tensor_dims(1, counter)
+
+ # matching constraint
+ is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
+ is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
+
+ c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
+ for i in range(MAX_TENSOR_RANK)])])
+ c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
+ for i in range(MAX_TENSOR_RANK)])])
+
+ return [Disj([c2, c3])], counter
+
+
@register_inference_rule("expand")
def expand_inference_rule(n: Node, symbols, constraints, counter):
"""
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
index d8e6b9f..4be1209 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -1,4 +1,5 @@
# mypy: ignore-errors
+import copy
import itertools
from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar
@@ -6,7 +7,7 @@
from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
-from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor
+from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
@@ -37,6 +38,32 @@
return F()
+@register_transformation_rule(IndexSelect)
+def transform_index_select(constraint, counter):
+ """
+ The constraints consider the given tensor size, checks if the index is valid
+ and if so, generates a constraint for replacing the input dimension
+ with the required dimension
+ """
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+ is_valid_index = valid_index(constraint.index, dims)
+ nat_constraints = gen_nat_constraints(dims)
+
+ # if the index is valid then replace the input dimension with the new dimension
+ # otherwise the dimension will not be replaced and the clause will contain False
+ if is_valid_index == T():
+ new_dims = copy.deepcopy((dims))
+ new_dims[constraint.index] = constraint.dim_replace
+
+
+ transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+ *nat_constraints,
+ is_valid_index,
+ BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
+
+ # print(constraints)
+ return transformed_constraint, counter
+
@register_transformation_rule(GetItem)
def transform_get_item(constraint, counter):
"""
@@ -198,7 +225,7 @@
elif constraint.op == op_leq:
assert isinstance(constraint.rhs, int)
- disj = []
+ disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
for i in range(1, constraint.rhs + 1):
dims = []
for j in range(1, i + 1):
diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
index d8381d9..6c49b30 100644
--- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
+++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
@@ -107,6 +107,7 @@
raise NotImplementedError('operation not yet implemented')
else:
+ # print(constraint)
raise NotImplementedError('Operation not yet implemented')