| #include <gtest/gtest.h> |
| |
| #include <c10/util/irange.h> |
| #include <torch/torch.h> |
| |
| #include <test/cpp/api/support.h> |
| |
| using namespace torch::nn; |
| using namespace torch::test; |
| |
| struct AGIUnit : torch::nn::Module {}; |
| |
| namespace test { |
| struct AGIUnit : torch::nn::Module {}; |
| struct AGIUnit2 : torch::nn::Module { |
| AGIUnit2() : torch::nn::Module("Foo") {} |
| }; |
| } // namespace test |
| |
| struct ModuleTest : torch::test::SeedingFixture {}; |
| |
| TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) { |
| Linear module(3, 4); |
| ASSERT_TRUE(module->is_training()); |
| |
| module->eval(); |
| ASSERT_FALSE(module->is_training()); |
| |
| module->train(); |
| ASSERT_TRUE(module->is_training()); |
| } |
| |
| TEST_F(ModuleTest, ZeroGrad) { |
| Linear module(3, 4); |
| auto weight = torch::ones({8, 3}, torch::requires_grad()); |
| auto loss = module(weight).sum(); |
| loss.backward(); |
| for (auto& parameter : module->parameters()) { |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| auto grad = parameter.grad(); |
| ASSERT_TRUE(grad.defined()); |
| ASSERT_NE(grad.sum().item<float>(), 0); |
| } |
| module->zero_grad(); |
| for (auto& parameter : module->parameters()) { |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| auto grad = parameter.grad(); |
| ASSERT_FALSE(grad.defined()); |
| } |
| } |
| |
| TEST_F(ModuleTest, ZeroGradWithUndefined) { |
| struct TestModule : torch::nn::Module { |
| TestModule() { |
| x = register_parameter("x", torch::ones(5, torch::requires_grad())); |
| y = register_parameter("y", torch::ones(5, torch::requires_grad())); |
| } |
| torch::Tensor x, y; |
| }; |
| |
| TestModule module; |
| auto z = module.x * 2; |
| z.sum().backward(); |
| |
| ASSERT_TRUE(module.x.grad().defined()); |
| ASSERT_FALSE(module.y.grad().defined()); |
| |
| module.zero_grad(false); // set_to_none = false |
| |
| ASSERT_TRUE(module.x.grad().defined()); |
| ASSERT_FALSE(module.y.grad().defined()); |
| |
| ASSERT_EQ(module.x.grad().sum().item<float>(), 0); |
| |
| module.zero_grad(); |
| |
| ASSERT_FALSE(module.x.grad().defined()); |
| ASSERT_FALSE(module.y.grad().defined()); |
| } |
| |
| TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) { |
| struct TestModel : public torch::nn::Module {}; |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)), |
| "Submodule name must not contain a dot (got 'name.with.dot')"); |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_module("", torch::nn::Linear(3, 4)), |
| "Submodule name must not be empty"); |
| } |
| |
| TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) { |
| struct TestModel : public torch::nn::Module {}; |
| TestModel model; |
| model.register_module("linear", torch::nn::Linear(3, 4)); |
| ASSERT_THROWS_WITH( |
| model.register_module("linear", torch::nn::Linear(3, 4)), |
| "Submodule 'linear' already defined"); |
| } |
| |
| TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) { |
| torch::nn::Module model; |
| ASSERT_THROWS_WITH( |
| model.replace_module("linear", torch::nn::Linear(3, 4)), |
| "Submodule 'linear' is not defined"); |
| } |
| |
| TEST_F(ModuleTest, ReplaceModule) { |
| struct TestModel : public torch::nn::Module { |
| torch::nn::Linear l1{nullptr}; |
| TestModel() { |
| l1 = register_module("l1", torch::nn::Linear(3, 4)); |
| } |
| }; |
| auto model = std::make_shared<TestModel>(); |
| model->l1 = model->replace_module("l1", torch::nn::Linear(5, 6)); |
| ASSERT_EQ(model->named_parameters()["l1.weight"].size(0), 6); |
| ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as<Linear>()); |
| } |
| |
| TEST_F(ModuleTest, UnregisterModule) { |
| struct TestModel : public torch::nn::Module {}; |
| TestModel model; |
| ASSERT_THROWS_WITH( |
| model.unregister_module("linear"), |
| "No Module with name `linear` is registered"); |
| model.register_module("linear", torch::nn::Linear(3, 4)); |
| model.unregister_module("linear"); |
| ASSERT_TRUE(model.children().empty()); |
| } |
| |
| TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) { |
| struct TestModel : public torch::nn::Module {}; |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_parameter("name.with.dot", torch::ones(5)), |
| "Parameter name must not contain a dot (got 'name.with.dot')"); |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_parameter("", torch::ones(5)), |
| "Parameter name must not be empty"); |
| } |
| |
| TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) { |
| struct TestModel : public torch::nn::Module {}; |
| TestModel model; |
| model.register_parameter("p", torch::ones(5)); |
| ASSERT_THROWS_WITH( |
| model.register_parameter("p", torch::ones(5)), |
| "Parameter 'p' already defined"); |
| } |
| |
| TEST_F(ModuleTest, RegisterParameterUndefinedTensor) { |
| struct TestModel : public torch::nn::Module {}; |
| { |
| TestModel model; |
| model.register_parameter( |
| "undefined_tensor", torch::Tensor(), /*requires_grad=*/false); |
| ASSERT_EQ(model.parameters().size(), 0); |
| } |
| { |
| WarningCapture warnings; |
| |
| TestModel model; |
| model.register_parameter("undefined_tensor", torch::Tensor()); |
| ASSERT_EQ(model.parameters().size(), 0); |
| |
| ASSERT_EQ( |
| count_substr_occurrences( |
| warnings.str(), |
| "Ignoring the `requires_grad=true` function parameter"), |
| 1); |
| } |
| } |
| |
| TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) { |
| struct TestModel : public torch::nn::Module {}; |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_buffer("name.with.dot", torch::ones(5)), |
| "Buffer name must not contain a dot (got 'name.with.dot')"); |
| ASSERT_THROWS_WITH( |
| TestModel{}.register_buffer("", torch::ones(5)), |
| "Buffer name must not be empty"); |
| } |
| |
| TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) { |
| struct TestModel : public torch::nn::Module {}; |
| TestModel model; |
| model.register_buffer("p", torch::ones(5)); |
| ASSERT_THROWS_WITH( |
| model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined"); |
| } |
| |
| TEST_F(ModuleTest, CanGetName) { |
| // CHECK instead of REQUIRE because demangling may fail. |
| AGIUnit agi; |
| // Call it twice just to make sure there are no bugs in the lazy |
| // initialization semantics. |
| EXPECT_EQ(agi.name(), "AGIUnit"); |
| EXPECT_EQ(agi.name(), "AGIUnit"); |
| EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit"); |
| EXPECT_EQ(test::AGIUnit2().name(), "Foo"); |
| } |
| |
| TEST_F(ModuleTest, AsCastsModulesCorrectly) { |
| Linear module(3, 4); |
| ASSERT_EQ(module->as<Linear>(), module.get()); |
| ASSERT_EQ(module->as<LinearImpl>(), module.get()); |
| ASSERT_EQ(module->as<Module>(), module.get()); |
| ASSERT_EQ(module->as<AGIUnit>(), nullptr); |
| |
| std::shared_ptr<Module> raw = module.ptr(); |
| ASSERT_EQ(raw->as<Linear>(), module.get()); |
| ASSERT_EQ(raw->as<LinearImpl>(), module.get()); |
| ASSERT_EQ(raw->as<Module>(), module.get()); |
| ASSERT_EQ(raw->as<AGIUnit>(), nullptr); |
| |
| Module& raw_ref = *raw.get(); |
| ASSERT_EQ(raw_ref.as<Linear>(), module.get()); |
| ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get()); |
| ASSERT_EQ(raw_ref.as<Module>(), module.get()); |
| ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr); |
| if (auto* linear = raw_ref.as<Linear>()) { |
| ASSERT_EQ(linear->weight.ndimension(), 2); |
| } |
| |
| AGIUnit unit; |
| ASSERT_EQ(unit.as<Linear>(), nullptr); |
| ASSERT_EQ(unit.as<LinearImpl>(), nullptr); |
| ASSERT_EQ(unit.as<AGIUnit>(), &unit); |
| } |
| |
| void test_DeviceOrDtypeConversionSkipsUndefinedTensor( |
| torch::Device to_device, |
| torch::Dtype to_dtype) { |
| { |
| // Case 1: Undefined tensors as parameters |
| Linear module(LinearOptions(10, 20).bias(false)); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_FALSE(module->bias.defined()); |
| |
| module->to(to_device); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_EQ(module->weight.device().type(), to_device.type()); |
| ASSERT_FALSE(module->bias.defined()); |
| |
| module->to(to_dtype); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_EQ(module->weight.dtype(), to_dtype); |
| ASSERT_FALSE(module->bias.defined()); |
| } |
| { |
| // Case 2: Undefined tensors as buffers |
| BatchNorm1d module( |
| BatchNorm1dOptions(5).track_running_stats(false).affine(true)); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_FALSE(module->running_mean.defined()); |
| |
| module->to(to_device); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_EQ(module->weight.device().type(), to_device.type()); |
| ASSERT_FALSE(module->running_mean.defined()); |
| |
| module->to(to_dtype); |
| ASSERT_TRUE(module->weight.defined()); |
| ASSERT_EQ(module->weight.dtype(), to_dtype); |
| ASSERT_FALSE(module->running_mean.defined()); |
| } |
| } |
| |
| TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) { |
| test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCPU, torch::kDouble); |
| } |
| |
| TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) { |
| test_DeviceOrDtypeConversionSkipsUndefinedTensor( |
| torch::kCUDA, torch::kDouble); |
| } |
| |
| TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { |
| { |
| Linear module(LinearOptions(10, 20).bias(false)); |
| |
| auto params = module->parameters(); |
| ASSERT_EQ(params.size(), 1); |
| auto named_params = module->named_parameters(); |
| ASSERT_EQ(named_params.size(), 1); |
| |
| ASSERT_TRUE(pointer_equal(params[0], named_params["weight"])); |
| ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight)); |
| } |
| { |
| BatchNorm1d module( |
| BatchNorm1dOptions(5).track_running_stats(false).affine(false)); |
| |
| auto buffers = module->buffers(); |
| ASSERT_EQ(buffers.size(), 0); |
| auto named_buffers = module->named_buffers(); |
| ASSERT_EQ(named_buffers.size(), 0); |
| } |
| { |
| BatchNorm1d module( |
| BatchNorm1dOptions(5).track_running_stats(true).affine(false)); |
| |
| auto buffers = module->buffers(); |
| ASSERT_EQ(buffers.size(), 3); |
| auto named_buffers = module->named_buffers(); |
| ASSERT_EQ(named_buffers.size(), 3); |
| |
| ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"])); |
| ASSERT_TRUE( |
| pointer_equal(named_buffers["running_mean"], module->running_mean)); |
| ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"])); |
| ASSERT_TRUE( |
| pointer_equal(named_buffers["running_var"], module->running_var)); |
| ASSERT_TRUE( |
| pointer_equal(buffers[2], named_buffers["num_batches_tracked"])); |
| ASSERT_TRUE(pointer_equal( |
| named_buffers["num_batches_tracked"], module->num_batches_tracked)); |
| } |
| } |
| |
| TEST_F(ModuleTest, Conversion_MultiCUDA) { |
| Linear module(128, 64); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU)); |
| ASSERT_EQ(parameter.dtype(), torch::kFloat32); |
| } |
| { |
| module->to({torch::kCUDA, 0}); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
| ASSERT_EQ(parameter.device().index(), 0); |
| } |
| module->to({torch::kCUDA, 1}); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
| ASSERT_EQ(parameter.device().index(), 1); |
| } |
| } |
| { |
| module->to(torch::Device(torch::kCPU)); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU); |
| } |
| } |
| { |
| module->to(torch::kFloat64); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.dtype(), torch::kFloat64); |
| } |
| } |
| } |
| |
| TEST_F(ModuleTest, Conversion_NoGrad_MultiCUDA) { |
| Linear module(128, 64); |
| for (auto& parameter : module->parameters()) { |
| parameter.requires_grad_(false); |
| } |
| { |
| module->to(torch::kInt32); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.dtype(), torch::kInt32); |
| } |
| } |
| { |
| module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8); |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA); |
| ASSERT_EQ(parameter.device().index(), 1); |
| } |
| for (auto& parameter : module->parameters()) { |
| ASSERT_EQ(parameter.dtype(), torch::kUInt8); |
| } |
| } |
| } |
| |
| TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) { |
| struct UnCloneable : Module {}; |
| UnCloneable module; |
| ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented"); |
| } |
| |
| TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) { |
| struct Cloneable : Module { |
| std::shared_ptr<Module> clone( |
| const torch::optional<torch::Device>& device = |
| torch::nullopt) const override { |
| return nullptr; |
| } |
| }; |
| Cloneable module; |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| ASSERT_NO_THROW({ module.clone(); }); |
| } |
| |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct TestDistinctParametersModule |
| : public Cloneable<TestDistinctParametersModule> { |
| TestDistinctParametersModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| l1 = register_module("l1", Linear(10, 3)); |
| l2 = register_module("l2", Linear(3, 5)); |
| l3 = register_module("l3", Linear(5, 100)); |
| buffer = register_buffer("buf", torch::ones({2, 2})); |
| } |
| |
| Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
| torch::Tensor buffer; |
| }; |
| |
| void testDistinctParameters( |
| std::shared_ptr<Module> m1, |
| std::shared_ptr<Module> m2) { |
| auto params1 = m1->named_parameters(); |
| auto params2 = m2->named_parameters(); |
| ASSERT_EQ(params1.size(), 6); |
| ASSERT_EQ(params2.size(), 6); |
| for (auto& param : params1) { |
| ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
| ASSERT_TRUE(param->allclose(params2[param.key()])); |
| param->add_(2); |
| } |
| for (auto& param : params1) { |
| ASSERT_FALSE(param->allclose(params2[param.key()])); |
| } |
| |
| auto buffers1 = m1->named_buffers(); |
| auto buffers2 = m2->named_buffers(); |
| ASSERT_EQ(buffers1.size(), 1); |
| ASSERT_EQ(buffers2.size(), 1); |
| for (auto& buffer : buffers1) { |
| ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()])); |
| ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()])); |
| buffer->add_(2); |
| } |
| for (auto& buffer : buffers1) { |
| ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()])); |
| } |
| } |
| |
| TEST_F(ModuleTest, CloneCreatesDistinctParameters) { |
| auto module = std::make_shared<TestDistinctParametersModule>(); |
| torch::NoGradGuard no_grad; |
| auto module2 = module->clone(); |
| testDistinctParameters(module, module2); |
| } |
| |
| TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_CUDA) { |
| auto module = std::make_shared<TestDistinctParametersModule>(); |
| torch::NoGradGuard no_grad; |
| torch::Device device(torch::kCUDA, 0); |
| module->to(device); |
| auto module2 = module->clone(device); |
| testDistinctParameters(module, module2); |
| } |
| |
| TEST_F(ModuleTest, CloneCreatesDistinctParametersExplicitDevice_MultiCUDA) { |
| auto module = std::make_shared<TestDistinctParametersModule>(); |
| torch::NoGradGuard no_grad; |
| torch::Device d0(torch::kCUDA, 0); |
| torch::Device d1(torch::kCUDA, 1); |
| module->to(d0); |
| auto module2 = module->clone(d1); |
| |
| for (auto& param : module->parameters()) { |
| ASSERT_EQ(param.device(), d0); |
| } |
| |
| for (auto& param : module2->parameters()) { |
| ASSERT_EQ(param.device(), d1); |
| } |
| |
| // need to move the module back to d0 as allclose expects two tensors on |
| // the same device. |
| module2->to(d0); |
| testDistinctParameters(module, module2); |
| } |
| |
| TEST_F(ModuleTest, ClonePreservesExternalReferences) { |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct TestModule : public Cloneable<TestModule> { |
| TestModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| weight = register_parameter("weight", torch::ones({4, 4})); |
| } |
| torch::Tensor weight; |
| }; |
| auto module = std::make_shared<TestModule>(); |
| { |
| torch::NoGradGuard no_grad; |
| module->weight += 1; |
| } |
| ASSERT_TRUE( |
| pointer_equal(module->weight, module->named_parameters()["weight"])); |
| ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"])); |
| |
| auto module2 = std::dynamic_pointer_cast<TestModule>( |
| std::shared_ptr<Module>(module->clone())); |
| ASSERT_FALSE(pointer_equal(module2->weight, module->weight)); |
| ASSERT_TRUE( |
| pointer_equal(module2->weight, module2->named_parameters()["weight"])); |
| ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"])); |
| ASSERT_TRUE(module2->weight.allclose(module->weight)); |
| ASSERT_FALSE( |
| pointer_equal(module2->weight, module->named_parameters()["weight"])); |
| } |
| |
| TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) { |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct TestModule : public Cloneable<TestModule> { |
| TestModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| weight = register_parameter("weight", torch::ones({4, 4})); |
| } |
| |
| torch::Tensor weight; |
| int value = 0; |
| }; |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct NestedModule : public Cloneable<NestedModule> { |
| NestedModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| module = register_module("module", std::make_shared<TestModule>()); |
| } |
| std::shared_ptr<TestModule> module; |
| }; |
| |
| auto a = std::make_shared<NestedModule>(); |
| { |
| torch::NoGradGuard no_grad; |
| a->module->weight += 1; |
| a->module->value = 123; |
| } |
| |
| auto b = std::dynamic_pointer_cast<NestedModule>(a->clone()); |
| |
| ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight)); |
| ASSERT_TRUE(pointer_equal( |
| b->module->weight, b->module->named_parameters()["weight"])); |
| ASSERT_TRUE( |
| b->module->named_parameters()["weight"].allclose(a->module->weight)); |
| ASSERT_TRUE(b->module->weight.allclose(a->module->weight)); |
| ASSERT_EQ(b->module->value, a->module->value); |
| } |
| |
| TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) { |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct TestModule : public Cloneable<TestModule> { |
| TestModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| l1 = register_module("l1", Linear(10, 3)); |
| l2 = register_module("l2", Linear(3, 5)); |
| l3 = register_module("l3", Linear(5, 100)); |
| buffer = register_buffer("buf", torch::ones({2, 2})); |
| } |
| |
| Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
| torch::Tensor buffer; |
| }; |
| |
| TestModule m; |
| torch::Device device(torch::kCUDA, 0); |
| |
| m.to(device); |
| |
| auto clone = m.clone(); |
| for (const auto& parameter : clone->parameters()) { |
| ASSERT_EQ(parameter.device().type(), device.type()); |
| ASSERT_EQ(parameter.device().index(), device.index()); |
| } |
| for (const auto& buffer : clone->buffers()) { |
| ASSERT_EQ(buffer.device().type(), device.type()); |
| ASSERT_EQ(buffer.device().index(), device.index()); |
| } |
| } |
| |
| TEST_F( |
| ModuleTest, |
| CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) { |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| struct TestModule : public Cloneable<TestModule> { |
| TestModule() { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| reset(); |
| } |
| void reset() override { |
| l1 = register_module("l1", Linear(10, 3)); |
| l2 = register_module("l2", Linear(3, 5)); |
| l3 = register_module("l3", Linear(5, 100)); |
| buffer = register_buffer("buf", torch::ones({2, 2})); |
| } |
| |
| Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; |
| torch::Tensor buffer; |
| }; |
| |
| TestModule m; |
| torch::Device device(torch::kCUDA, 1); |
| // everything is on CPU here |
| auto clone = m.clone(device); |
| for (const auto& parameter : clone->parameters()) { |
| ASSERT_EQ(parameter.device().type(), device.type()); |
| ASSERT_EQ(parameter.device().index(), device.index()); |
| } |
| for (const auto& buffer : clone->buffers()) { |
| ASSERT_EQ(buffer.device().type(), device.type()); |
| ASSERT_EQ(buffer.device().index(), device.index()); |
| } |
| } |
| |
| struct ParameterTestModule : Module { |
| ParameterTestModule() { |
| a = register_parameter("a", torch::zeros({2, 2})); |
| b = register_parameter("b", torch::ones({2, 2})); |
| c = register_parameter("c", torch::ones({2, 2}) * 2); |
| } |
| |
| torch::Tensor a, b, c; |
| }; |
| |
| TEST_F(ModuleTest, HasCorrectNumberOfParameters) { |
| ParameterTestModule module; |
| ASSERT_EQ(module.parameters().size(), 3); |
| ASSERT_EQ(module.named_parameters().size(), 3); |
| } |
| |
| TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) { |
| ParameterTestModule module; |
| auto parameters = module.named_parameters(); |
| ASSERT_TRUE(parameters.contains("a")); |
| ASSERT_TRUE(parameters.contains("b")); |
| ASSERT_TRUE(parameters.contains("c")); |
| } |
| |
| struct BufferTestModule : Module { |
| BufferTestModule() { |
| a = register_buffer("a", torch::zeros({2, 2})); |
| b = register_buffer("b", torch::ones({2, 2})); |
| c = register_buffer("c", torch::ones({2, 2}) * 2); |
| } |
| |
| torch::Tensor a, b, c; |
| }; |
| |
| TEST_F(ModuleTest, HasCorrectNumberOfBuffers) { |
| BufferTestModule module; |
| ASSERT_EQ(module.buffers().size(), 3); |
| ASSERT_EQ(module.named_buffers().size(), 3); |
| } |
| |
| TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) { |
| BufferTestModule module; |
| auto buffers = module.named_buffers(); |
| ASSERT_TRUE(buffers.contains("a")); |
| ASSERT_TRUE(buffers.contains("b")); |
| ASSERT_TRUE(buffers.contains("c")); |
| } |
| |
| struct AImpl : torch::nn::Module { |
| AImpl() : x_(123) {} |
| AImpl(int x) : x_(x) {} |
| int x_; |
| }; |
| TORCH_MODULE(A); |
| |
| TEST_F( |
| ModuleTest, |
| DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) { |
| A a; |
| ASSERT_TRUE(a); |
| ASSERT_FALSE(a.is_empty()); |
| ASSERT_EQ(a->x_, 123); |
| } |
| |
| TEST_F( |
| ModuleTest, |
| ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) { |
| A a(5); |
| ASSERT_TRUE(a); |
| ASSERT_FALSE(a.is_empty()); |
| ASSERT_EQ(a->x_, 5); |
| } |
| |
| TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) { |
| A a = nullptr; |
| ASSERT_FALSE(a); |
| ASSERT_TRUE(a.is_empty()); |
| ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder"); |
| } |
| |
| struct TestModule : public torch::nn::Module { |
| TestModule(int64_t size) { |
| p1 = register_parameter("p1", torch::randn({size})); |
| p2 = register_parameter("p2", torch::randn({size})); |
| b1 = register_buffer("b1", torch::randn({size})); |
| b2 = register_buffer("b2", torch::randn({size})); |
| } |
| |
| torch::Tensor forward(torch::Tensor input) { |
| return input; |
| } |
| |
| torch::Tensor p1, p2, b1, b2; |
| }; |
| |
| TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules(); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model.ptr(), model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].get(), expected[i].get()); |
| } |
| } |
| |
| TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| std::vector<std::shared_ptr<torch::nn::Module>> modules = |
| model->modules(/*include_self=*/false); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].get(), expected[i].get()); |
| } |
| } |
| |
| TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
| model->named_modules(); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model.ptr(), model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string()); |
| ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
| } |
| } |
| |
| TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
| model->named_modules( |
| /*name_prefix=*/std::string(), /*include_self=*/false); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].key(), std::to_string(i)); |
| ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
| } |
| } |
| |
| TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children(); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].get(), expected[i].get()); |
| } |
| |
| // For this flat model, this should be true. |
| ASSERT_EQ(modules, model->modules(/*include_self=*/false)); |
| } |
| |
| TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) { |
| torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3)); |
| torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
| model->named_children(); |
| std::vector<std::shared_ptr<torch::nn::Module>> expected = { |
| model[0], model[1], model[2]}; |
| ASSERT_EQ(modules.size(), expected.size()); |
| for (const auto i : c10::irange(expected.size())) { |
| // Assert pointer equality. |
| ASSERT_EQ(modules[i].key(), std::to_string(i)); |
| ASSERT_EQ(modules[i].value().get(), expected[i].get()); |
| } |
| } |
| |
| TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) { |
| TestModule module(1); |
| std::vector<torch::Tensor> parameters = module.parameters(); |
| ASSERT_EQ(parameters.size(), 2); |
| ASSERT_EQ(parameters[0].data_ptr<float>(), module.p1.data_ptr<float>()); |
| ASSERT_EQ(parameters[1].data_ptr<float>(), module.p2.data_ptr<float>()); |
| } |
| |
| TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) { |
| TestModule module(1); |
| torch::OrderedDict<std::string, torch::Tensor> parameters = |
| module.named_parameters(); |
| ASSERT_EQ(parameters.size(), 2); |
| ASSERT_EQ(parameters[0].key(), "p1"); |
| ASSERT_EQ(parameters[0]->data_ptr<float>(), module.p1.data_ptr<float>()); |
| ASSERT_EQ(parameters[1].key(), "p2"); |
| ASSERT_EQ(parameters[1]->data_ptr<float>(), module.p2.data_ptr<float>()); |
| } |
| |
| TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) { |
| TestModule module(1); |
| std::vector<torch::Tensor> buffers = module.buffers(); |
| ASSERT_EQ(buffers.size(), 2); |
| ASSERT_EQ(buffers[0].data_ptr<float>(), module.b1.data_ptr<float>()); |
| ASSERT_EQ(buffers[1].data_ptr<float>(), module.b2.data_ptr<float>()); |
| } |
| |
| TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) { |
| TestModule module(1); |
| torch::OrderedDict<std::string, torch::Tensor> buffers = |
| module.named_buffers(); |
| ASSERT_EQ(buffers.size(), 2); |
| ASSERT_EQ(buffers[0].key(), "b1"); |
| ASSERT_EQ(buffers[0]->data_ptr<float>(), module.b1.data_ptr<float>()); |
| ASSERT_EQ(buffers[1].key(), "b2"); |
| ASSERT_EQ(buffers[1]->data_ptr<float>(), module.b2.data_ptr<float>()); |
| } |
| |
| struct TestContainer : torch::nn::Module { |
| TestContainer(int64_t number, std::vector<TestContainer> modules = {}) |
| : tensor(torch::tensor(number)) { |
| for (const auto i : c10::irange(modules.size())) { |
| register_module( |
| std::to_string(i), |
| std::make_shared<TestContainer>(std::move(modules[i]))); |
| } |
| } |
| torch::Tensor tensor; |
| }; |
| |
| int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) { |
| return std::dynamic_pointer_cast<TestContainer>(module) |
| ->tensor.item<int64_t>(); |
| } |
| |
| std::shared_ptr<TestContainer> make_deeply_nested_test_container() { |
| return std::make_shared<TestContainer>(TestContainer( |
| 0, |
| {TestContainer(1, {TestContainer(2), TestContainer(3)}), |
| TestContainer(4), |
| TestContainer( |
| 5, |
| {TestContainer(6), |
| TestContainer(7, {TestContainer(8), TestContainer(9)})})})); |
| } |
| |
| std::vector<std::pair<std::string, int64_t>> |
| make_key_value_pairs_for_deeply_nested_container() { |
| return { |
| {"test_prefix", 0}, |
| {"test_prefix.0", 1}, |
| {"test_prefix.0.0", 2}, |
| {"test_prefix.0.1", 3}, |
| {"test_prefix.1", 4}, |
| {"test_prefix.2", 5}, |
| {"test_prefix.2.0", 6}, |
| {"test_prefix.2.1", 7}, |
| {"test_prefix.2.1.0", 8}, |
| {"test_prefix.2.1.1", 9}}; |
| } |
| |
| TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) { |
| auto model = make_deeply_nested_test_container(); |
| std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules(); |
| |
| ASSERT_EQ(modules.size(), 10); |
| for (const auto i : c10::irange(modules.size())) { |
| ASSERT_EQ(get_test_container_item(modules[i]), i); |
| } |
| } |
| |
| TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) { |
| auto model = make_deeply_nested_test_container(); |
| torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
| model->named_modules(/*name_prefix=*/"test_prefix"); |
| auto expected = make_key_value_pairs_for_deeply_nested_container(); |
| |
| ASSERT_EQ(modules.size(), expected.size()); |
| |
| for (const auto i : c10::irange(expected.size())) { |
| ASSERT_EQ(modules[i].key(), expected[i].first); |
| ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second); |
| } |
| } |
| |
| TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) { |
| auto model = make_deeply_nested_test_container(); |
| std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children(); |
| |
| ASSERT_EQ(modules.size(), 3); |
| ASSERT_EQ(get_test_container_item(modules[0]), 1); |
| ASSERT_EQ(get_test_container_item(modules[1]), 4); |
| ASSERT_EQ(get_test_container_item(modules[2]), 5); |
| } |
| |
| TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) { |
| auto model = make_deeply_nested_test_container(); |
| torch::OrderedDict<std::string, std::shared_ptr<torch::nn::Module>> modules = |
| model->named_children(); |
| |
| ASSERT_EQ(modules.size(), 3); |
| |
| ASSERT_EQ(get_test_container_item(modules[0].value()), 1); |
| ASSERT_EQ(modules[0].key(), "0"); |
| |
| ASSERT_EQ(get_test_container_item(modules[1].value()), 4); |
| ASSERT_EQ(modules[1].key(), "1"); |
| |
| ASSERT_EQ(get_test_container_item(modules[2].value()), 5); |
| ASSERT_EQ(modules[2].key(), "2"); |
| } |
| |
| TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) { |
| auto model = make_deeply_nested_test_container(); |
| int64_t index = 0; |
| model->apply([&index](torch::nn::Module& module) { |
| ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++); |
| }); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) { |
| std::shared_ptr<const TestContainer> model = |
| make_deeply_nested_test_container(); |
| int64_t index = 0; |
| model->apply([&index](const torch::nn::Module& module) { |
| ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++); |
| }); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) { |
| auto model = make_deeply_nested_test_container(); |
| auto expected = make_key_value_pairs_for_deeply_nested_container(); |
| int64_t index = 0; |
| model->apply( |
| [&index, expected](const std::string& name, torch::nn::Module& module) { |
| ASSERT_EQ(name, expected[index].first); |
| ASSERT_EQ( |
| module.as<TestContainer>()->tensor.item<int64_t>(), |
| expected[index++].second); |
| }, |
| /*name_prefix=*/"test_prefix"); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) { |
| std::shared_ptr<const TestContainer> model = |
| make_deeply_nested_test_container(); |
| auto expected = make_key_value_pairs_for_deeply_nested_container(); |
| int64_t index = 0; |
| model->apply( |
| [&index, &expected]( |
| const std::string& name, const torch::nn::Module& module) { |
| ASSERT_EQ(name, expected[index].first); |
| ASSERT_EQ( |
| module.as<const TestContainer>()->tensor.item<int64_t>(), |
| expected[index++].second); |
| }, |
| /*name_prefix=*/"test_prefix"); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) { |
| auto model = make_deeply_nested_test_container(); |
| int64_t index = 0; |
| model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) { |
| ASSERT_EQ(get_test_container_item(module), index++); |
| }); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) { |
| auto model = make_deeply_nested_test_container(); |
| auto expected = make_key_value_pairs_for_deeply_nested_container(); |
| int64_t index = 0; |
| model->apply( |
| [&index, &expected]( |
| const std::string& name, |
| const std::shared_ptr<torch::nn::Module>& module) { |
| ASSERT_EQ(name, expected[index].first); |
| ASSERT_EQ(get_test_container_item(module), expected[index++].second); |
| }, |
| /*name_prefix=*/"test_prefix"); |
| ASSERT_EQ(index, 10); |
| } |
| |
| TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) { |
| { |
| TestModule module(1); |
| ASSERT_THROWS_WITH( |
| module.modules(), |
| "It looks like you attempted to retrieve " |
| "your top-level module as a shared_ptr") |
| } |
| { |
| TestModule module(1); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| ASSERT_NO_THROW(module.modules(/*include_self=*/false)); |
| } |
| { |
| auto module = std::make_shared<TestModule>(1); |
| // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) |
| ASSERT_NO_THROW(module->modules()); |
| } |
| } |
| |
| struct EmptyModule : torch::nn::Module {}; |
| |
| TEST_F(ModuleTest, PrettyPrint) { |
| struct TestModule : torch::nn::Module { |
| TestModule(int x, float y) : x_(x), y_(y) {} |
| |
| void pretty_print(std::ostream& stream) const override { |
| stream << "TestModule(x=" << x_ << ", y=" << y_ << ")"; |
| } |
| |
| int x_; |
| float y_; |
| }; |
| |
| ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule"); |
| ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)"); |
| } |
| |
| struct ModuleWithNonTensorForwardImpl : torch::nn::Module { |
| int64_t forward(torch::Tensor x) { |
| return x.numel(); |
| } |
| }; |
| TORCH_MODULE(ModuleWithNonTensorForward); |
| |
| TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) { |
| ModuleWithNonTensorForward m; |
| ASSERT_EQ(m(torch::ones(123)), 123); |
| } |