blob: b8f68d7073acc5fa7b36bc3c944a6a1ac85bb564 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
# flake8: noqa
from dataclasses import dataclass, field, InitVar
from hypothesis import given, settings, strategies as st
from torch.testing._internal.jit_utils import JitTestCase
from typing import List, Optional
import sys
import torch
import unittest
from enum import Enum
# Example jittable dataclass
@dataclass(order=True)
class Point:
x: float
y: float
norm: Optional[torch.Tensor] = None
def __post_init__(self):
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
class MixupScheme(Enum):
INPUT = ["input"]
MANIFOLD = [
"input",
"before_fusion_projection",
"after_fusion_projection",
"after_classifier_projection",
]
@dataclass
class MixupParams:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
self.alpha = alpha
self.scheme = scheme
class MixupScheme2(Enum):
A = 1
B = 2
@dataclass
class MixupParams2:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
self.alpha = alpha
self.scheme = scheme
@dataclass
class MixupParams3:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
self.alpha = alpha
self.scheme = scheme
# Make sure the Meta internal tooling doesn't raise an overflow error
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
class TestDataclasses(JitTestCase):
@classmethod
def tearDownClass(cls):
torch._C._jit_clear_class_registry()
def test_init_vars(self):
@torch.jit.script
@dataclass(order=True)
class Point2:
x: float
y: float
norm_p: InitVar[int] = 2
norm: Optional[torch.Tensor] = None
def __post_init__(self, norm_p: int):
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
def fn(x: float, y: float, p: int):
pt = Point2(x, y, p)
return pt.norm
self.checkScript(fn, (1.0, 2.0, 3))
# Sort of tests both __post_init__ and optional fields
@settings(deadline=None)
@given(NonHugeFloats, NonHugeFloats)
def test__post_init__(self, x, y):
P = torch.jit.script(Point)
def fn(x: float, y: float):
pt = P(x, y)
return pt.norm
self.checkScript(fn, [x, y])
@settings(deadline=None)
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
def test_comparators(self, pt1, pt2):
x1, y1 = pt1
x2, y2 = pt2
P = torch.jit.script(Point)
def compare(x1: float, y1: float, x2: float, y2: float):
pt1 = P(x1, y1)
pt2 = P(x2, y2)
return (
pt1 == pt2,
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
pt1 < pt2,
pt1 <= pt2,
pt1 > pt2,
pt1 >= pt2,
)
self.checkScript(compare, [x1, y1, x2, y2])
def test_default_factories(self):
@dataclass
class Foo(object):
x: List[int] = field(default_factory=list)
with self.assertRaises(NotImplementedError):
torch.jit.script(Foo)
def fn():
foo = Foo()
return foo.x
torch.jit.script(fn)()
# The user should be able to write their own __eq__ implementation
# without us overriding it.
def test_custom__eq__(self):
@torch.jit.script
@dataclass
class CustomEq:
a: int
b: int
def __eq__(self, other: 'CustomEq') -> bool:
return self.a == other.a # ignore the b field
def fn(a: int, b1: int, b2: int):
pt1 = CustomEq(a, b1)
pt2 = CustomEq(a, b2)
return pt1 == pt2
self.checkScript(fn, [1, 2, 3])
def test_no_source(self):
with self.assertRaises(RuntimeError):
# uses list in Enum is not supported
torch.jit.script(MixupParams)
torch.jit.script(MixupParams2) # don't throw
def test_use_unregistered_dataclass_raises(self):
def f(a: MixupParams3):
return 0
with self.assertRaises(OSError):
torch.jit.script(f)