[dynamo] Compile torchvision augmentations (#100292)
Resolves https://github.com/pytorch/pytorch/issues/100112
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100292
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 855cba4..00f545b 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -3111,6 +3111,23 @@
opt_fn(inp1, inp2, inp3, inp4, c)
self.assertEqual(cnt.frame_count, 3)
+ def test_torch_variable_type(self):
+ # from torchvision
+ def check_type(obj, types_or_checks):
+ for type_or_check in types_or_checks:
+ if (
+ isinstance(obj, type_or_check)
+ if isinstance(type_or_check, type)
+ else type_or_check(obj)
+ ):
+ return True
+ return False
+
+ opt_check_type = torch._dynamo.optimize("eager")(check_type)
+ ref = check_type(torch.randn(4), [torch.Tensor])
+ res = opt_check_type(torch.randn(4), [torch.Tensor])
+ self.assertEqual(ref, res)
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 48a8531..6dc68d3 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -734,7 +734,7 @@
def is_safe_constant(v):
if istype(v, (tuple, frozenset)):
return all(map(is_safe_constant, v))
- return istype(
+ return isinstance(v, (enum.Enum, type)) or istype(
v,
(
types.CodeType,
@@ -749,7 +749,7 @@
torch.device,
torch.dtype,
),
- ) or isinstance(v, enum.Enum)
+ )
def check_constant_args(args, kwargs):
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index 00e4162..099afea 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -237,6 +237,13 @@
assert user_cls is collections.defaultdict
self.default_factory = default_factory
+ def is_python_constant(self):
+ # Return false for unsupported defaults. This ensures that a bad handler
+ # path is not taken in BuiltinVariable for getitem.
+ if self.default_factory not in [list, tuple, dict] and not self.items:
+ return False
+ return super().is_python_constant()
+
def call_method(
self,
tx,
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 10a143a..e332ba8 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -174,7 +174,7 @@
def python_type(self):
if isinstance(self.value, (torch.Tensor, torch.nn.Module)):
return type(self.value)
- if type(self.value) is type:
+ if isinstance(self.value, type):
return type
return super().python_type()