| # Owner(s): ["module: dynamo"] |
| from typing import Callable, Dict, List, NamedTuple, Optional |
| |
| import torch |
| |
| import torch._dynamo |
| from torch._dynamo.test_case import run_tests, TestCase |
| from torch._dynamo.testing import CompileCounter, same |
| |
| """ |
| This is an example of a pure-python version of autograd implemented by |
| @zdevito. It represents a rather challenging test case for TorchDynamo |
| to push the limits of what it can do. |
| """ |
| |
| |
| _name: int = 0 |
| |
| |
| def fresh_name() -> str: |
| """create a new unique name for a variable: v0, v1, v2""" |
| global _name |
| r = f"v{_name}" |
| _name += 1 |
| return r |
| |
| |
| class Variable: |
| def __init__(self, value: torch.Tensor, name: str = None): |
| self.value = value |
| self.name = name or fresh_name() |
| |
| # We need to start with some tensors whose values were not computed |
| # inside the autograd. This function constructs leaf nodes. |
| @staticmethod |
| def constant(value: torch.Tensor, name: str = None): |
| return Variable(value, name) |
| |
| def __repr__(self): |
| return repr(self.value) |
| |
| # This performs a pointwise multiplication of a Variable, tracking gradients |
| def __mul__(self, rhs: "Variable") -> "Variable": |
| # defined later in the notebook |
| return operator_mul(self, rhs) |
| |
| def __add__(self, rhs: "Variable") -> "Variable": |
| return operator_add(self, rhs) |
| |
| def sum(self, name: Optional[str] = None) -> "Variable": |
| return operator_sum(self, name) |
| |
| def expand(self, sizes: List[int]) -> "Variable": |
| return operator_expand(self, sizes) |
| |
| |
| class TapeEntry(NamedTuple): |
| # names of the inputs to the original computation |
| inputs: List[str] |
| # names of the outputs of the original computation |
| outputs: List[str] |
| # apply chain rule |
| propagate: "Callable[List[Variable], List[Variable]]" |
| |
| |
| gradient_tape: List[TapeEntry] = [] |
| |
| |
| def reset_tape(): |
| gradient_tape.clear() |
| global _name |
| _name = 0 |
| |
| |
| def grad(L, desired_results: List[Variable]) -> List[Variable]: |
| # this map holds dL/dX for all values X |
| dL_d: Dict[str, Variable] = {} |
| # It starts by initializing the 'seed' dL/dL, which is 1 |
| dL_d[L.name] = Variable(torch.ones(())) |
| # print(f'd{L.name} ------------------------') |
| |
| # look up dL_dentries. If a variable is never used to compute the loss, |
| # we consider its gradient None, see the note below about zeros for more information. |
| def gather_grad(entries: List[str]): |
| return [dL_d[entry] if entry in dL_d else None for entry in entries] |
| |
| # propagate the gradient information backward |
| for entry in reversed(gradient_tape): |
| dL_doutputs = gather_grad(entry.outputs) |
| if all(dL_doutput is None for dL_doutput in dL_doutputs): |
| # optimize for the case where some gradient pathways are zero. See |
| # The note below for more details. |
| continue |
| |
| # perform chain rule propagation specific to each compute |
| dL_dinputs = entry.propagate(dL_doutputs) |
| |
| # Accululate the gradient produced for each input. |
| # Each use of a variable produces some gradient dL_dinput for that |
| # use. The multivariate chain rule tells us it is safe to sum |
| # all the contributions together. |
| for input, dL_dinput in zip(entry.inputs, dL_dinputs): |
| if input not in dL_d: |
| dL_d[input] = dL_dinput |
| else: |
| dL_d[input].value += dL_dinput.value |
| |
| # print some information to understand the values of each intermediate |
| # for name, value in dL_d.items(): |
| # print(f'd{L.name}_d{name} = {value.name}') |
| # print(f'------------------------') |
| |
| return gather_grad(desired.name for desired in desired_results) |
| |
| |
| def operator_mul(self: Variable, rhs: Variable) -> Variable: |
| if isinstance(rhs, float) and rhs == 1.0: |
| # peephole optimization |
| return self |
| |
| # define forward |
| r = Variable(self.value * rhs.value) |
| # print(f'{r.name} = {self.name} * {rhs.name}') |
| |
| # record what the inputs and outputs of the op were |
| inputs = [self.name, rhs.name] |
| outputs = [r.name] |
| |
| # define backprop |
| def propagate(dL_doutputs: List[Variable]): |
| (dL_dr,) = dL_doutputs |
| |
| dr_dself = rhs # partial derivative of r = self*rhs |
| dr_drhs = self # partial derivative of r = self*rhs |
| |
| # chain rule propagation from outputs to inputs of multiply |
| dL_dself = dL_dr * dr_dself |
| dL_drhs = dL_dr * dr_drhs |
| dL_dinputs = [dL_dself, dL_drhs] |
| return dL_dinputs |
| |
| # finally, we record the compute we did on the tape |
| gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate)) |
| return r |
| |
| |
| def operator_add(self: Variable, rhs: Variable) -> Variable: |
| # Add follows a similar pattern to Mul, but it doesn't end up |
| # capturing any variables. |
| r = Variable(self.value + rhs.value) |
| # print(f'{r.name} = {self.name} + {rhs.name}') |
| |
| def propagate(dL_doutputs: List[Variable]): |
| (dL_dr,) = dL_doutputs |
| dr_dself = 1.0 |
| dr_drhs = 1.0 |
| dL_dself = dL_dr * dr_dself |
| dL_drhs = dL_dr * dr_drhs |
| return [dL_dself, dL_drhs] |
| |
| gradient_tape.append( |
| TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate) |
| ) |
| return r |
| |
| |
| def operator_sum(self: Variable, name: Optional[str]) -> "Variable": |
| r = Variable(torch.sum(self.value), name=name) |
| # print(f'{r.name} = {self.name}.sum()') |
| |
| def propagate(dL_doutputs: List[Variable]): |
| (dL_dr,) = dL_doutputs |
| size = self.value.size() |
| return [dL_dr.expand(*size)] |
| |
| gradient_tape.append( |
| TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) |
| ) |
| return r |
| |
| |
| def operator_expand(self: Variable, sizes: List[int]) -> "Variable": |
| assert self.value.dim() == 0 # only works for scalars |
| r = Variable(self.value.expand(sizes)) |
| # print(f'{r.name} = {self.name}.expand({sizes})') |
| |
| def propagate(dL_doutputs: List[Variable]): |
| (dL_dr,) = dL_doutputs |
| return [dL_dr.sum()] |
| |
| gradient_tape.append( |
| TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) |
| ) |
| return r |
| |
| |
| def simple(a, b): |
| t = a + b |
| return t * b |
| |
| |
| class TestPythonAutograd(TestCase): |
| def _common(self, fn, expected_ops): |
| args1 = [torch.randn(10), torch.randn(10)] |
| args2 = [torch.randn(10), torch.randn(10)] |
| cnt = CompileCounter() |
| fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn) |
| reset_tape() |
| res1 = fn_dynamo(*args1) |
| reset_tape() |
| res2 = fn_dynamo(*args2) |
| reset_tape() |
| self.assertTrue(same(res1, fn(*args1))) |
| reset_tape() |
| self.assertTrue(same(res2, fn(*args2))) |
| reset_tape() |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertEqual(cnt.op_count, expected_ops) |
| |
| def test_forwards1(self): |
| def fn(a, b): |
| a = Variable.constant(a, name="a") |
| b = Variable.constant(b, name="b") |
| loss = simple(a, b).sum() |
| return loss |
| |
| self._common(fn, 3) |
| |
| def test_forwards2(self): |
| def fn(a, b): |
| reset_tape() |
| a = Variable.constant(a, name="a") |
| b = Variable.constant(b, name="b") |
| loss = simple(a, b).sum() |
| reset_tape() |
| return loss |
| |
| self._common(fn, 3) |
| |
| def test_backwards1(self): |
| def fn(a, b): |
| a = Variable.constant(a, name="a") |
| b = Variable.constant(b, name="b") |
| loss = simple(a, b).sum() |
| return grad(loss, [a, b]) |
| |
| self._common(fn, 8) |
| |
| def test_backwards2(self): |
| def fn(a, b): |
| reset_tape() |
| a = Variable.constant(a, name="a") |
| b = Variable.constant(b, name="b") |
| loss = simple(a, b).sum() |
| res = grad(loss, [a, b]) |
| reset_tape() |
| return res |
| |
| self._common(fn, 8) |
| |
| def test_split(self): |
| v1 = Variable.constant(torch.randn(10), name="a") |
| v2 = Variable.constant(torch.randn(10), name="b") |
| cnt = CompileCounter() |
| |
| def forward(a, b): |
| return simple(a, b).sum() |
| |
| reset_tape() |
| loss1 = forward(v1, v2) |
| grad1 = grad(loss1, [v1, v2]) |
| |
| reset_tape() |
| opt_forward = torch._dynamo.optimize_assert(cnt)(forward) |
| opt_grad = torch._dynamo.optimize_assert(cnt)(grad) |
| loss2 = opt_forward(v1, v2) |
| # force two frames |
| grad2 = opt_grad(loss2, [v1, v2]) |
| |
| self.assertTrue(same(loss1, loss2)) |
| self.assertTrue(same(grad1, grad2)) |
| self.assertEqual(cnt.frame_count, 2) |
| self.assertEqual(cnt.op_count, 8) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |