| #include <gtest/gtest.h> |
| |
| #include <c10/util/irange.h> |
| #include <torch/torch.h> |
| |
| #include <algorithm> |
| #include <memory> |
| #include <vector> |
| |
| #include <test/cpp/api/support.h> |
| |
| using namespace torch::nn; |
| using namespace torch::test; |
| |
| struct ParameterListTest : torch::test::SeedingFixture {}; |
| |
| TEST_F(ParameterListTest, ConstructsFromSharedPointer) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| ASSERT_TRUE(ta.requires_grad()); |
| ASSERT_FALSE(tb.requires_grad()); |
| ParameterList list(ta, tb, tc); |
| ASSERT_EQ(list->size(), 3); |
| } |
| |
| TEST_F(ParameterListTest, isEmpty) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| ParameterList list; |
| ASSERT_TRUE(list->is_empty()); |
| list->append(ta); |
| ASSERT_FALSE(list->is_empty()); |
| ASSERT_EQ(list->size(), 1); |
| } |
| |
| TEST_F(ParameterListTest, PushBackAddsAnElement) { |
| ParameterList list; |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| torch::Tensor td = torch::randn({1, 2, 3}); |
| ASSERT_EQ(list->size(), 0); |
| ASSERT_TRUE(list->is_empty()); |
| list->append(ta); |
| ASSERT_EQ(list->size(), 1); |
| list->append(tb); |
| ASSERT_EQ(list->size(), 2); |
| list->append(tc); |
| ASSERT_EQ(list->size(), 3); |
| list->append(td); |
| ASSERT_EQ(list->size(), 4); |
| } |
| TEST_F(ParameterListTest, ForEachLoop) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| torch::Tensor td = torch::randn({1, 2, 3}); |
| ParameterList list(ta, tb, tc, td); |
| std::vector<torch::Tensor> params = {ta, tb, tc, td}; |
| ASSERT_EQ(list->size(), 4); |
| int idx = 0; |
| for (const auto& pair : *list) { |
| ASSERT_TRUE( |
| torch::all(torch::eq(pair.value(), params[idx++])).item<bool>()); |
| } |
| } |
| |
| TEST_F(ParameterListTest, AccessWithAt) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| torch::Tensor td = torch::randn({1, 2, 3}); |
| std::vector<torch::Tensor> params = {ta, tb, tc, td}; |
| |
| ParameterList list; |
| for (auto& param : params) { |
| list->append(param); |
| } |
| ASSERT_EQ(list->size(), 4); |
| |
| // returns the correct module for a given index |
| for (const auto i : c10::irange(params.size())) { |
| ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>()); |
| } |
| |
| for (const auto i : c10::irange(params.size())) { |
| ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>()); |
| } |
| |
| // throws for a bad index |
| ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range"); |
| ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range"); |
| ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range"); |
| } |
| |
| TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| torch::Tensor td = torch::randn({1, 2, 3}); |
| torch::Tensor te = torch::randn({1, 2}); |
| torch::Tensor tf = torch::randn({1, 2, 3}); |
| ParameterList a(ta, tb); |
| ParameterList b(tc, td); |
| a->extend(*b); |
| |
| ASSERT_EQ(a->size(), 4); |
| ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>()); |
| |
| ASSERT_EQ(b->size(), 2); |
| ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>()); |
| |
| std::vector<torch::Tensor> c = {te, tf}; |
| b->extend(c); |
| |
| ASSERT_EQ(b->size(), 4); |
| ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>()); |
| } |
| |
| TEST_F(ParameterListTest, PrettyPrintParameterList) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| ParameterList list(ta, tb, tc); |
| ASSERT_EQ( |
| c10::str(list), |
| "torch::nn::ParameterList(\n" |
| "(0): Parameter containing: [Float of size [1, 2]]\n" |
| "(1): Parameter containing: [Float of size [1, 2]]\n" |
| "(2): Parameter containing: [Float of size [1, 2]]\n" |
| ")"); |
| } |
| |
| TEST_F(ParameterListTest, IncrementAdd) { |
| torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); |
| torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); |
| torch::Tensor tc = torch::randn({1, 2}); |
| torch::Tensor td = torch::randn({1, 2, 3}); |
| torch::Tensor te = torch::randn({1, 2}); |
| torch::Tensor tf = torch::randn({1, 2, 3}); |
| ParameterList listA(ta, tb, tc); |
| ParameterList listB(td, te, tf); |
| std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf}; |
| int idx = 0; |
| *listA += *listB; |
| ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>()); |
| ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>()); |
| for (const auto& P : listA->named_parameters(false)) |
| ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>()); |
| |
| ASSERT_EQ(idx, 6); |
| } |