| import dataclasses |
| import typing |
| import unittest |
| |
| from tools.autograd import gen_autograd_functions |
| from tools.autograd import load_derivatives |
| import tools.codegen.model |
| |
| |
| class TestCreateDerivative(unittest.TestCase): |
| def test_named_grads(self) -> None: |
| schema = tools.codegen.model.FunctionSchema.parse( |
| "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| ) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| derivative = load_derivatives.create_derivative( |
| native_function, |
| formula="func_backward(grad_x, grad_y)", |
| var_names=(), |
| available_named_gradients=["grad_x", "grad_y"], |
| ) |
| self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"}) |
| |
| def test_non_differentiable_output(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" |
| schema = tools.codegen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| differentiability_info = load_derivatives.create_differentiability_info( |
| defn={ |
| "name": specification, |
| "a": "grads[0]", |
| "b": "grads[2]", |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| ) |
| |
| self.assertSequenceEqual( |
| differentiability_info.available_named_gradients, |
| # grad_y is not present because y is a |
| # bool and thus not differentiable. |
| ["grad_x", "grad_z"], |
| ) |
| |
| def test_indexed_grads(self) -> None: |
| schema = tools.codegen.model.FunctionSchema.parse( |
| "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| ) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| derivative = load_derivatives.create_derivative( |
| native_function, |
| formula="func_backward(grads[0], grads[1])", |
| var_names=(), |
| available_named_gradients=["grad_x", "grad_y"], |
| ) |
| self.assertSetEqual(derivative.named_gradients, set()) |
| |
| def test_named_grads_and_indexed_grads(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| schema = tools.codegen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"' |
| ): |
| load_derivatives.create_differentiability_info( |
| defn={ |
| "name": specification, |
| # Uh-oh, the derivatives reference gradients by |
| # name and by index. |
| "a": "grad_x", |
| "b": "grads[1]", |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| ) |
| |
| |
| class TestGenAutogradFunctions(unittest.TestCase): |
| def test_non_differentiable_output_invalid_type(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" |
| schema = tools.codegen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| differentiability_info = load_derivatives.create_differentiability_info( |
| defn={ |
| "name": specification, |
| "a": "grad_x", |
| "b": "grad_z", |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| ) |
| definition = gen_autograd_functions.process_function( |
| differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION |
| ) |
| # grad_z should map to grads[1], not grads[2] because output 1 |
| # (y) is not differentiable. |
| assert "grad_z = grads[2]" not in definition |
| assert "grad_z = grads[1]" in definition |
| |
| def test_non_differentiable_output_output_differentiability(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" |
| schema = tools.codegen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| differentiability_info = load_derivatives.create_differentiability_info( |
| defn={ |
| "name": specification, |
| "a": "grad_x", |
| "b": "grad_z", |
| "output_differentiability": [True, False, True], |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| ) |
| definition = gen_autograd_functions.process_function( |
| differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION |
| ) |
| # grad_z should map to grads[1], not grads[2] because output 1 |
| # (y) is not differentiable. |
| assert "grad_z = grads[2]" not in definition |
| assert "grad_z = grads[1]" in definition |
| |
| |
| # Represents the most basic NativeFunction. Use dataclasses.replace() |
| # to edit for use. |
| DEFAULT_NATIVE_FUNCTION, _ = tools.codegen.model.NativeFunction.from_yaml( |
| {"func": "func() -> bool"}, loc=tools.codegen.model.Location(__file__, 1) |
| ) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |