| #include <gtest/gtest.h> |
| #include <torch/torch.h> |
| #include <algorithm> |
| #include <memory> |
| #include <vector> |
| |
| #include <test/cpp/api/support.h> |
| |
| using namespace torch::nn; |
| using namespace torch::test; |
| |
| struct ModuleDictTest : torch::test::SeedingFixture {}; |
| |
| TEST_F(ModuleDictTest, ConstructsFromList) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = { |
| {"module_1", std::make_shared<M>(1)}, |
| {"module_2", std::make_shared<M>(2)}, |
| {"module_3", std::make_shared<M>(3)}}; |
| ModuleDict dict(list); |
| ASSERT_EQ(dict->size(), 3); |
| } |
| |
| TEST_F(ModuleDictTest, ConstructsFromordereddict) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"module_1", std::make_shared<M>(1)}, |
| {"module_2", std::make_shared<M>(2)}, |
| {"module_3", std::make_shared<M>(3)}, |
| }; |
| ModuleDict dict(ordereddict); |
| ASSERT_EQ(dict->size(), 3); |
| } |
| |
| TEST_F(ModuleDictTest, UpdatePopClearContains) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| |
| ModuleDict dict; |
| ASSERT_TRUE(dict->empty()); |
| // Update by List |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = { |
| {"module_1", std::make_shared<M>(1)}}; |
| dict->update(list1); |
| ASSERT_EQ(dict->size(), 1); |
| ASSERT_TRUE(dict->contains("module_1")); |
| // Update by OrderedDict |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"module_2", std::make_shared<M>(2)}}; |
| dict->update(ordereddict); |
| ASSERT_EQ(dict->size(), 2); |
| ASSERT_TRUE(dict->contains("module_2")); |
| // Update by another ModuleDict |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = { |
| {"module_3", std::make_shared<M>(3)}}; |
| ModuleDict updatedict(list2); |
| dict->update(*updatedict); |
| ASSERT_EQ(dict->size(), 3); |
| ASSERT_TRUE(dict->contains("module_3")); |
| // Pop |
| dict->pop("module_1"); |
| ASSERT_EQ(dict->size(), 2); |
| // Pop unexist |
| ASSERT_THROWS_WITH(dict->pop("module_4"), " 'module_4' is not defined"); |
| // Clear |
| dict->clear(); |
| ASSERT_EQ(dict->size(), 0); |
| } |
| |
| TEST_F(ModuleDictTest, UpdateExist) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list1 = { |
| {"module_1", std::make_shared<M>(1)}, |
| {"module_2", std::make_shared<M>(2)}}; |
| ModuleDict dict(list1); |
| ASSERT_EQ(dict->at<M>("module_2").value, 2); |
| // Update by list |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list2 = { |
| {"module_2", std::make_shared<M>(0)}, |
| {"module_3", std::make_shared<M>(3)}}; |
| dict->update(list2); |
| ASSERT_EQ(dict->size(), 3); |
| ASSERT_EQ(dict->at<M>("module_2").value, 0); |
| // Update by ordereddict |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"module_3", std::make_shared<M>(0)}, |
| {"module_4", std::make_shared<M>(4)}}; |
| dict->update(ordereddict); |
| ASSERT_EQ(dict->size(), 4); |
| ASSERT_EQ(dict->at<M>("module_3").value, 0); |
| // Update by ModuleDict |
| std::vector<std::pair<std::string, std::shared_ptr<Module>>> list3 = { |
| {"module_4", std::make_shared<M>(0)}, |
| {"module_1", std::make_shared<M>(0)}}; |
| ModuleDict dict2(list3); |
| dict->update(*dict2); |
| ASSERT_EQ(dict->size(), 4); |
| ASSERT_EQ(dict->at<M>("module_1").value, 0); |
| ASSERT_EQ(dict->at<M>("module_4").value, 0); |
| } |
| |
| TEST_F(ModuleDictTest, Keys) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(10, 3).ptr()}, |
| {"conv", Conv2d(1, 2, 3).ptr()}, |
| {"dropout", Dropout(0.5).ptr()}, |
| }; |
| ModuleDict dict(ordereddict); |
| const auto& keys = dict->keys(); |
| std::vector<std::string> expected{"linear", "conv", "dropout"}; |
| ASSERT_EQ(keys, expected); |
| ASSERT_THROWS_WITH(dict["batch"], " 'batch' is not defined"); |
| |
| ASSERT_TRUE(dict["linear"]->as<Linear>()); |
| ASSERT_TRUE(dict["conv"]->as<Conv2d>()); |
| ASSERT_TRUE(dict["dropout"]->as<Dropout>()); |
| } |
| |
| TEST_F(ModuleDictTest, Values) { |
| struct M : Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| }; |
| |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"module_1", std::make_shared<M>(1)}, |
| {"module_2", std::make_shared<M>(2)}, |
| }; |
| ModuleDict dict(ordereddict); |
| const auto& values = dict->values(); |
| const auto& expected = ordereddict.values(); |
| ASSERT_EQ(values, expected); |
| ASSERT_TRUE(std::equal( |
| dict->begin(), |
| dict->end(), |
| ordereddict.begin(), |
| [](const auto& lhs, const auto& rhs) { |
| return lhs.value().get() == rhs.value().get(); |
| })); |
| } |
| |
| TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(10, 3).ptr()}, |
| {"conv", Conv2d(1, 2, 3).ptr()}, |
| {"dropout", Dropout(0.5).ptr()}, |
| {"batch", BatchNorm2d(5).ptr()}, |
| {"embedding", Embedding(4, 10).ptr()}, |
| {"lstm", LSTM(4, 5).ptr()}}; |
| ModuleDict dict(ordereddict); |
| } |
| |
| TEST_F(ModuleDictTest, HasReferenceSemantics) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear1", Linear(2, 3).ptr()}, |
| {"linear2", Linear(3, 4).ptr()}, |
| {"linear3", Linear(4, 5).ptr()}, |
| }; |
| ModuleDict first(ordereddict); |
| ModuleDict second(ordereddict); |
| |
| ASSERT_EQ(first->size(), second->size()); |
| ASSERT_TRUE(std::equal( |
| first->begin(), |
| first->end(), |
| second->begin(), |
| [](const auto& lhs, const auto& rhs) { |
| return lhs.value().get() == rhs.value().get(); |
| })); |
| } |
| |
| void iscloneable_helper(torch::Device device) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(2, 3).ptr()}, |
| {"relu", Functional(torch::relu).ptr()}, |
| {"batch", BatchNorm1d(3).ptr()}, |
| }; |
| ModuleDict dict(ordereddict); |
| dict->to(device); |
| ModuleDict clone = |
| std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device)); |
| ASSERT_EQ(dict->size(), clone->size()); |
| |
| for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); |
| ++it, ++it_c) { |
| // The key should be same |
| ASSERT_EQ(it->key(), it_c->key()); |
| // The modules should be the same kind (type). |
| ASSERT_EQ(it->value()->name(), it_c->value()->name()); |
| // But not pointer-equal (distinct objects). |
| ASSERT_NE(it->value(), it_c->value()); |
| } |
| |
| // Verify that the clone is deep, i.e. parameters of modules are cloned too. |
| torch::NoGradGuard no_grad; |
| |
| auto params1 = dict->named_parameters(); |
| auto params2 = clone->named_parameters(); |
| ASSERT_EQ(params1.size(), params2.size()); |
| for (auto& param : params1) { |
| ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
| ASSERT_EQ(param->device(), params2[param.key()].device()); |
| ASSERT_TRUE(param->allclose(params2[param.key()])); |
| param->add_(2); |
| } |
| for (auto& param : params1) { |
| ASSERT_FALSE(param->allclose(params2[param.key()])); |
| } |
| } |
| |
| TEST_F(ModuleDictTest, IsCloneable) { |
| iscloneable_helper(torch::kCPU); |
| } |
| |
| TEST_F(ModuleDictTest, IsCloneable_CUDA) { |
| iscloneable_helper({torch::kCUDA, 0}); |
| } |
| |
| TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict1 = { |
| {"linear", Linear(10, 3).ptr()}, |
| {"conv", Conv2d(1, 2, 3).ptr()}, |
| {"test", Dropout(0.5).ptr()}, |
| }; |
| ModuleDict dict(ordereddict1); |
| |
| auto modules = dict->children(); |
| ASSERT_TRUE(modules[0]->as<Linear>()); |
| ASSERT_TRUE(modules[1]->as<Conv2d>()); |
| ASSERT_TRUE(modules[2]->as<Dropout>()); |
| |
| // Update Existing |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict2 = { |
| {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}}; |
| dict->update(ordereddict2); |
| |
| modules = dict->children(); |
| ASSERT_TRUE(modules[0]->as<Linear>()); |
| ASSERT_TRUE(modules[1]->as<Conv2d>()); |
| // Keep Order |
| ASSERT_TRUE(modules[2]->as<BatchNorm2d>()); |
| ASSERT_TRUE(modules[3]->as<LSTM>()); |
| } |
| |
| TEST_F(ModuleDictTest, CloneToDevice_CUDA) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(2, 3).ptr()}, |
| {"relu", Functional(torch::relu).ptr()}, |
| {"batch", BatchNorm1d(3).ptr()}, |
| }; |
| ModuleDict dict(ordereddict); |
| torch::Device device(torch::kCUDA, 0); |
| ModuleDict clone = |
| std::dynamic_pointer_cast<ModuleDictImpl>(dict->clone(device)); |
| for (const auto& p : clone->parameters()) { |
| ASSERT_EQ(p.device(), device); |
| } |
| for (const auto& b : clone->buffers()) { |
| ASSERT_EQ(b.device(), device); |
| } |
| } |
| |
| TEST_F(ModuleDictTest, PrettyPrintModuleDict) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(10, 3).ptr()}, |
| {"conv", Conv2d(1, 2, 3).ptr()}, |
| {"dropout", Dropout(0.5).ptr()}, |
| {"batch", BatchNorm2d(5).ptr()}, |
| {"embedding", Embedding(4, 10).ptr()}, |
| {"lstm", LSTM(4, 5).ptr()}}; |
| ModuleDict dict(ordereddict); |
| |
| ASSERT_EQ( |
| c10::str(dict), |
| "torch::nn::ModuleDict(\n" |
| " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" |
| " (conv): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" |
| " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n" |
| " (batch): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" |
| " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" |
| " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" |
| ")"); |
| } |
| |
| TEST_F(ModuleDictTest, InvalidAt) { |
| torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { |
| {"linear", Linear(10, 3).ptr()}}; |
| ModuleDict dict(ordereddict); |
| ASSERT_THROWS_WITH( |
| dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module"); |
| } |