blob: ddbd90a025dabae4b629128284d0736052ed1628 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
from typing import Any, List
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWith(JitTestCase):
"""
A suite of tests for with statements.
"""
def test_with_as(self):
"""
Check that with statements that use the 'as' keyword to bind expressions
to targets work as expected.
"""
@torch.jit.script
class Context(object):
"""
This class implements a basic context manager interface for use in
the unit tests. Unlike Context, the stateful part of this class
is a Tensor that is mutated in-place so that modifications made in the
JIT interpreter are visible outside of it.
"""
def __init__(self, start: int):
self.count = torch.tensor([start], dtype=torch.double)
def __enter__(self):
self.count.add_(0.3)
return self.count
def __exit__(self, type: Any, value: Any, tb: Any) -> bool:
self.count.sub_(0.3)
return True
make_global(Context)
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
c = Context(1)
with c as mult:
y = x + mult
y *= c.count
return y
def test_pass(x: torch.Tensor) -> torch.Tensor:
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
still be called.
"""
c = Context(1)
with c as mult:
pass
x *= c.count
return x
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that returning early from inside a with-statement works
as expected.
"""
with c as mult:
y = x + mult
return y
x = y + y
return x
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that conditionally returning early from inside a with-statement works
as expected.
"""
with c as mult:
y = x + mult
if mult > 0:
return y
x = y + y
return x
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that breaking early from inside a with-statement works
as expected.
"""
with c as mult:
for a in l:
if a == 0:
break
x += a * mult
return x
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that using continue inside a with-statement works
as expected.
"""
with c as mult:
for a in l:
if a == 0:
continue
x += a * mult
return x
def test_serial(x: torch.Tensor) -> torch.Tensor:
"""
Test two with-statements in a row.
"""
c = Context(1)
with c as mult:
y = x + mult
with c as mult:
y *= mult
return y
def test_nested(x: torch.Tensor) -> torch.Tensor:
"""
Test nested with-statements.
"""
c = Context(1)
with c as m:
with c as n:
y = x + n
y *= m
return y
def test_combined(x: torch.Tensor) -> torch.Tensor:
"""
Test a with-statement with multiple with items.
"""
c = Context(1)
d = Context(2)
with c as m, d as n:
y = x + (m + n)
return y
test_input = torch.randn(2, 2)
test_context = Context(2)
test_list = [2, 0, 1, 3, 0, 2]
self.checkScript(test_basic, (test_input,))
self.checkScript(test_pass, (test_input,))
self.checkScript(test_early_return, (test_input, test_context))
self.checkScript(test_break, (test_input, test_context, test_list))
self.checkScript(test_continue, (test_input, test_context, test_list))
self.assertEqual(test_context.count, 2)
self.checkScript(test_serial, (test_input,))
self.checkScript(test_nested, (test_input,))
self.checkScript(test_combined, (test_input,))
def test_with_no_as(self):
"""
Check that with statements that do not use the 'as' keyword to bind expressions
to targets work as expected.
"""
@torch.jit.script
class Context(object):
"""
This class implements a basic context manager interface for use in
the unit tests. Unlike Context, the stateful part of this class
is a Tensor that is mutated in-place so that modifications made in the
JIT interpreter are visible outside of it.
"""
def __init__(self, start: int):
self.count = torch.tensor([start], dtype=torch.double)
def __enter__(self):
self.count.add_(0.3)
return self.count
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
make_global(Context)
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
c = Context(1)
with c:
y = x + c.count
y *= c.count
return y
def test_pass(x: torch.Tensor) -> torch.Tensor:
"""
Test with a pass statement inside a with-statement. Although
the body of the with is empty, __enter__ and __exit__ should
still be called.
"""
c = Context(1)
with c:
pass
x *= c.count
return x
def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that returning early from inside a with-statement works
as expected.
"""
with c:
y = x + c.count
return y
x = y + y
return x
def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test that conditionally returning early from inside a with-statement works
as expected.
"""
with c:
y = x + c.count
if c.count > 0:
return y
x = y + y
return x
def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that breaking early from inside a with-statement works
as expected.
"""
with c:
for a in l:
if a == 0:
break
x += a * c.count
return x
def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor:
"""
Test that using continue inside a with-statement works
as expected.
"""
with c:
for a in l:
if a == 0:
continue
x += a * c.count
return x
def test_serial(x: torch.Tensor) -> torch.Tensor:
"""
Test two with-statements in a row.
"""
c = Context(1)
with c:
y = x + c.count
with c:
y *= c.count
return y
def test_nested(x: torch.Tensor) -> torch.Tensor:
"""
Test nested with-statements.
"""
c = Context(1)
with c:
with c:
y = x + c.count
y *= c.count
return y
def test_combined(x: torch.Tensor) -> torch.Tensor:
"""
Test a with-statement with multiple with items.
"""
c = Context(1)
d = Context(2)
with c, d:
y = x + (c.count + d.count)
return y
test_input = torch.randn(2, 2)
test_context = Context(2)
test_list = [2, 0, 1, 3, 0, 2]
self.checkScript(test_basic, (test_input,))
self.checkScript(test_pass, (test_input,))
self.checkScript(test_early_return, (test_input, test_context))
self.checkScript(test_break, (test_input, test_context, test_list))
self.checkScript(test_continue, (test_input, test_context, test_list))
self.assertEqual(test_context.count, 2)
self.checkScript(test_serial, (test_input,))
self.checkScript(test_nested, (test_input,))
self.checkScript(test_combined, (test_input,))
def test_with_exceptions(self):
"""
Check that exceptions thrown in the bodies of with-statements are
handled correctly.
"""
@torch.jit.script
class Context(object):
"""
This class implements a basic context manager interface for use in
the unit tests. Unlike Context, the stateful part of this class
is a Tensor that is mutated in-place so that modifications made in the
JIT interpreter are visible outside of it.
"""
def __init__(self, start: int):
self.count = torch.tensor([start], dtype=torch.double)
def __enter__(self):
self.count.add_(0.3)
return self.count
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
make_global(Context)
@torch.jit.script
def method_that_raises() -> torch.Tensor:
raise Exception("raised exception")
@torch.jit.script
def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while executing the body of a with-statement.
"""
with c as _:
x += method_that_raises()
return x
@torch.jit.script
def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while executing the body of a nested with-statement.
"""
with c as _:
with c as _:
x += method_that_raises()
return x
@torch.jit.script
def with_that_raises(c: Context) -> torch.Tensor:
a = torch.tensor([1])
with c as _:
a += method_that_raises()
return a
@torch.jit.script
def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor:
"""
Test the case in which an exception is thrown while there are active with-statements in two different
frames.
"""
with c as _:
x += with_that_raises(c)
return x
c = Context(1)
# checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
# not compile class types (of which Context, the context manager being used for this test
# is one).
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
test_exception(torch.randn(2), c)
self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
test_exception_nested(torch.randn(2), c)
self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
test_exception_fn_call(torch.randn(2), c)
self.assertEqual(c.count, 1)
def test_with_errors(self):
"""
Check that errors related to with-statements are detected and reported correctly.
"""
@torch.jit.script
class NoEnterNoExit(object):
"""
This class is missing __enter__ and __exit__ methods.
"""
def __init__(self):
self.count = 1
@torch.jit.script
class BadEnter(object):
"""
This class has an __enter__ method with an incorrect signature.
"""
def __init__(self):
self.count = 1
def __enter__(self, incr: int):
self.count += incr
def __exit__(self, type: Any, value: Any, tb: Any):
pass
@torch.jit.script
class BadExit(object):
"""
This class has an __exit__ method with an incorrect signature.
"""
def __init__(self):
self.count = 1
def __enter__(self):
self.count += 1
def __exit__(self, type: Any, value: Any):
pass
@torch.jit.script
class ExitIncorrectTypes(object):
"""
This class has an __exit__ method with unsupported argument types.
"""
def __init__(self):
self.count = 1
def __enter__(self):
self.count += 1
def __exit__(self, type: Any, value: int, tb: int):
pass
def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor:
with cm as _:
pass
return x
def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor:
with cm as _:
pass
return x
def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor:
with cm as _:
pass
return x
def test_exit_incorrect_types(x: torch.Tensor, cm: ExitIncorrectTypes) -> torch.Tensor:
with cm as _:
pass
return x
def test_enter_without_object():
with "not_object" as obj:
pass
test_tensor = torch.randn(5, dtype=torch.double)
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"does not define __enter__ and __exit__ methods", "cm"
):
self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"__enter__ must have only one argument and one return value", "cm"
):
self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"__exit__ must have four arguments", "cm"
):
self.checkScript(test_bad_exit, (test_tensor, BadExit()))
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"argument 2 of __exit__ must have Any type", "cm"
):
self.checkScript(
test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"must return an object", "\"not_object\""):
self.checkScript(test_enter_without_object, ())
def test_with_no_grad(self):
"""
Check that torch.no_grad() works. Most of these are adapted from
corresponding tests for eager-mode no_grad.
"""
# Basic no_grad test.
def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
w = x + y
return w
s = torch.jit.script(test_no_grad)
x = torch.ones(5, 5, requires_grad=True)
y = torch.ones(5, 5) * 4
w = s(x, y)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
# Test assignment of a grad-less Tensor to a Tensor with gradients
# in a no_grad block.
def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x[0] = y
return x
s = torch.jit.script(test_no_grad_assignment)
z = torch.randn(5)
w = s(x, z)
self.assertTrue(w.requires_grad)
self.assertIsNone(w.grad_fn)
# Check that @torch.jit.ignored functions respect no_grad when it is
# called in JIT mode.
class NoGradModule(torch.nn.Module):
def __init__(self):
super().__init__()
@torch.jit.ignore
def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
w = x + y
return w
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
w = self.adder(x, y)
return w
s = torch.jit.script(NoGradModule())
w = s(x, y)
self.assertFalse(w.requires_grad)
@skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
def test_with_record_function(self):
"""
Check that torch.autograd.profiler.record_function context manager is
torchscriptable.
"""
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.autograd.profiler.record_function("foo"):
# Nested record_function.
with torch.autograd.profiler.record_function("nested"):
a = x + y
return a
scripted = torch.jit.script(with_rf)
x, y = torch.ones(2), torch.ones(2)
with torch.autograd.profiler.profile() as p:
scripted(x, y)
# Need to call below to populate CPU children.
p.key_averages()
function_events = p.function_events
# Event with name "foo" should be recorded.
rf_events = [evt for evt in function_events if evt.name == "foo"]
self.assertEqual(len(rf_events), 1)
rf_event = rf_events[0]
child_events = rf_event.cpu_children
# Ensure we find nested record_function event
self.assertTrue("nested" in (child.name for child in child_events))
nested_function_event = [
evt for evt in function_events if evt.name == "nested"
][0]
# Nested record function should have child "aten::add"
nested_child_events = nested_function_event.cpu_children
self.assertTrue("aten::add" in (child.name for child in nested_child_events))