blob: e97767c000398bbee222111b56c4f4114fc9536a [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import io
import os
import sys
import copy
import unittest
import torch
from typing import Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
find_library_location,
)
from torch.testing import FileCheck
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTorchbind(JitTestCase):
def setUp(self):
if IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location('libtorchbind_test.so')
torch.ops.load_library(str(lib_file_path))
def test_torchbind(self):
def test_equality(f, cmp_key):
obj1 = f()
obj2 = torch.jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
def f():
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
# test nn module with prepare_scriptable function
class NonJitableClass(object):
def __init__(self, int1, int2):
self.int1 = int1
self.int2 = int2
def return_vals(self):
return self.int1, self.int2
class CustomWrapper(torch.nn.Module):
def __init__(self, foo):
super(CustomWrapper, self).__init__()
self.foo = foo
def forward(self) -> None:
self.foo.increment(1)
return
def __prepare_scriptable__(self):
int1, int2 = self.foo.return_vals()
foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
return CustomWrapper(foo)
foo = CustomWrapper(NonJitableClass(1, 2))
jit_foo = torch.jit.script(foo)
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
def foo(stackstring):
# type: (StackString)
stackstring.push("lel")
return stackstring
script_input = torch.classes._TorchScriptTesting._StackString([])
scripted = torch.jit.script(foo)
script_output = scripted(script_input)
self.assertEqual(script_output.pop(), "lel")
def test_torchbind_return_instance(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
return ss
scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\
.check('prim::CallMethod[name="__init__"]')
fc.run(str(scripted.graph))
out = scripted()
self.assertEqual(out.pop(), "mom")
self.assertEqual(out.pop(), "hi")
def test_torchbind_return_instance_from_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
clone = ss.clone()
ss.pop()
return ss, clone
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out[0].pop(), "hi")
self.assertEqual(out[1].pop(), "mom")
self.assertEqual(out[1].pop(), "hi")
def test_torchbind_def_property_getter_setter(self):
def foo_getter_setter_full():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
# getX method intentionally adds 2 to x
old = fooGetterSetter.x
# setX method intentionally adds 2 to x
fooGetterSetter.x = old + 4
new = fooGetterSetter.x
return old, new
self.checkScript(foo_getter_setter_full, ())
def foo_getter_setter_lambda():
foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5)
old = foo.x
foo.x = old + 4
new = foo.x
return old, new
self.checkScript(foo_getter_setter_lambda, ())
def test_torchbind_def_property_just_getter(self):
def foo_just_getter():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
# getY method intentionally adds 4 to x
return fooGetterSetter, fooGetterSetter.y
scripted = torch.jit.script(foo_just_getter)
out, result = scripted()
self.assertEqual(result, 10)
with self.assertRaisesRegex(RuntimeError, 'can\'t set attribute'):
out.y = 5
def foo_not_setter():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
old = fooGetterSetter.y
fooGetterSetter.y = old + 4
# getY method intentionally adds 4 to x
return fooGetterSetter.y
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooGetterSetter.y = old + 4'):
scripted = torch.jit.script(foo_not_setter)
def test_torchbind_def_property_readwrite(self):
def foo_readwrite():
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
old = fooReadWrite.x
fooReadWrite.x = old + 4
return fooReadWrite.x, fooReadWrite.y
self.checkScript(foo_readwrite, ())
def foo_readwrite_error():
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
fooReadWrite.y = 5
return fooReadWrite
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooReadWrite.y = 5'):
scripted = torch.jit.script(foo_readwrite_error)
def test_torchbind_take_instance_as_method_arg(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out.pop(), "hi")
self.assertEqual(out.pop(), "mom")
def test_torchbind_return_tuple(self):
def f():
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
return val.return_a_tuple()
scripted = torch.jit.script(f)
tup = scripted()
self.assertEqual(tup, (1337.0, 123))
def test_torchbind_save_load(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
self.getExportImportCopy(scripted)
def test_torchbind_lambda_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
return ss.top()
scripted = torch.jit.script(foo)
self.assertEqual(scripted(), "mom")
def test_torchbind_class_attr_recursive(self):
class FooBar(torch.nn.Module):
def __init__(self, foo_model):
super(FooBar, self).__init__()
self.foo_mod = foo_model
def forward(self) -> int:
return self.foo_mod.info()
def to_ivalue(self):
torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1)
return FooBar(torchbind_model)
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
scripted = torch.jit.script(inst.to_ivalue())
self.assertEqual(scripted(), 6)
def test_torchbind_class_attribute(self):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
def forward(self):
return self.f.top()
inst = FooBar1234()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
assert eic() == "deserialized"
for expected in ["deserialized", "was", "i"]:
assert eic.f.pop() == expected
def test_torchbind_getstate(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
# return {1, 3, 3, 7}. I tried to make this actually depend on the
# values at instantiation in the test with some transformation, but
# because it seems we serialize/deserialize multiple times, that
# transformation isn't as you would it expect it to be.
assert eic() == 7
for expected in [7, 3, 3, 1]:
assert eic.f.pop() == expected
def test_torchbind_deepcopy(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
copied = copy.deepcopy(scripted)
assert copied.forward() == 7
for expected in [7, 3, 3, 1]:
assert copied.f.pop() == expected
def test_torchbind_python_deepcopy(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
copied = copy.deepcopy(inst)
assert copied() == 7
for expected in [7, 3, 3, 1]:
assert copied.f.pop() == expected
def test_torchbind_tracing(self):
class TryTracing(torch.nn.Module):
def __init__(self):
super(TryTracing, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
traced = torch.jit.trace(TryTracing(), ())
self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pass_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, 'but instead found type \'Tensor\''):
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
def test_torchbind_tracing_nested(self):
class TryTracingNest(torch.nn.Module):
def __init__(self):
super(TryTracingNest, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
class TryTracing123(torch.nn.Module):
def __init__(self):
super(TryTracing123, self).__init__()
self.nest = TryTracingNest()
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
traced = torch.jit.trace(TryTracing123(), ())
self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pickle_serialization(self):
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
nt_loaded = torch.load(b)
for exp in [7, 3, 3, 1]:
self.assertEqual(nt_loaded.pop(), exp)
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
torch.classes.foo.IDontExist(3, 4, 5)
def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self):
super().__init__()
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
def forward(self) -> str:
foo_obj = self.foo
if foo_obj is not None:
return foo_obj.pop()
else:
return '<None>'
mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod)
def test_torchbind_no_init(self):
with self.assertRaisesRegex(RuntimeError, 'torch::init'):
x = torch.classes._TorchScriptTesting._NoInit()
def test_profiler_custom_op(self):
inst = torch.classes._TorchScriptTesting._PickleTester([3, 4])
with torch.autograd.profiler.profile() as prof:
torch.ops._TorchScriptTesting.take_an_instance(inst)
found_event = False
for e in prof.function_events:
if e.name == '_TorchScriptTesting::take_an_instance':
found_event = True
self.assertTrue(found_event)
def test_torchbind_getattr(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
self.assertEqual(None, getattr(foo, 'bar', None))
def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
foo.bar
def test_lambda_as_constructor(self):
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
self.assertEqual(obj_no_swap.diff(), 1)
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
self.assertEqual(obj_swap.diff(), -1)
def test_staticmethod(self):
def fn(inp: int) -> int:
return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp)
self.checkScript(fn, (1,))
def test_default_args(self):
def fn() -> int:
obj = torch.classes._TorchScriptTesting._DefaultArgs()
obj.increment(5)
obj.decrement()
obj.decrement(2)
obj.divide()
obj.scale_add(5)
obj.scale_add(3, 2)
obj.divide(3)
return obj.increment()
self.checkScript(fn, ())
def gn() -> int:
obj = torch.classes._TorchScriptTesting._DefaultArgs(5)
obj.increment(3)
obj.increment()
obj.decrement(2)
obj.divide()
obj.scale_add(3)
obj.scale_add(3, 2)
obj.divide(2)
return obj.decrement()
self.checkScript(gn, ())