blob: 73a6f8f6d330ae51ff54620aa5607babbb1c8a98 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
from typing import Callable, Dict, List, NamedTuple, Optional
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo
to push the limits of what it can do.
"""
_name: int = 0
def fresh_name() -> str:
"""create a new unique name for a variable: v0, v1, v2"""
global _name
r = f"v{_name}"
_name += 1
return r
class Variable:
def __init__(self, value: torch.Tensor, name: str = None):
self.value = value
self.name = name or fresh_name()
# We need to start with some tensors whose values were not computed
# inside the autograd. This function constructs leaf nodes.
@staticmethod
def constant(value: torch.Tensor, name: str = None):
return Variable(value, name)
def __repr__(self):
return repr(self.value)
# This performs a pointwise multiplication of a Variable, tracking gradients
def __mul__(self, rhs: "Variable") -> "Variable":
# defined later in the notebook
return operator_mul(self, rhs)
def __add__(self, rhs: "Variable") -> "Variable":
return operator_add(self, rhs)
def sum(self, name: Optional[str] = None) -> "Variable":
return operator_sum(self, name)
def expand(self, sizes: List[int]) -> "Variable":
return operator_expand(self, sizes)
class TapeEntry(NamedTuple):
# names of the inputs to the original computation
inputs: List[str]
# names of the outputs of the original computation
outputs: List[str]
# apply chain rule
propagate: "Callable[List[Variable], List[Variable]]"
gradient_tape: List[TapeEntry] = []
def reset_tape():
gradient_tape.clear()
global _name
_name = 0
def grad(L, desired_results: List[Variable]) -> List[Variable]:
# this map holds dL/dX for all values X
dL_d: Dict[str, Variable] = {}
# It starts by initializing the 'seed' dL/dL, which is 1
dL_d[L.name] = Variable(torch.ones(()))
# print(f'd{L.name} ------------------------')
# look up dL_dentries. If a variable is never used to compute the loss,
# we consider its gradient None, see the note below about zeros for more information.
def gather_grad(entries: List[str]):
return [dL_d[entry] if entry in dL_d else None for entry in entries]
# propagate the gradient information backward
for entry in reversed(gradient_tape):
dL_doutputs = gather_grad(entry.outputs)
if all(dL_doutput is None for dL_doutput in dL_doutputs):
# optimize for the case where some gradient pathways are zero. See
# The note below for more details.
continue
# perform chain rule propagation specific to each compute
dL_dinputs = entry.propagate(dL_doutputs)
# Accululate the gradient produced for each input.
# Each use of a variable produces some gradient dL_dinput for that
# use. The multivariate chain rule tells us it is safe to sum
# all the contributions together.
for input, dL_dinput in zip(entry.inputs, dL_dinputs):
if input not in dL_d:
dL_d[input] = dL_dinput
else:
dL_d[input].value += dL_dinput.value
# print some information to understand the values of each intermediate
# for name, value in dL_d.items():
# print(f'd{L.name}_d{name} = {value.name}')
# print(f'------------------------')
return gather_grad(desired.name for desired in desired_results)
def operator_mul(self: Variable, rhs: Variable) -> Variable:
if isinstance(rhs, float) and rhs == 1.0:
# peephole optimization
return self
# define forward
r = Variable(self.value * rhs.value)
# print(f'{r.name} = {self.name} * {rhs.name}')
# record what the inputs and outputs of the op were
inputs = [self.name, rhs.name]
outputs = [r.name]
# define backprop
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = rhs # partial derivative of r = self*rhs
dr_drhs = self # partial derivative of r = self*rhs
# chain rule propagation from outputs to inputs of multiply
dL_dself = dL_dr * dr_dself
dL_drhs = dL_dr * dr_drhs
dL_dinputs = [dL_dself, dL_drhs]
return dL_dinputs
# finally, we record the compute we did on the tape
gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
return r
def operator_add(self: Variable, rhs: Variable) -> Variable:
# Add follows a similar pattern to Mul, but it doesn't end up
# capturing any variables.
r = Variable(self.value + rhs.value)
# print(f'{r.name} = {self.name} + {rhs.name}')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = 1.0
dr_drhs = 1.0
dL_dself = dL_dr * dr_dself
dL_drhs = dL_dr * dr_drhs
return [dL_dself, dL_drhs]
gradient_tape.append(
TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
)
return r
def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
r = Variable(torch.sum(self.value), name=name)
# print(f'{r.name} = {self.name}.sum()')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
size = self.value.size()
return [dL_dr.expand(*size)]
gradient_tape.append(
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
)
return r
def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
assert self.value.dim() == 0 # only works for scalars
r = Variable(self.value.expand(sizes))
# print(f'{r.name} = {self.name}.expand({sizes})')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
return [dL_dr.sum()]
gradient_tape.append(
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
)
return r
def simple(a, b):
t = a + b
return t * b
class TestPythonAutograd(TestCase):
def _common(self, fn, expected_ops):
args1 = [torch.randn(10), torch.randn(10)]
args2 = [torch.randn(10), torch.randn(10)]
cnt = CompileCounter()
fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
reset_tape()
res1 = fn_dynamo(*args1)
reset_tape()
res2 = fn_dynamo(*args2)
reset_tape()
self.assertTrue(same(res1, fn(*args1)))
reset_tape()
self.assertTrue(same(res2, fn(*args2)))
reset_tape()
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, expected_ops)
def test_forwards1(self):
def fn(a, b):
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
return loss
self._common(fn, 3)
def test_forwards2(self):
def fn(a, b):
reset_tape()
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
reset_tape()
return loss
self._common(fn, 3)
def test_backwards1(self):
def fn(a, b):
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
return grad(loss, [a, b])
self._common(fn, 8)
def test_backwards2(self):
def fn(a, b):
reset_tape()
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
res = grad(loss, [a, b])
reset_tape()
return res
self._common(fn, 8)
def test_split(self):
v1 = Variable.constant(torch.randn(10), name="a")
v2 = Variable.constant(torch.randn(10), name="b")
cnt = CompileCounter()
def forward(a, b):
return simple(a, b).sum()
reset_tape()
loss1 = forward(v1, v2)
grad1 = grad(loss1, [v1, v2])
reset_tape()
opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
loss2 = opt_forward(v1, v2)
# force two frames
grad2 = opt_grad(loss2, [v1, v2])
self.assertTrue(same(loss1, loss2))
self.assertTrue(same(grad1, grad2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 8)
if __name__ == "__main__":
run_tests()