| # Generates C++ autograd functions for the derivatives of ATen operations |
| # |
| # This writes two files: |
| # Functions.h/cpp: subclasses of autograd::Node |
| # python_functions.h/cpp: Python bindings for the above classes |
| # |
| from .gen_inplace_or_view_type import VIEW_FUNCTIONS |
| |
| from typing import List, Sequence, Tuple |
| |
| from tools.codegen.api.autograd import ( |
| Derivative, |
| DifferentiabilityInfo, |
| SavedAttribute, |
| uses_retain_variables, |
| uses_single_grad, |
| ) |
| from tools.codegen.api.types import ( |
| Binding, |
| BaseCType, |
| OptionalCType, |
| tensorT, |
| longT, |
| doubleT, |
| scalarT, |
| stringT, |
| boolT, |
| intArrayRefT, |
| tensorListT, |
| MutRefCType, |
| ListCType, |
| ArrayRefCType, |
| optionalIntArrayRefT, |
| ) |
| from tools.codegen.code_template import CodeTemplate |
| from tools.codegen.utils import FileManager |
| from tools.codegen.model import Argument |
| |
| FUNCTION_DECLARATION = CodeTemplate( |
| """\ |
| struct TORCH_API ${op} : public ${superclass} { |
| using ${superclass}::${superclass}; |
| variable_list apply(variable_list&& grads) override; |
| std::string name() const override { return "${op}"; } |
| void release_variables() override { |
| ${thread_lock} |
| ${release_variables} |
| } |
| ${will_release_variables} |
| ${saved_variables} |
| ${saved_list_sizes} |
| }; |
| """ |
| ) |
| |
| WILL_RELEASE_VARIABLES = CodeTemplate( |
| """\ |
| bool retain_variables = true; |
| void will_release_variables() override { |
| retain_variables = false; |
| } |
| """ |
| ) |
| |
| FUNCTION_DEFINITION = CodeTemplate( |
| """\ |
| variable_list ${op}::apply(variable_list&& grads) { |
| ${thread_lock} |
| ${asserts} |
| IndexRangeGenerator gen; |
| ${compute_index_ranges} |
| variable_list grad_inputs(gen.size()); |
| ${body} |
| return grad_inputs; |
| } |
| """ |
| ) |
| |
| GRAD_INPUT_MASK = CodeTemplate( |
| """\ |
| auto grad_input_mask = std::array<bool, ${n}>{ |
| ${masks} |
| };\ |
| """ |
| ) |
| |
| DERIVATIVE_SINGLE = CodeTemplate( |
| """\ |
| if (should_compute_output({ ${name}_ix })) { |
| auto grad_result = ${derivative}; |
| copy_range(grad_inputs, ${name}_ix, grad_result); |
| } |
| """ |
| ) |
| |
| DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( |
| """\ |
| if (should_compute_output({ ${name}_ix })) { |
| copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); |
| } |
| """ |
| ) |
| |
| DERIVATIVE_MULTI = CodeTemplate( |
| """\ |
| if (should_compute_output({ ${idx_ranges} })) { |
| ${grad_input_mask} |
| auto grad_result = ${derivative}; |
| ${copy_ranges} |
| } |
| """ |
| ) |
| |
| # Generates python bindings |
| # |
| # This generates the definitions for: |
| # (1) The PyTypeObject for each backward grad_fn subclassing Node |
| # (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs) |
| # We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs |
| # Each PyGetSetDef has a function ptr to a getter, also defined here (3). |
| # (3) Getters for each of grad_fn's saved inputs and outputs. |
| # |
| PY_FUNCTION_DEFINITION = CodeTemplate( |
| """\ |
| static PyTypeObject ${op}Class; |
| addClass<${op}>(${op}Class, "${op}", ${op}_properties); |
| """ |
| ) |
| |
| PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate( |
| """\ |
| ${all_getter_definitions} |
| |
| static struct PyGetSetDef ${op}_properties[] = { |
| THP_FUNCTION_DEFAULT_PROPERTIES, |
| ${all_getsetdef_structs} |
| {nullptr} /* sentinel */ |
| }; |
| |
| """ |
| ) |
| |
| PY_GETSETDEF_STRUCT = CodeTemplate( |
| """\ |
| {(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""" |
| ) |
| |
| PY_RAW_GETSETDEF_STRUCT = CodeTemplate( |
| """\ |
| {(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""" |
| ) |
| |
| # Getter templates |
| GETTER_DEFINITION = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| auto prop = static_cast<${op}*>(self->cdata.get())->${name}; |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_SAVEDVAR = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| const auto *node = static_cast<${op}*>(self->cdata.get()); |
| const auto& prop = node->${name}_; |
| if (node->${name}_released_) { |
| PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); |
| return nullptr; |
| } |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| const auto *node = static_cast<${op}*>(self->cdata.get()); |
| const auto& prop = node->${name}_; |
| if (node->${name}_released_) { |
| PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); |
| return nullptr; |
| } |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_OPT = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; |
| if (!opt_prop.has_value()) { |
| Py_RETURN_NONE; |
| } |
| auto prop = opt_prop.value(); |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate( |
| """\ |
| PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { |
| HANDLE_TH_ERRORS |
| auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; |
| if (!opt_prop.list.has_value()) { |
| Py_RETURN_NONE; |
| } |
| auto prop = opt_prop.list.value(); |
| ${body} |
| END_HANDLE_TH_ERRORS |
| } |
| """ |
| ) |
| |
| # Getter body |
| GETTER_BODY_SAVEDVAR = """\ |
| return THPVariable_Wrap(prop.unpack(self->cdata)); |
| """ |
| |
| GETTER_BODY_RAW_SAVEDVAR = """\ |
| pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference); |
| return obj.release().ptr(); |
| """ |
| |
| GETTER_BODY_VEC_SAVEDVAR = """\ |
| PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); |
| for (auto i: c10::irange(prop.size())) { |
| PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata))); |
| } |
| return tup; |
| """ |
| |
| GETTER_BODY_RAW_VEC_SAVEDVAR = """\ |
| PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); |
| for (auto i : c10::irange(prop.size())) { |
| pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference); |
| PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr()); |
| } |
| return tup; |
| """ |
| |
| GETTER_BODY_ARRAYREF_LONG = """\ |
| PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); |
| for (auto i : c10::irange(prop.size())) { |
| PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i])); |
| } |
| return tup; |
| """ |
| |
| GETTER_BODY_ARRAYREF_DOUBLE = """\ |
| PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); |
| for (auto i : c10::irange(prop.size())) { |
| PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i])); |
| } |
| return tup; |
| """ |
| |
| GETTER_BODY_INT64_T = """\ |
| return PyLong_FromUnsignedLong((int64_t) prop); |
| """ |
| |
| GETTER_BODY_DOUBLE = """\ |
| return PyFloat_FromDouble((double) prop); |
| """ |
| |
| GETTER_BODY_BOOL = """\ |
| if (prop) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| """ |
| |
| GETTER_BODY_STRING = """\ |
| return PyUnicode_FromStringAndSize(prop.data(), prop.size()); |
| """ |
| |
| GETTER_BODY_SCALAR = """\ |
| if (prop.isComplex()) { |
| auto cprop = prop.to<c10::complex<double>>(); |
| return PyComplex_FromDoubles(cprop.real(), cprop.imag()); |
| } else if (prop.isFloatingPoint()) { |
| return PyFloat_FromDouble(prop.to<double>()); |
| } else if (prop.isIntegral(/*includeBool=*/false)) { |
| return PyLong_FromLong(prop.to<int64_t>()); |
| } else if (prop.isBoolean()) { |
| if (prop.to<bool>()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| } else { |
| PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); |
| return nullptr; |
| } |
| """ |
| |
| MISC_GETTER_DEFS = { |
| OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), |
| BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), |
| OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), |
| BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), |
| BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), |
| OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), |
| } |
| |
| # These functions have backwards which cannot be traced, and so must have |
| # their backward functions traced opaquely. |
| # VIEW_FUNCTIONS are not traceable because they use as_strided, which |
| # has an untraceable backwards, see |
| # https://github.com/pytorch/pytorch/issues/4250 |
| # TODO: This is probably not exhaustive, but it's a start |
| UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS |
| |
| |
| def gen_autograd_functions_lib( |
| out: str, |
| differentiability_infos: Sequence[DifferentiabilityInfo], |
| template_path: str, |
| ) -> None: |
| """Functions.h and Functions.cpp body |
| |
| These contain the auto-generated subclasses of torch::autograd::Node |
| for each every differentiable torch function. |
| """ |
| |
| # only create an autograd function if we are actually going to calculate a derivative |
| infos = list( |
| filter(lambda info: info.args_with_derivatives, differentiability_infos) |
| ) |
| declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos)) |
| definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos)) |
| |
| file_basename = "Functions" |
| fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) |
| for suffix in [".h", ".cpp"]: |
| fname = file_basename + suffix |
| fm.write_with_template( |
| fname, |
| fname, |
| lambda: { |
| "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname, |
| "autograd_function_declarations": declarations, |
| "autograd_function_definitions": definitions, |
| }, |
| ) |
| |
| |
| def gen_autograd_functions_python( |
| out: str, |
| differentiability_infos: Sequence[DifferentiabilityInfo], |
| template_path: str, |
| ) -> None: |
| |
| fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) |
| num_shards = 5 |
| fm.write( |
| "python_functions.h", |
| lambda: { |
| "generated_comment": f"@generated from {fm.template_dir}/python_functions.h", |
| "shard_forward_declare": [ |
| f"void initialize_autogenerated_functions_{i}();" |
| for i in range(num_shards) |
| ], |
| "shard_call": [ |
| f"initialize_autogenerated_functions_{i}();" for i in range(num_shards) |
| ], |
| }, |
| ) |
| |
| infos = list( |
| filter(lambda info: info.args_with_derivatives, differentiability_infos) |
| ) |
| fm.write_sharded( |
| "python_functions.cpp", |
| infos, |
| key_fn=lambda info: info.name, |
| base_env={ |
| "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp", |
| }, |
| env_callable=lambda info: { |
| "py_function_initializers": [ |
| process_function(info, PY_FUNCTION_DEFINITION) |
| ], |
| "py_function_props_and_getters": [ |
| process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) |
| ], |
| }, |
| num_shards=num_shards, |
| sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, |
| ) |
| |
| |
| def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: |
| saved_variables: List[str] = [] |
| release_variables: List[str] = [] |
| saved_list_sizes: List[str] = [] |
| unpack: List[str] = [] |
| asserts: List[str] = [] |
| compute_index_ranges: List[str] = [] |
| getter_definitions: List[str] = [] |
| py_getsetdef_structs: List[str] = [] |
| |
| for arg in info.args_with_derivatives: |
| if ( |
| arg.type == "at::TensorList" |
| or arg.type == "const c10::List<c10::optional<at::Tensor>> &" |
| ): |
| size = f"{arg.name}_size_" |
| saved_list_sizes.append(f"size_t {arg.name}_size_;") |
| else: |
| size = "1" |
| compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") |
| |
| def save_var(var: SavedAttribute, is_output: bool) -> None: |
| name = var.nctype.name |
| type = var.nctype.type |
| should_append_getsetdef = True |
| should_append_raw_getsetdef = False |
| |
| if ( |
| type == BaseCType(tensorT) |
| or type == OptionalCType(BaseCType(tensorT)) |
| or type == MutRefCType(OptionalCType(BaseCType(tensorT))) |
| or (type == BaseCType(scalarT) and is_output) |
| ): |
| saved_variables.append(f"SavedVariable {name}_;") |
| release_variables.append(f"{name}_.reset_data();") |
| ptr = "shared_from_this()" if is_output else "" |
| unpack.append(f"auto {name} = {name}_.unpack({ptr});") |
| getter_definitions.append( |
| GETTER_DEFINITION_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_SAVEDVAR |
| ) |
| ) |
| getter_definitions.append( |
| GETTER_DEFINITION_RAW_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR |
| ) |
| ) |
| should_append_raw_getsetdef = True |
| elif type == BaseCType(tensorListT): |
| saved_variables.append(f"std::vector<SavedVariable> {name}_;") |
| saved_variables.append(f"bool {name}_released_ = false;") |
| # Just clear() is sufficient, we don't need to loop and clear each variable. |
| # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. |
| release_variables.append(f"{name}_.clear();") |
| release_variables.append(f"{name}_released_ = true;") |
| unpack.append(f"auto {name} = unpack_list({name}_);") |
| asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") |
| getter_definitions.append( |
| GETTER_DEFINITION_VEC_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR |
| ) |
| ) |
| getter_definitions.append( |
| GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR |
| ) |
| ) |
| should_append_raw_getsetdef = True |
| elif type == ListCType(OptionalCType(BaseCType(tensorT))): |
| saved_variables.append(f"std::vector<SavedVariable> {name}_;") |
| saved_variables.append(f"bool {name}_released_ = false;") |
| # Just clear() is sufficient, we don't need to loop and clear each variable. |
| # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. |
| release_variables.append(f"{name}_.clear();") |
| release_variables.append(f"{name}_released_ = true;") |
| unpack.append(f"auto {name} = unpack_opt_list({name}_);") |
| asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") |
| getter_definitions.append( |
| GETTER_DEFINITION_VEC_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR |
| ) |
| ) |
| getter_definitions.append( |
| GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( |
| op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR |
| ) |
| ) |
| should_append_raw_getsetdef = True |
| elif type == BaseCType(intArrayRefT): |
| saved_variables.append(f"std::vector<int64_t> {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION.substitute( |
| op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG |
| ) |
| ) |
| elif type == BaseCType(optionalIntArrayRefT): |
| saved_variables.append(f"c10::OptionalArray<int64_t> {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION_OPT_ARRAYREF.substitute( |
| op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG |
| ) |
| ) |
| elif type == OptionalCType(BaseCType(intArrayRefT)): |
| saved_variables.append(f"c10::OptionalArray<int64_t> {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION_OPT_ARRAYREF.substitute( |
| op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG |
| ) |
| ) |
| elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): |
| saved_variables.append(f"c10::OptionalArray<double> {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION_OPT_ARRAYREF.substitute( |
| op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE |
| ) |
| ) |
| elif type == BaseCType(longT): |
| saved_variables.append(f"{type.cpp_type()} {name} = 0;") |
| getter_definitions.append( |
| GETTER_DEFINITION.substitute( |
| op=info.op, name=name, body=GETTER_BODY_INT64_T |
| ) |
| ) |
| elif type == BaseCType(stringT): |
| saved_variables.append(f"std::string {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION.substitute( |
| op=info.op, name=name, body=GETTER_BODY_STRING |
| ) |
| ) |
| elif type == OptionalCType(BaseCType(stringT)): |
| saved_variables.append(f"c10::optional<std::string> {name};") |
| getter_definitions.append( |
| GETTER_DEFINITION_OPT.substitute( |
| op=info.op, name=name, body=GETTER_BODY_STRING |
| ) |
| ) |
| else: |
| saved_variables.append(f"{type.cpp_type()} {name};") |
| |
| if type in MISC_GETTER_DEFS: |
| getter_def, body = MISC_GETTER_DEFS[type] |
| getter_definitions.append( |
| getter_def.substitute(op=info.op, name=name, body=body) |
| ) |
| else: |
| # Types we don't expose python bindings to yet: |
| # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, |
| # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> |
| should_append_getsetdef = False |
| |
| if should_append_getsetdef: |
| py_getsetdef_structs.append( |
| PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) |
| ) |
| if should_append_raw_getsetdef: |
| py_getsetdef_structs.append( |
| PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) |
| ) |
| |
| for var in info.all_saved_inputs: |
| save_var(var, is_output=False) |
| for var in info.all_saved_outputs: |
| save_var(var, is_output=True) |
| |
| # lock the mutex when we release variables and in Node::apply to protect thread safety |
| # see Note [Thread Safety on Autograd Node] |
| if len(release_variables) > 0: |
| thread_lock = "std::lock_guard<std::mutex> lock(mutex_);" |
| else: |
| thread_lock = "" |
| |
| if uses_retain_variables(info): |
| will_release_variables = WILL_RELEASE_VARIABLES.substitute() |
| else: |
| will_release_variables = "" |
| |
| body: List[str] = [] |
| |
| if uses_single_grad(info): |
| body.append("const auto& grad = grads[0];") |
| else: |
| # Generate aliases for gradients named for returned values. |
| body.extend( |
| f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];" |
| for name in info.used_named_gradients |
| ) |
| |
| def emit_derivative( |
| derivative: Derivative, |
| args_with_derivatives: Sequence[Binding], |
| ) -> Tuple[bool, str]: |
| formula = derivative.formula |
| var_names = derivative.var_names |
| if len(var_names) == 1: |
| checks_any_grad_defined = False |
| if "not_implemented" not in formula: |
| matching_args = [ |
| arg for arg in args_with_derivatives if arg.name == var_names[0] |
| ] |
| if len(matching_args) == 1: |
| # We can add undefined grad support if the input variable is a Tensor |
| arg = matching_args[0] |
| if isinstance(arg.argument, Argument) and str( |
| arg.argument.type |
| ) in ("Tensor", "Tensor?"): |
| formula = "any_grad_defined ? (" + formula + ") : Tensor()" |
| checks_any_grad_defined = True |
| return ( |
| checks_any_grad_defined, |
| DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula), |
| ) |
| else: |
| if "grad_input_mask" in formula: |
| masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names] |
| grad_input_mask = GRAD_INPUT_MASK.substitute( |
| masks=masks, n=len(var_names) |
| ) |
| else: |
| grad_input_mask = "" |
| idx_ranges = ", ".join(f"{n}_ix" for n in var_names) |
| copy_ranges: List[str] = [] |
| for i, n in enumerate(var_names): |
| copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) |
| return False, DERIVATIVE_MULTI.substitute( |
| idx_ranges=idx_ranges, |
| copy_ranges=copy_ranges, |
| derivative=formula, |
| grad_input_mask=grad_input_mask, |
| ) |
| |
| body.extend(unpack) |
| need_any_grad_defined_var = False |
| for derivative in info.derivatives: |
| checks_any_grad_defined, derivative_text = emit_derivative( |
| derivative, info.args_with_derivatives |
| ) |
| body.append(derivative_text) |
| need_any_grad_defined_var |= checks_any_grad_defined |
| # Since single-output derivative formulas need to check if grads are |
| # defined, only perform the check once, before all the formulas |
| if need_any_grad_defined_var: |
| body.insert( |
| -len(info.derivatives), |
| "bool any_grad_defined = any_variable_defined(grads);", |
| ) |
| |
| if info.name in UNTRACEABLE_FUNCTIONS: |
| superclass = "Node" |
| else: |
| superclass = "TraceableFunction" |
| |
| all_getsetdef_structs = ( |
| ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else "" |
| ) |
| all_getter_definitions = "\n".join(getter_definitions) |
| |
| return template.substitute( |
| op=info.op, |
| compute_index_ranges=compute_index_ranges, |
| saved_variables=saved_variables, |
| release_variables=release_variables, |
| saved_list_sizes=saved_list_sizes, |
| asserts=asserts, |
| thread_lock=thread_lock, |
| will_release_variables=will_release_variables, |
| body=body, |
| superclass=superclass, |
| all_getter_definitions=all_getter_definitions, |
| all_getsetdef_structs=all_getsetdef_structs, |
| ) |