| #include <ATen/core/boxing/impl/test_helpers.h> |
| #include <gtest/gtest.h> |
| |
| #include <ATen/core/op_registration/op_registration.h> |
| #include <torch/torch.h> |
| |
| #include <torch/csrc/autograd/FunctionsManual.h> |
| #include <torch/csrc/autograd/functions/basic_ops.h> |
| |
| #include <test/cpp/api/support.h> |
| |
| using namespace torch::autograd; |
| using namespace torch::test; |
| |
| #define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b))) |
| #define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b))) |
| |
| std::string graph_desc(std::shared_ptr<Node> node) { |
| if (!node) { |
| return "None"; |
| } |
| auto result = node->name() + "("; |
| auto next_edges = node->next_edges(); |
| for (auto& edge : next_edges) { |
| result += graph_desc(edge.function); |
| } |
| return result + ")"; |
| } |
| |
| Variable simple_fn(const Variable& x, const Variable& y) { |
| return x + 2 * y + x * y; |
| } |
| |
| TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) { |
| auto x = at::zeros({}, at::kCPU); |
| x.requires_grad_(); |
| x.register_hook([](at::TensorBase x) { return; }); |
| auto y = torch::autograd::UndefinedGrad().apply({x}); |
| y[0].backward(); |
| } |
| |
| TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) { |
| auto x = at::zeros({}, at::kCPU); |
| x.requires_grad_(); |
| x.register_hook([](at::Tensor x) -> at::Tensor { return x; }); |
| auto y = torch::autograd::UndefinedGrad().apply({x}); |
| y[0].backward(); |
| } |
| |
| TEST(AutogradAPITests, BackwardSimpleTest) { |
| Variable x = torch::randn({2, 2}, torch::requires_grad()); |
| Variable y = torch::randn({2, 2}, torch::requires_grad()); |
| auto res = simple_fn(x, y); |
| backward({res.sum()}, {}); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2})); |
| ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2); |
| } |
| |
| TEST(AutogradAPITests, BackwardTest) { |
| Variable x = torch::randn({2, 2}, torch::requires_grad()); |
| Variable y = torch::randn({2, 2}, torch::requires_grad()); |
| auto res = simple_fn(x, y); |
| backward({res}, {torch::ones({2, 2})}, {}, true); |
| |
| backward({res}, {torch::ones({2, 2})}); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2}))); |
| ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2)); |
| } |
| |
| TEST(AutogradAPITests, GradSimpleTest) { |
| // basic grad |
| Variable x = torch::randn({2, 2}, torch::requires_grad()); |
| Variable y = torch::randn({2, 2}, torch::requires_grad()); |
| auto res = simple_fn(x, y); |
| auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})}); |
| |
| ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2})); |
| ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2); |
| } |
| |
| TEST(AutogradAPITests, GradTest) { |
| Variable x = torch::randn({2, 2}, torch::requires_grad()); |
| Variable y = torch::randn({2, 2}, torch::requires_grad()); |
| auto res = simple_fn(x, y); |
| res.backward(torch::ones({2, 2}), false, true); |
| |
| Variable x_grad = y + torch::ones({2, 2}); |
| Variable y_grad = x + torch::ones({2, 2}) * 2; |
| ASSERT_VARIABLE_EQ(x.grad(), x_grad); |
| ASSERT_VARIABLE_EQ(y.grad(), y_grad); |
| |
| Variable grad_sum = 2 * x.grad() + y.grad(); |
| auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true); |
| |
| ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2})); |
| ASSERT_VARIABLE_EQ(x.grad(), x_grad); |
| ASSERT_VARIABLE_EQ(y.grad(), y_grad); |
| } |
| |
| TEST(AutogradAPITests, GradNonLeafTest) { |
| Variable x_init = torch::randn({2, 2}, torch::requires_grad()); |
| Variable x = x_init; |
| Variable y = torch::randn({2, 2}, torch::requires_grad()); |
| Variable grad_output = torch::ones({2, 2}); |
| |
| for (int i = 0; i < 5; ++i) { |
| auto res = simple_fn(x, y); |
| auto input_grads = grad({res}, {x}, {grad_output}, {}, true); |
| |
| Variable grad_x_expected = y + torch::ones({2, 2}); |
| ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected); |
| ASSERT_FALSE(x.grad().defined()); |
| ASSERT_FALSE(y.grad().defined()); |
| x = x + 0.05 * input_grads[0]; |
| } |
| |
| float val_init = simple_fn(x_init, y).sum().item().toFloat(); |
| float val_final = simple_fn(x, y).sum().item().toFloat(); |
| ASSERT_TRUE(val_final > val_init); |
| |
| x.backward(grad_output, false, true); |
| ASSERT_TRUE(x_init.grad().defined()); |
| ASSERT_TRUE(y.grad().defined()); |
| } |
| |
| TEST(AutogradAPITests, GradUnreachableTest) { |
| Variable x = torch::ones({1}, torch::requires_grad()); |
| Variable y = torch::ones({1}, torch::requires_grad()); |
| |
| Variable z = x * 2; |
| Variable w = y * 2; |
| |
| auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true); |
| ASSERT_VARIABLE_EQ(grad_res[0], x * 2); |
| ASSERT_FALSE(grad_res[1].defined()); |
| |
| // This is slightly different than the case above, because z doesn't even |
| // have a grad accumulator allocated. |
| z = torch::ones({1}, torch::requires_grad()); |
| grad_res = grad({x * 2}, {x, z}, {}, {}, false, true); |
| |
| ASSERT_VARIABLE_EQ(grad_res[0], x * 2); |
| ASSERT_FALSE(grad_res[1].defined()); |
| |
| // allow_unused=False, but grads contains None inside, should throw |
| ASSERT_THROWS_WITH( |
| grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True"); |
| } |
| |
| TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) { |
| // Test that certain nodes are not erroneously executed when an input |
| // is unreachable. See #39784 |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable var) { |
| return var; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| ADD_FAILURE() << "This node should not be executed!"; |
| return grad_output; |
| } |
| }; |
| |
| auto x = torch::randn(1, torch::requires_grad()); |
| auto x1 = torch::randn(1); |
| auto x2 = MyFunction::apply(x + x1); |
| |
| auto y = torch::randn(1, torch::requires_grad()); |
| auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true); |
| ASSERT_FALSE(grad_res[0].defined()); |
| } |
| |
| TEST(AutogradAPITests, EmptyInput) { |
| Variable x = torch::ones({1}, torch::requires_grad()); |
| ASSERT_THROWS_WITH( |
| grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs."); |
| } |
| |
| TEST(AutogradAPITests, RetainGrad) { |
| auto input = torch::rand({1, 3}, torch::requires_grad()); |
| auto h1 = input * 3; |
| auto out = (h1 * h1).sum(); |
| |
| { |
| // Warning when grad is accessed for non-leaf tensor |
| WarningCapture warnings; |
| ASSERT_FALSE(h1.grad().defined()); |
| ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos); |
| } |
| // It should be possible to call retain_grad() multiple times |
| h1.retain_grad(); |
| h1.retain_grad(); |
| { |
| // If retain_grad is true for a non-leaf tensor, |
| // there should not be any warning when grad is accessed |
| WarningCapture warnings; |
| ASSERT_FALSE(h1.grad().defined()); |
| ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos); |
| } |
| |
| // Gradient should be accumulated |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| out.backward({}, /*keep_graph=*/true); |
| ASSERT_VARIABLE_EQ(h1 * 2, h1.grad()); |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| out.backward({}, /*keep_graph=*/true); |
| ASSERT_VARIABLE_EQ(h1 * 4, h1.grad()); |
| |
| { |
| torch::NoGradGuard no_grad; |
| input.grad().zero_(); |
| } |
| // It should be a no-op for leaves |
| input.retain_grad(); |
| input.retain_grad(); |
| out.backward(); |
| ASSERT_VARIABLE_EQ(input * 18, input.grad()); |
| } |
| |
| TEST(AutogradAPITests, AnomalyMode) { |
| // Needs to have backtrace as warning and then throw an error |
| torch::autograd::DetectAnomalyGuard detect_anomaly; |
| { |
| WarningCapture warnings; |
| auto x = torch::tensor({5.0}, torch::requires_grad()); |
| auto y = x * x; |
| auto z = y * y; |
| y += 1; |
| ASSERT_THROWS_WITH(z.backward(), "inplace"); |
| ASSERT_TRUE( |
| warnings.str().find("Traceback of forward") != std::string::npos); |
| } |
| auto double_backward_produce_nan = [](bool should_throw) { |
| auto x = torch::tensor({0.0}, torch::requires_grad()); |
| auto y = x.pow(1.5); |
| auto gr = |
| // NOLINTNEXTLINE(bugprone-argument-comment) |
| grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); |
| if (should_throw) { |
| WarningCapture warnings; |
| ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})}); |
| , "returned nan"); |
| auto msgs = warnings.messages(); |
| ASSERT_EQ(msgs.size(), 2); |
| ASSERT_TRUE( |
| msgs[0].find("Traceback of forward call that caused the error") != |
| std::string::npos); |
| ASSERT_TRUE( |
| msgs[1].find( |
| "Traceback of forward call that induced the previous calculation") != |
| std::string::npos); |
| } else { |
| grad({gr[0]}, {x}, {torch::tensor({0.0})}); |
| } |
| }; |
| |
| double_backward_produce_nan(true); |
| { |
| torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false); |
| double_backward_produce_nan(false); |
| { |
| torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true); |
| double_backward_produce_nan(true); |
| } |
| } |
| double_backward_produce_nan(true); |
| } |
| |
| TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward( |
| AutogradContext* ctx, |
| Variable var1, |
| Variable var2) { |
| ctx->save_for_backward({var1, var2}); |
| return var1 * var2, var1; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {}; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| MyFunction::apply(x, y); |
| } |
| |
| TEST(CustomAutogradTest, CustomFunction) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward( |
| AutogradContext* ctx, |
| Variable var1, |
| int mul, |
| Variable var2) { |
| ctx->saved_data["mul"] = mul; |
| ctx->save_for_backward({var1, var2}); |
| return var1 + mul * var2 + var1 * var2; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| int mul = ctx->saved_data["mul"].toInt(); |
| auto saved = ctx->get_saved_variables(); |
| auto var1 = saved[0]; |
| auto var2 = saved[1]; |
| variable_list output = { |
| grad_output[0] + grad_output[0] * var2, |
| Variable(), |
| grad_output[0] * mul + grad_output[0] * var1}; |
| return output; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| auto res = MyFunction::apply(x, 2, y); |
| auto go = torch::ones({}, torch::requires_grad()); |
| res.sum().backward(go, false, true); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5})); |
| ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2); |
| } |
| |
| TEST(CustomAutogradTest, CustomFunctionWithTensorList) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, at::TensorList tensors) { |
| torch::autograd::variable_list vars; |
| for (const at::Tensor& tensor : tensors) { |
| vars.push_back(tensor); |
| } |
| ctx->save_for_backward(vars); |
| return tensors[0] + tensors[1] + tensors[0] * tensors[1]; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| auto saved = ctx->get_saved_variables(); |
| auto var1 = saved[0]; |
| auto var2 = saved[1]; |
| variable_list output = { |
| grad_output[0] + grad_output[0] * var2, |
| grad_output[0] + grad_output[0] * var1}; |
| return output; |
| } |
| }; |
| |
| at::Tensor x = torch::randn({5, 5}, torch::requires_grad()); |
| at::Tensor y = torch::randn({5, 5}, torch::requires_grad()); |
| torch::autograd::variable_list variables = {x, y}; |
| at::TensorList tensors = variables; |
| auto res = MyFunction::apply(tensors); |
| auto go = torch::ones({}, torch::requires_grad()); |
| res.sum().backward(go, false, true); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5})); |
| ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5})); |
| } |
| |
| TEST(CustomAutogradTest, GraphTaskTrimEdges) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward( |
| AutogradContext* ctx, |
| Variable var1, |
| Variable var2, |
| int mul, |
| bool needs_input1_grad, |
| bool needs_input2_grad) { |
| // setup the expected should and should not compute idx |
| ctx->saved_data["needs_input1_grad"] = needs_input1_grad; |
| ctx->saved_data["needs_input2_grad"] = needs_input2_grad; |
| |
| ctx->saved_data["mul"] = mul; |
| ctx->save_for_backward({var1, var2}); |
| return var1 + mul * var2 + var1 * var2; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| // Test `needs_input_grad` method is working correctly. |
| // We have to test this within the backward function. |
| auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool(); |
| auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool(); |
| IndexRange var1_idx = {0, 1}; |
| IndexRange var2_idx = {1, 2}; |
| EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad); |
| EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad); |
| EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad); |
| EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad); |
| EXPECT_EQ( |
| ctx->needs_input_grad({var1_idx, var2_idx}), |
| needs_input1_grad || needs_input2_grad); |
| |
| // calculate gradients |
| int mul = ctx->saved_data["mul"].toInt(); |
| auto saved = ctx->get_saved_variables(); |
| auto var1 = saved[0]; |
| auto var2 = saved[1]; |
| |
| Variable grad_var1, grad_var2; |
| if (ctx->needs_input_grad(0)) { |
| grad_var1 = grad_output[0] + grad_output[0] * var2; |
| } |
| if (ctx->needs_input_grad(1)) { |
| grad_var2 = grad_output[0] * mul + grad_output[0] * var1; |
| } |
| variable_list output = { |
| grad_var1, |
| grad_var2, |
| Variable(), |
| Variable(), |
| Variable(), |
| }; |
| return output; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| auto go = torch::ones_like(x); |
| Variable out; |
| |
| // grad_x |
| out = MyFunction::apply( |
| x, |
| y, |
| 2, |
| /* needs_input1_grad= */ true, |
| /* needs_input2_grad= */ false); |
| auto grad_x = torch::autograd::grad({out}, {x}, {go})[0]; |
| ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5})); |
| |
| // grad_y |
| out = MyFunction::apply( |
| x, |
| y, |
| 2, |
| /* needs_input1_grad= */ false, |
| /* needs_input2_grad= */ true); |
| auto grad_y = torch::autograd::grad({out}, {y}, {go})[0]; |
| ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2); |
| |
| // grad_x and grad_y |
| out = MyFunction::apply( |
| x, |
| y, |
| 2, |
| /* needs_input1_grad= */ true, |
| /* needs_input2_grad= */ true); |
| auto grads = torch::autograd::grad({out}, {x, y}, {go}); |
| ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5})); |
| ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2); |
| } |
| |
| TEST(CustomAutogradTest, FunctionReturnsInput) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable var1) { |
| return var1; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {grad_output[0] * 2}; |
| } |
| }; |
| |
| Variable x(torch::ones(1, torch::requires_grad())); |
| MyFunction::apply(x).backward(torch::ones(1), true, true); |
| ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.)); |
| } |
| |
| TEST(CustomAutogradTest, FunctionReturnsUndefined) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable var) { |
| return var * 2; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| at::Tensor undefined_tensor; |
| return {undefined_tensor}; |
| } |
| }; |
| |
| auto x = torch::ones(1, torch::requires_grad()); |
| |
| MyFunction::apply(x).backward(); |
| ASSERT_FALSE(x.grad().defined()); |
| |
| MyFunction::apply(x.pow(2)).backward(); |
| ASSERT_FALSE(x.grad().defined()); |
| |
| MyFunction::apply(x).sum().backward(); |
| ASSERT_FALSE(x.grad().defined()); |
| |
| ASSERT_FALSE(torch::autograd::grad( |
| {MyFunction::apply(x)}, {x}, {}, false, false, true)[0] |
| .defined()); |
| } |
| |
| TEST(CustomAutogradTest, MaterializeGrads) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable var) { |
| return var; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1)); |
| return grad_output; |
| } |
| }; |
| |
| auto x = torch::ones(1, torch::requires_grad()); |
| UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); |
| } |
| |
| TEST(CustomAutogradTest, DontMaterializeGrads) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable var) { |
| ctx->set_materialize_grads(false); |
| return var; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| EXPECT_FALSE(grad_output[0].defined()); |
| return grad_output; |
| } |
| }; |
| |
| auto x = torch::ones(1, torch::requires_grad()); |
| UndefinedGrad().apply({MyFunction::apply(x)})[0].backward(); |
| } |
| |
| TEST(CustomAutogradTest, NoGradCustomFunction) { |
| // Custom Function should respect grad mode |
| struct MyOp : public Function<MyOp> { |
| static Variable forward(AutogradContext* ctx, Variable x) { |
| return x + 1; |
| } |
| |
| static variable_list backward(AutogradContext* ctx, variable_list dy) { |
| return dy; |
| } |
| }; |
| |
| auto x = torch::ones({5, 5}, torch::requires_grad()); |
| { |
| at::NoGradGuard no_grad; |
| auto y = MyOp::apply(x); |
| ASSERT_FALSE(y.requires_grad()); |
| } |
| } |
| |
| TEST(CustomAutogradTest, MarkDirty) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable v) { |
| // Change the value inplace |
| auto v_data = v.data_ptr<float>(); |
| v_data[0] = 2; |
| ctx->mark_dirty({v}); |
| return v; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {(grad_output[0] * 2.0)}; |
| } |
| }; |
| |
| // Clone here because modifying leafs inplace is not allowed |
| auto x = torch::randn({5, 5}, torch::requires_grad()).clone(); |
| auto version_before = x._version(); |
| auto out = MyFunction::apply(x); |
| auto version_after = x._version(); |
| ASSERT_TRUE(version_after >= (version_before + 1)); |
| out.sum().backward(); |
| } |
| |
| TEST(CustomAutogradTest, MarkNonDifferentiable) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable v) { |
| Variable output = v > 0; |
| ctx->mark_non_differentiable({output}); |
| return output; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {(grad_output[0] * 0.0)}; |
| } |
| }; |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| auto mask = MyFunction::apply(x); |
| ASSERT_FALSE(mask.requires_grad()); |
| auto y = x.masked_fill(mask, 0); |
| y.sum().backward(); |
| } |
| |
| TEST(CustomAutogradTest, MarkNonDifferentiableMixed) { |
| struct MyFunction : public Function<MyFunction> { |
| static variable_list forward(AutogradContext* ctx, Variable input) { |
| Variable a = input + 1; |
| Variable b = input + 2; |
| ctx->mark_non_differentiable({a}); |
| return {a, b}; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| const Variable &grad_a = grad_output[0], &grad_b = grad_output[1]; |
| EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5})); |
| EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5})); |
| return {grad_b}; |
| } |
| }; |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| auto out = MyFunction::apply(x); |
| |
| ASSERT_FALSE(out[0].requires_grad()); |
| ASSERT_TRUE(out[1].requires_grad()); |
| out[1].sum().backward(); |
| ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5})); |
| } |
| |
| TEST(CustomAutogradTest, MarkNonDifferentiableNone) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable input) { |
| auto output = input.clone(); |
| ctx->mark_non_differentiable({output}); |
| return output; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_outputs) { |
| return {}; |
| } |
| }; |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| auto r = MyFunction::apply(x * x); |
| (r * x).sum().backward(); |
| } |
| |
| TEST(CustomAutogradTest, ReturnLeafInplace) { |
| struct Inplace : public Function<Inplace> { |
| static variable_list forward(AutogradContext* ctx, Variable a, Variable b) { |
| ctx->mark_dirty({a}); |
| return {a.add_(b), b + 2}; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {grad_output[0], grad_output[0] + grad_output[1]}; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| |
| auto out = Inplace::apply(x, y); |
| auto& q = out[0]; |
| ASSERT_TRUE(torch::equal(q, x)); |
| ASSERT_TRUE(q.requires_grad()); |
| q.sum().backward(); |
| ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5})); |
| } |
| |
| TEST(CustomAutogradTest, ReturnDuplicateInplace) { |
| struct DoubleInplace : public Function<DoubleInplace> { |
| static variable_list forward(AutogradContext* ctx, Variable x) { |
| x.mul_(2); |
| ctx->mark_dirty({x}); |
| return {x, x}; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctsx, |
| variable_list grad_outputs) { |
| return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; |
| } |
| }; |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| |
| ASSERT_THROWS_WITH( |
| DoubleInplace::apply(x), "leaf Variable that requires grad"); |
| // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one |
| // output"); |
| |
| auto out = DoubleInplace::apply(x.clone()); |
| ASSERT_TRUE(torch::equal(out[0], out[1])); |
| } |
| |
| TEST(CustomAutogradTest, ReturnDuplicate) { |
| struct DoubleDuplicate : public Function<DoubleDuplicate> { |
| static variable_list forward(AutogradContext* ctx, Variable x) { |
| auto output = x * 2; |
| return {output, output}; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_outputs) { |
| return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; |
| } |
| }; |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| auto out = DoubleDuplicate::apply(x); |
| ASSERT_TRUE(torch::equal(out[0], out[1])); |
| } |
| |
| TEST(CustomAutogradTest, SaveEmptyForBackward) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable input) { |
| ctx->save_for_backward({Variable(), input, Variable()}); |
| return input * input; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| auto saved = ctx->get_saved_variables(); |
| EXPECT_FALSE(saved[0].defined()); |
| EXPECT_FALSE(saved[2].defined()); |
| return {saved[1] * 2 * grad_output[0]}; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| auto y = MyFunction::apply(x); |
| y.sum().backward(); |
| ASSERT_VARIABLE_EQ(x.grad(), 2 * x); |
| } |
| |
| TEST(CustomAutogradTest, InvalidGradients) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext* ctx, Variable x) { |
| return x * 2; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctsx, |
| variable_list grad_outputs) { |
| return { |
| torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))}; |
| } |
| }; |
| |
| auto input1 = |
| torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true)); |
| ASSERT_THROWS_WITH( |
| MyFunction::apply(input1).sum().backward(), "expected shape"); |
| auto input2 = |
| torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); |
| } |
| |
| TEST(CustomAutogradTest, NoGradInput) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext*, Variable x) { |
| return x; |
| } |
| |
| static variable_list backward( |
| AutogradContext*, |
| variable_list grad_outputs) { |
| return grad_outputs; |
| } |
| }; |
| |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y; |
| { |
| at::NoGradGuard no_grad; |
| y = MyFunction::apply(x); |
| } |
| |
| ASSERT_TRUE(x.requires_grad()); |
| ASSERT_FALSE(y.grad_fn()); |
| } |
| |
| TEST(CustomAutogradTest, TooManyGrads) { |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext*, Variable input) { |
| return input; |
| } |
| |
| static variable_list backward(AutogradContext*, variable_list grad_output) { |
| grad_output.insert(grad_output.end(), {Variable(), Variable()}); |
| return grad_output; |
| } |
| }; |
| } |
| |
| TEST(CustomAutogradTest, DepNoGrad) { |
| struct F1 : public Function<F1> { |
| static variable_list forward(AutogradContext* ctx, Variable input) { |
| auto out = torch::randn(input.sizes()); |
| ctx->mark_non_differentiable({out}); |
| return {input, out}; |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| return {grad_output[0]}; |
| } |
| }; |
| |
| struct F2 : public Function<F2> { |
| static Variable forward(AutogradContext*, Variable input, Variable ignore) { |
| return input; |
| } |
| |
| static variable_list backward(AutogradContext*, variable_list grad_output) { |
| return {grad_output[0], Variable()}; |
| } |
| }; |
| |
| auto x = torch::randn(5, torch::requires_grad()); |
| auto out = F1::apply(x); |
| Variable &a = out[0], &b = out[1]; |
| b = b + 1; // Separate F1 and F2 by another operation |
| ASSERT_TRUE(a.requires_grad()); |
| ASSERT_FALSE(b.requires_grad()); |
| |
| auto c = F2::apply(a, b); |
| c.backward(torch::ones(c.sizes()), false, false); |
| ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes())); |
| } |
| |
| TEST(CustomAutogradTest, Reentrant) { |
| static Variable y_data = torch::randn({2, 2}); |
| struct Reenter : public Function<Reenter> { |
| static Variable forward(AutogradContext* ctx, Variable input) { |
| Variable output; |
| { |
| at::AutoGradMode enable_grad(true); |
| auto x = make_variable(input.tensor_data(), true); |
| auto y = make_variable(y_data.tensor_data(), true); |
| output = x * y; |
| |
| ctx->saved_data["x"] = x; |
| ctx->saved_data["y"] = y; |
| ctx->saved_data["output_var"] = output; |
| } |
| return output.detach(); |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| { |
| at::AutoGradMode enable_grad(true); |
| auto out = ctx->saved_data["output_var"].toTensor(); |
| out.sum().backward(); |
| } |
| return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]}; |
| } |
| }; |
| |
| auto x = torch::randn({2, 2}, torch::requires_grad()); |
| auto out = Reenter::apply(x); |
| out.sum().backward(); |
| ASSERT_VARIABLE_EQ(x.grad(), y_data); |
| } |
| |
| // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of |
| // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950 |
| TEST(CustomAutogradTest, DeepReentrant) { |
| struct DeepReenter : public Function<DeepReenter> { |
| static Variable forward(AutogradContext* ctx, Variable x) { |
| { |
| at::AutoGradMode enable_grad(true); |
| ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1; |
| } |
| return ctx->saved_data["x"].toTensor().detach(); |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) { |
| return grad_output; |
| } |
| { |
| at::AutoGradMode enable_grad(true); |
| apply(ctx->saved_data["x"].toTensor())[0].sum().backward(); |
| return grad_output; |
| } |
| } |
| }; |
| |
| // This should not stack overflow |
| auto v = |
| torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true)); |
| DeepReenter::apply(v).sum().backward(); |
| } |
| |
| TEST(CustomAutogradTest, ReentrantPriority) { |
| static std::vector<int> order; |
| |
| struct MyFunction : public Function<MyFunction> { |
| static Variable forward(AutogradContext*, Variable x) { |
| return x; |
| } |
| |
| static variable_list backward(AutogradContext*, variable_list grad) { |
| order.push_back(0); |
| return grad; |
| } |
| }; |
| |
| struct Reenter : public Function<Reenter> { |
| static Variable forward(AutogradContext* ctx, Variable x) { |
| { |
| at::AutoGradMode enable_grad(true); |
| ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1; |
| } |
| return ctx->saved_data["x"].toTensor().detach(); |
| } |
| |
| static variable_list backward( |
| AutogradContext* ctx, |
| variable_list grad_output) { |
| order.push_back(1); |
| if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) { |
| return grad_output; |
| } |
| { |
| at::AutoGradMode enable_grad(true); |
| apply(ctx->saved_data["x"].toTensor())[0].sum().backward(); |
| return grad_output; |
| } |
| } |
| }; |
| |
| auto a = MyFunction::apply( |
| torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true))); |
| auto b = Reenter::apply( |
| torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true))); |
| auto v = a * b; |
| v.backward(); |
| |
| // All the reentrant tasks should be prioritized over the MyFunction backward |
| // task. |
| ASSERT_EQ(order.size(), 10); |
| ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9); |
| ASSERT_EQ(order.back(), 0); |
| // Clear static variable in case test get executed in a loop |
| order.clear(); |
| } |
| |
| TEST(CustomAutogradTest, Hooks) { |
| Variable x = torch::ones({5, 5}, torch::requires_grad()); |
| Variable y = torch::ones({5, 5}) * 4; |
| y.set_requires_grad(true); |
| |
| int counter = 0; |
| |
| std::function<void(int, Variable)> bw_hook( |
| [&counter](int inc, Variable grad) { counter += inc; }); |
| |
| Variable z = x * x + x * 2 + x * y + y; |
| x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); }); |
| auto hook_1 = |
| z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); }); |
| z.backward(torch::ones({5, 5}), true, true); |
| ASSERT_EQ(counter, 1); |
| |
| auto hook_2 = |
| z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); }); |
| z.backward(torch::ones({5, 5}), true, true); |
| ASSERT_EQ(counter, 4); |
| |
| z.remove_hook(hook_2); |
| z.backward(torch::ones({5, 5}), true, true); |
| ASSERT_EQ(counter, 5); |
| |
| std::function<Variable(Variable)> bw_hook_modify( |
| [](Variable grad) { return grad.mul(2); }); |
| |
| z.remove_hook(hook_1); |
| z.register_hook(bw_hook_modify); |
| y.grad().zero_(); |
| z.backward(torch::ones({5, 5}), true, false); |
| ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2); |
| |
| y.register_hook(bw_hook_modify); |
| y.grad().zero_(); |
| z.backward(torch::ones({5, 5}), false, false); |
| ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4); |
| |
| ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index"); |
| } |
| |
| TEST(CustomAutogradTest, HooksInplace) { |
| auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
| |
| int hook1_count = 0; |
| auto hook1 = ([&hook1_count](Variable grad) { |
| hook1_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
| }); |
| |
| int hook2_count = 0; |
| auto hook2 = ([&hook2_count](Variable grad) { |
| hook2_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
| }); |
| |
| a.register_hook(hook1); |
| a.mul_(2); |
| a.register_hook(hook2); |
| |
| auto out = (a + 1).sum(); |
| out.backward(); |
| |
| ASSERT_EQ(hook1_count, 1); |
| ASSERT_EQ(hook2_count, 1); |
| } |
| |
| TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) { |
| auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
| |
| int hook1_count = 0; |
| auto hook1 = ([&hook1_count](Variable grad) { |
| hook1_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
| }); |
| |
| int hook2_count = 0; |
| auto hook2 = ([&hook2_count](Variable grad) { |
| hook2_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2); |
| }); |
| |
| int hook3_count = 0; |
| auto hook3 = ([&hook3_count](Variable grad) { |
| hook3_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
| }); |
| |
| a.register_hook(hook1); |
| a.retain_grad(); |
| a.register_hook(hook2); |
| |
| a.mul_(2); |
| a.register_hook(hook3); |
| |
| auto out = (a + 1).sum(); |
| out.backward(); |
| |
| ASSERT_EQ(hook1_count, 1); |
| ASSERT_EQ(hook2_count, 1); |
| ASSERT_EQ(hook3_count, 1); |
| |
| ASSERT_TRUE(a.retains_grad()); |
| ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); |
| } |
| |
| TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) { |
| auto a = torch::ones({5, 5}, torch::requires_grad()).clone(); |
| |
| int hook1_count = 0; |
| auto hook1 = ([&hook1_count](Variable grad) { |
| hook1_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); |
| }); |
| |
| int hook2_count = 0; |
| auto hook2 = ([&hook2_count](Variable grad) { |
| hook2_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4); |
| }); |
| |
| int hook3_count = 0; |
| auto hook3 = ([&hook3_count](Variable grad) { |
| hook3_count++; |
| ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5})); |
| }); |
| |
| a.register_hook(hook1); |
| a.retain_grad(); |
| a.register_hook(hook2); |
| |
| a.mul_(2); |
| a.mul_(2); |
| a.register_hook(hook3); |
| |
| auto out = (a + 1).sum(); |
| out.backward(); |
| |
| ASSERT_EQ(hook1_count, 1); |
| ASSERT_EQ(hook2_count, 1); |
| ASSERT_EQ(hook3_count, 1); |
| |
| ASSERT_TRUE(a.retains_grad()); |
| ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5})); |
| } |
| |
| TEST(CustomAutogradTest, HookNone) { |
| struct NoneGradientFunction : public Function<NoneGradientFunction> { |
| static variable_list forward(AutogradContext* ctx, Variable x, Variable y) { |
| return {x, y}; |
| } |
| |
| static variable_list backward(AutogradContext* ctx, variable_list grad) { |
| return {grad[0], Variable()}; |
| } |
| }; |
| |
| bool was_called = false; |
| |
| auto hook = ([&was_called](Variable grad) { |
| ASSERT_TRUE(grad.defined()); |
| was_called = true; |
| }); |
| |
| auto x = torch::randn({5, 5}, torch::requires_grad()); |
| auto y = torch::randn({5, 5}); |
| |
| auto out = NoneGradientFunction::apply(x, y); |
| Variable rx = x[0], ry = x[1]; |
| |
| rx.register_hook(hook); |
| ry.register_hook(hook); |
| (rx + ry).sum().backward(); |
| ASSERT_TRUE(was_called); |
| } |
| |
| TEST(CustomAutogradTest, BackwardWithInputs) { |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| Variable z = x * x + x * y + y * y; |
| Variable x_grad_expected = 2 * x + y; |
| Variable y_grad_expected = x + 2 * y; |
| |
| z.backward(torch::ones({5, 5}), false, false, {x}); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); |
| ASSERT_FALSE(y.grad().defined()); |
| } |
| |
| TEST(CustomAutogradTest, BackwardWithEmptyInputs) { |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| Variable z = x * x + x * y + y * y; |
| Variable x_grad_expected = 2 * x + y; |
| Variable y_grad_expected = x + 2 * y; |
| ASSERT_THROWS_WITH( |
| z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}), |
| "cannot be empty"); |
| } |
| |
| TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { |
| Variable x = torch::randn({5, 5}, torch::requires_grad()); |
| Variable y = torch::randn({5, 5}, torch::requires_grad()); |
| Variable z = x * x; |
| Variable w = y * z + x * y + y * y; |
| |
| Variable x_grad_expected = 2 * x * y + y; |
| Variable z_grad_expected = y; |
| |
| w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z}); |
| |
| ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected); |
| ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected); |
| ASSERT_FALSE(y.grad().defined()); |
| } |
| |
| TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { |
| c10::WarningUtils::WarnAlways guard(true); |
| |
| torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true); |
| auto z = x * x; |
| { |
| WarningCapture warnings; |
| z.backward(torch::ones({5, 5}), std::nullopt, true); |
| ASSERT_TRUE( |
| warnings.str().find("Using backward() with create_graph=True") != |
| std::string::npos); |
| } |
| |
| { |
| WarningCapture warnings; |
| torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true); |
| ASSERT_TRUE( |
| warnings.str().find("Using backward() with create_graph=True") != |
| std::string::npos); |
| } |
| } |
| |
| /** |
| * Tests for AutogradNotImplementedFallback |
| * - Check that we created the NotImplemented kernel when inputs require grad |
| * but when no inputs require grad, we should not create this node |
| * - check_inplace logic |
| * - view ops |
| * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't |
| * test non-NDEBUG builds. |
| * - tensorlist input and output |
| * - multiple outputs / non-tensor output |
| * - rebase_history vs set_history |
| */ |
| namespace { |
| |
| torch::Tensor inplace_op( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| return self.add_(other); |
| } |
| |
| std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| other.add_(self); |
| self.add_(other); |
| return std::tuple<torch::Tensor, torch::Tensor>(self, other); |
| } |
| |
| std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| // This is not allowed. We test below that this calling into the boxed kernel |
| // will raise an error |
| return std::tuple<torch::Tensor, torch::Tensor>(self, other); |
| } |
| |
| std::tuple<torch::Tensor, torch::Tensor> non_first_view_op( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| // This is not allowed. We test below that this calling into the boxed kernel |
| // will raise an error |
| return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other); |
| } |
| |
| int64_t ret_single_non_tensor( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| return 12; |
| } |
| |
| torch::Tensor opt_op( |
| const torch::Tensor& self, |
| const std::optional<at::Tensor>& other) { |
| if (other.has_value()) { |
| return self + other.value(); |
| } else { |
| return self.clone(); |
| } |
| } |
| |
| torch::Tensor my_custom_op( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| return self + other; |
| } |
| |
| std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| auto a = self - other; |
| auto b = self + other; |
| return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12); |
| } |
| |
| torch::Tensor view_op(const torch::Tensor& self) { |
| return self.alias(); |
| } |
| |
| torch::Tensor view_op_with_extra_arg( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| return self.alias(); |
| } |
| |
| std::vector<torch::Tensor> ret_tensor_vector_view( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| return {self.alias(), self.alias()}; |
| } |
| |
| std::vector<at::Tensor> ret_tensor_vector( |
| const torch::Tensor& self, |
| const torch::Tensor& other) { |
| std::vector<at::Tensor> out; |
| out.push_back(self + other); |
| out.push_back(self - other); |
| return out; |
| } |
| |
| torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { |
| const auto& res = self.clone(); |
| for (const auto& t : other) { |
| res.add_(t); |
| } |
| return res; |
| } |
| |
| #define REGISTER_TEST_OP(name, schema, fn) \ |
| auto m = MAKE_TORCH_LIBRARY(_test); \ |
| m.def(schema); \ |
| auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ |
| auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ |
| auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \ |
| m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ |
| m_autograd.impl( \ |
| name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \ |
| m_inplaceorview.impl( \ |
| name, \ |
| c10::DispatchKey::ADInplaceOrView, \ |
| autogradNotImplementedInplaceOrViewFallback()); |
| |
| template <typename F> |
| void assertBasicChecks(F op) { |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| auto c = torch::tensor({1.}, {torch::kFloat32}); |
| |
| // If any inputs require grad, |
| auto out1 = op(a, b); |
| ASSERT_THROWS_WITH(out1.backward(), "is not implemented"); |
| |
| // # Should not have grad_fn if none require grad |
| auto out2 = op(b, c); |
| ASSERT_THROWS_WITH( |
| out2.backward(), |
| "element 0 of tensors does not require grad and does not have a grad_fn"); |
| |
| // TODO: Forward AD Tests? |
| } |
| |
| } // namespace |
| |
| TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { |
| REGISTER_TEST_OP( |
| "ret_single_non_tensor", |
| "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", |
| ret_single_non_tensor); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::ret_single_non_tensor", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>( |
| opHandle, _1, _2); |
| }; |
| |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| |
| ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b)); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, InplaceOp) { |
| REGISTER_TEST_OP( |
| "inplace_op", |
| "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", |
| inplace_op); |
| auto opHandle = |
| c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| torch::Tensor, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| |
| // Check in-place |
| ASSERT_THROWS_WITH( |
| op(a, b), |
| "a leaf Variable that requires grad is being used in an in-place operation"); |
| op(b, a); |
| a = a.clone(); |
| b = b.clone(); |
| auto c = op(a, b); |
| ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); |
| |
| // Test in-place on view |
| auto base = |
| torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
| auto view = base.view(-1); |
| auto t = torch::tensor({1.}, {torch::kFloat32}); |
| |
| torch::Tensor v_nograd; |
| { |
| c10::NoGradGuard guard; |
| v_nograd = base.view(-1); |
| op(v_nograd, t); |
| } |
| |
| ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode"); |
| ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); |
| ASSERT_THAT( |
| op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { |
| REGISTER_TEST_OP( |
| "two_arg_inplace_op", |
| "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", |
| two_arg_inplace_op); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::two_arg_inplace_op", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| std::tuple<torch::Tensor, torch::Tensor>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| |
| // Both are modified in-place! |
| ASSERT_THROWS_WITH( |
| op(a, b), |
| "a leaf Variable that requires grad is being used in an in-place operation"); |
| ASSERT_THROWS_WITH( |
| op(b, a), |
| "a leaf Variable that requires grad is being used in an in-place operation"); |
| |
| auto c = |
| torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
| auto d = |
| torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
| |
| auto saved_version_c = c._version(); |
| auto saved_version_d = d._version(); |
| op(c, d); |
| ASSERT_NE(c._version(), saved_version_c); |
| ASSERT_NE(d._version(), saved_version_d); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, OptOp) { |
| REGISTER_TEST_OP( |
| "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); |
| auto opHandle = |
| c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); |
| auto op = [&](const torch::Tensor& _1, |
| const std::optional<torch::Tensor>& _2) { |
| return callOpUnboxed< |
| torch::Tensor, |
| const torch::Tensor&, |
| const std::optional<torch::Tensor>&>(opHandle, _1, _2); |
| }; |
| |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| |
| ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b))); |
| ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {}))); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { |
| REGISTER_TEST_OP( |
| "my_custom_op", |
| "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", |
| my_custom_op); |
| auto opHandle = |
| c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| torch::Tensor, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| |
| assertBasicChecks(op); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { |
| REGISTER_TEST_OP( |
| "ret_tuple_non_tensor", |
| "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", |
| ret_tuple_non_tensor); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::ret_tuple_non_tensor", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| auto out = callOpUnboxed< |
| std::tuple<torch::Tensor, torch::Tensor, int64_t>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| auto [out0, out1, out2] = std::move(out); |
| return out0; |
| }; |
| |
| assertBasicChecks(op); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, ViewOp) { |
| REGISTER_TEST_OP( |
| "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op); |
| auto opHandle = |
| c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); |
| auto op = [&](const torch::Tensor& _1) { |
| return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1); |
| }; |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| auto v = op(b); |
| ASSERT_TRUE(v.is_view()); |
| ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); |
| |
| auto b1 = |
| torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); |
| auto v1 = op(b1); |
| ASSERT_TRUE(v1.is_view()); |
| ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl()); |
| |
| // Test inplace on view |
| auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| |
| // raise on rebase_history when it refreshes grad_fn |
| ASSERT_THROWS_WITH( |
| v1.add_(t), "which does not have a derivative implemented is forbidden"); |
| // base should not be aware of the views, so this is still okay |
| b1.add_(t); |
| ASSERT_THROWS_WITH( |
| v1.grad_fn(), |
| "which does not have a derivative implemented is forbidden"); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) { |
| REGISTER_TEST_OP( |
| "view_op_with_extra_arg", |
| "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)", |
| view_op_with_extra_arg); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::view_op_with_extra_arg", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| torch::Tensor, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| assertBasicChecks(op); |
| auto a = torch::tensor({1.}, {torch::kFloat32}); |
| auto b = torch::tensor({2.}, {torch::kFloat32}); |
| auto out1 = op(a, b); |
| ASSERT_TRUE(out1.is_view()); |
| ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) { |
| REGISTER_TEST_OP( |
| "ret_tensor_vector_view", |
| "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)", |
| ret_tensor_vector_view); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::ret_tensor_vector_view", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| std::vector<at::Tensor>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| auto a = torch::tensor({1.}, {torch::kFloat32}); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| auto out = op(a, b); |
| ASSERT_TRUE(out[0].is_view()); |
| ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
| ASSERT_TRUE(out[1].is_view()); |
| ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl()); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { |
| REGISTER_TEST_OP( |
| "two_pairs_of_view_op", |
| "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", |
| two_pairs_of_view_op); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::two_pairs_of_view_op", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| std::tuple<torch::Tensor, torch::Tensor>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| ASSERT_THROWS_WITH( |
| op(a, b), |
| "Expected only a single output in the operator schema to have a non-write alias annotation"); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) { |
| REGISTER_TEST_OP( |
| "non_first_view_op", |
| "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))", |
| non_first_view_op); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::non_first_view_op", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| std::tuple<torch::Tensor, torch::Tensor>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2); |
| }; |
| auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| ASSERT_THROWS_WITH( |
| op(a, b), "can only create view relationships between the first"); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, RetTensorVector) { |
| REGISTER_TEST_OP( |
| "ret_tensor_vector", |
| "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", |
| ret_tensor_vector); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::ret_tensor_vector", ""); |
| auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { |
| return callOpUnboxed< |
| std::vector<at::Tensor>, |
| const torch::Tensor&, |
| const torch::Tensor&>(opHandle, _1, _2)[0]; |
| }; |
| assertBasicChecks(op); |
| } |
| |
| TEST(TestAutogradNotImplementedFallback, TensorlistOp) { |
| REGISTER_TEST_OP( |
| "tensorlist_op", |
| "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", |
| tensorlist_op); |
| auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( |
| "_test::tensorlist_op", ""); |
| auto op = [&](torch::Tensor _1, at::TensorList _2) { |
| return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>( |
| opHandle, _1, _2); |
| }; |
| |
| auto a = torch::tensor({1.}, {torch::kFloat32}); |
| auto b = torch::tensor({1.}, {torch::kFloat32}); |
| auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); |
| std::vector<torch::Tensor> vec = {b, c}; |
| auto out = op(a, vec); |
| |
| ASSERT_THROWS_WITH( |
| torch::autograd::grad({out}, {vec[0]}), |
| "element 0 of the input tensors does not require grad"); |
| ASSERT_THROWS_WITH( |
| torch::autograd::grad({out}, {vec[1]}), "is not implemented"); |
| |
| ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); |
| } |
| |
| // TODO add these tests if needed |
| // test_once_differentiable |
| // test_sparse_backward |
| // test_save_output_nr |
| // test_free_deep_graph_pyfunction |
| // test_naughty_anomaly_access |
| // test_naughty_autograd-function_stashing_ctx |
| // test_custom_autograd_repeated_grad_grad |
| // test_return_leaf |
| // test_anomaly_detect_nan |
| // test_no_grad_copy |