| # Owner(s): ["module: functorch"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| from functorch.dim import Tensor, Dim, dims, dimlists, stack, DimensionBindError, DimList |
| |
| from attn_ft import BertSelfAttention as BertSelfAttentionA, Linear |
| from attn_positional import BertSelfAttention as BertSelfAttentionB |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests, TEST_CUDA |
| |
| from unittest import skip, skipIf |
| import torch |
| import gc |
| |
| from functorch._C import dim as _C |
| |
| try: |
| from torchvision.models import resnet18 |
| except ImportError: |
| resnet18 = None |
| |
| _test_c, _parse_test, _set_pointwise_optimize = _C._test_c, _C._parse_test, _C._set_pointwise_optimize |
| |
| from contextlib import contextmanager |
| from time import perf_counter |
| |
| measure_perf = False |
| if measure_perf: |
| from torchdim.magic_trace import magic_trace |
| else: |
| @contextmanager |
| def magic_trace(*args, **kwargs): |
| yield |
| |
| @contextmanager |
| def measure(what): |
| b = perf_counter() |
| yield |
| e = perf_counter() |
| print(f"{what}: {e - b:.20f} seconds") |
| |
| def triu(A): |
| i, j = dims() |
| a = A[i, j] |
| zero = torch.tensor(0, dtype=torch.float) # XXX - torch.where is janky... |
| return torch.where(i <= j, a, zero).order(i, j) |
| |
| def gpu_time(lmb, name, r=100): |
| b = torch.cuda.Event(enable_timing=True) |
| e = torch.cuda.Event(enable_timing=True) |
| # with magic_trace(name + ".fxt"): |
| for _ in range(r): |
| lmb() |
| b.record() |
| for _ in range(r): |
| lmb() |
| e.record() |
| e.synchronize() |
| elapsed = b.elapsed_time(e) |
| # with torch.profiler.profile(schedule=torch.profiler.schedule( |
| # wait=0, |
| # warmup=1, |
| # active=2), on_trace_ready=tensorboard_trace_handler(name), with_stack=True) as profiler: |
| # for _ in range(3): |
| # lmb() |
| # profiler.step() |
| print(name, elapsed / r) |
| return elapsed / r |
| |
| class TestMin(TestCase): |
| |
| def setUp(self): |
| super().setUp() |
| gc.disable() |
| gc.collect() |
| self.interesting = set() |
| for o in gc.get_objects(): |
| if isinstance(o, (torch.Tensor, Dim, Tensor, DimList)): |
| self.interesting.add(id(o)) |
| if 'cuda' in self._testMethodName: |
| self.mem_allocated = torch.cuda.memory_allocated() |
| |
| def tearDown(self): |
| interesting = [] |
| for o in gc.get_objects(): |
| if isinstance(o, (torch.Tensor, Dim, Tensor, DimList)) and id(o) not in self.interesting: |
| interesting.append(o) |
| |
| extra_memory = 0 |
| if 'cuda' in self._testMethodName: |
| extra_memory += torch.cuda.memory_allocated() - self.mem_allocated |
| |
| # nolevels = _n_levels_in_use() == 0 |
| if extra_memory != 0 or len(interesting) != 0: |
| import refcycle |
| refcycle.garbage().export_image('garbage.pdf') |
| gc.collect() |
| # assert nolevels, f"cleanup failed? {_n_levels_in_use()}" |
| assert extra_memory == 0, f'extra cuda memory left allocated: {extra_memory}' |
| assert len(interesting) == 0, \ |
| f'extra torch.Tensor, Dim, or Tensor left allocated: {len(interesting)} objects of types:' \ |
| f' { [type(t) for t in interesting] }' |
| |
| def test_manual_stuff(self): |
| |
| A_ = torch.rand(3, 4) |
| B_ = torch.rand(4, 5) |
| i, j, k = dims() |
| A = A_[i, k] |
| B = B_[k, j] |
| C = (A.expand(j) * B.expand(i)).sum(k) |
| self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_))) |
| self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_))) |
| |
| D_ = torch.randint(0, 3, (6,)) |
| d = dims() |
| D = D_[d] |
| |
| A.index([i], [D]).order(k, d) |
| |
| def attn(self, batch_size=1, sequence_length=4, hidden_size=6, num_attention_heads=3, linear=Linear, device=None, time=False): |
| def maybe_to(x): |
| return x if device is None else x.to(device) |
| |
| attention_probs_dropout_prob = 0. |
| A = maybe_to(BertSelfAttentionA(hidden_size, num_attention_heads, attention_probs_dropout_prob, linear=linear)) |
| B = maybe_to(BertSelfAttentionB(hidden_size, num_attention_heads, attention_probs_dropout_prob)) |
| |
| |
| A.load_state_dict(B.state_dict()) |
| hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size)) |
| b_out = B(hidden_state) |
| a_out = A(hidden_state) |
| self.assertTrue(torch.allclose(a_out, b_out)) # why does a simple matmul not do the right thing? |
| |
| if time: |
| gpu_time(lambda: B(hidden_state), "positional", r=3) |
| gpu_time(lambda: A(hidden_state), "first_class", r=3) |
| |
| for approach in ('relative_key', 'relative_key_query'): |
| A = maybe_to(BertSelfAttentionA(hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, approach, sequence_length, linear=linear)) |
| B = maybe_to(BertSelfAttentionB(hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, approach, sequence_length)) |
| A.load_state_dict(B.state_dict()) |
| |
| hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size)) |
| b_out = B(hidden_state) |
| a_out = A(hidden_state) |
| self.assertTrue(torch.allclose(a_out, b_out)) |
| |
| if time: |
| gpu_time(lambda: B(hidden_state), "positional", r=3) |
| gpu_time(lambda: A(hidden_state), "first_class", r=3) |
| |
| A = maybe_to(BertSelfAttentionA(hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, None, None, linear=linear)) |
| B = maybe_to(BertSelfAttentionB(hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, None, None)) |
| A.load_state_dict(B.state_dict()) |
| |
| hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size)) |
| past_key_value = (maybe_to(torch.rand(batch_size, num_attention_heads, |
| sequence_length, hidden_size // num_attention_heads)), |
| maybe_to(torch.rand(batch_size, num_attention_heads, |
| sequence_length, hidden_size // num_attention_heads))) |
| |
| b_out = B(hidden_state, past_key_value=past_key_value) |
| a_out = A(hidden_state, past_key_value=past_key_value) |
| self.assertTrue(torch.allclose(a_out, b_out)) |
| |
| if time: |
| gpu_time(lambda: B(hidden_state), "positional", r=3) |
| gpu_time(lambda: A(hidden_state), "first_class", r=3) |
| |
| def test_attn(self): |
| self.attn() |
| |
| def test_inplace(self): |
| # some embeddings table |
| embeddings = torch.zeros(10, 3) |
| |
| # some sparse updates to the embeddings |
| indices = torch.arange(2) + 1 |
| values = torch.rand(2, 3) |
| |
| i, n, f = dims() |
| |
| embeddings[indices[i], f] += values[i, f] |
| |
| def test_adapt(self): |
| def f(): |
| ci, co = dims() |
| # python 3.11 adapts bytecode after a number of iterations |
| # check that we still match names correctly |
| for i in range(10): |
| f() |
| |
| @skipIf(not TEST_CUDA, "no CUDA") |
| def test_attn_cuda(self): |
| # size from the BERT paper, 90% pretraining of sequence length 128 |
| self.attn(batch_size=256, hidden_size=768, sequence_length=128, |
| num_attention_heads=12, device='cuda', time=measure_perf, linear=torch.nn.Linear) |
| |
| def test_stack(self): |
| i, j, d = dims() |
| A = torch.rand(4, 5) |
| r = stack([A[i, j]], d, j) |
| # a, b = r.unbind(d) |
| # self.assertTrue(torch.allclose(a.order(i, j), i.expand(j).order(i, j))) |
| # self.assertTrue(torch.allclose(b.order(i, j), j.expand(i).order(i, j))) |
| |
| def test_max(self): |
| ap = torch.rand(2, 3, 2) |
| i, j, k = dims() |
| a = ap[i, j, k] |
| r, i0 = a.max(dim=k) |
| self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0])) |
| |
| def test_mm(self): |
| i, j, k, q = dims() |
| a = torch.rand(3, 4) |
| b = torch.rand(4, 5) |
| a_ = a[i, k] |
| b_ = b[k, j] |
| q.size = 1 |
| r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j) |
| # r = (a_*b_).sum(k).order(q, i, j) |
| # print(r) |
| # print(a @ b) |
| |
| def test_with_dims_split(self): |
| a = torch.arange(3 * 12).view(3, 12) |
| i, j, k = dims() |
| k.size = 4 |
| r = a[i, [j, k]] |
| x = r.order(i, [j, k]) |
| self.assertTrue(torch.allclose(a, x)) |
| |
| def test_hello(self): |
| A = torch.rand(3, 4) |
| B = torch.rand(4, 5) |
| i, j, k = dims() |
| |
| |
| |
| # r = A[i]*4 |
| r = (A[i, k] * B[k, j]).sum(k).order(i, j) |
| assert torch.allclose(r, A @ B) |
| |
| assert A.sum() == A[i].sum((0, i)) |
| assert A.sum() == A[i].sum((-1, i)) |
| |
| assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i))) |
| assert torch.allclose(A[i].std(i, True), A.std(0, True)) |
| |
| assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0]) |
| assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k)) |
| # XXX - chunk changes the size of a dimension, has to take a new dimension... |
| # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k)) |
| assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7)) |
| kk = dims() |
| # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k)) |
| |
| k2 = dims() |
| # r = cat((A[i, k], A[i,k]), k, k2) |
| # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2)) |
| # assert k2.size == 2*k.size |
| |
| assert torch.allclose(A.expand(5, -1, -1), A[i, k].expand(j).order(j, i, k)) |
| z = dims() |
| C = torch.arange(2) |
| assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z)) |
| |
| o, l = dims() |
| o.size = 2 |
| r = A[i, k].index(k, (o, l)) |
| assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) |
| rr = r.index((o, l), k) |
| assert torch.allclose(A, rr.order(i, k)) |
| |
| r = i + k - 1 |
| r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1 |
| assert torch.allclose(r.order(i, k), r2) |
| |
| # test with ... |
| assert torch.allclose(A.T, A[..., k].order(k)) |
| |
| # test with dimlist |
| a_, b_ = dimlists() |
| assert torch.allclose(A[i, a_].order(*a_, i), A.T) |
| # test with one bound dimlist |
| assert torch.allclose(A[:, a_].order(*a_), A.T) |
| # test with a dimlist that will end up empty |
| assert torch.allclose(A[i, b_, k].order(i, k, *b_), A) |
| # test with too few things |
| (A[i] + i) |
| assert torch.allclose((A[i] + i).order(i), A + torch.arange(3)[:, None]) |
| # test with too many elements |
| try: |
| A[1, ..., 1, 1] |
| raise NotImplementedError() |
| except IndexError: |
| pass |
| c, d = dims() |
| c.size = 2 |
| assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2)) |
| |
| assert torch.allclose(A[c + 1, c + 0].order(c), A[torch.arange(2) + 1, torch.arange(2)]) |
| try: |
| A[..., 3, ...] |
| raise NotImplementedError() |
| except DimensionBindError: |
| pass |
| |
| C = torch.rand(4, 7) |
| c_, x, y, z = dims() |
| |
| a, b, c = C.split((3, 3, 1), dim=1) |
| s = dims() |
| ref = C.split((3, 3, 1), dim=1) |
| t = C[s, c_].split((x, y, z), dim=c_) |
| for a, b, d in zip(ref, t, (x, y, z)): |
| assert torch.allclose(a, b.order(s, d)) |
| |
| D = torch.rand(3, 4, 5) |
| assert torch.allclose(D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k)) |
| |
| |
| r = [id(x) for x in torch.rand_like(A[i, k]).dims] |
| assert id(i) in r and id(k) in r |
| r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims] |
| assert id(i) in r and id(k) in r |
| |
| def test_simple(self): |
| i, j, k = dims() |
| x = torch.rand(3, 4) |
| z = x[i, j] |
| (z + z + z + z) |
| (z.order(i, j)) |
| |
| def test_mm_fuse(self): |
| i, j, k = dims() |
| A = torch.rand(3, 4) |
| B = torch.rand(4, 5) |
| |
| C = (A[i, k] * B[k, j]).sum(k).order(i, j) |
| assert torch.allclose(C, A @ B) |
| |
| def test_time_mm_fuse(self): |
| i, j, k = dims() |
| A = torch.rand(3, 4) |
| B = torch.rand(4, 5) |
| |
| |
| for _ in range(10): |
| r0 = A @ B |
| |
| for _ in range(10): |
| a = A[i, k] |
| b = B[k, j] |
| r1 = (a * b).sum(k) |
| |
| with measure('pp'): |
| for _ in range(10000): |
| A @ B |
| # magic_trace_stop_indicator() |
| |
| with measure('fc'): |
| for _ in range(10000): |
| (A[i, k] * B[k, j]).sum(k).order(i, j) |
| |
| with magic_trace('f.fxt'): |
| for _ in range(10000): |
| (A[i, k] * B[k, j]).sum(k).order(i, j) |
| |
| with magic_trace('p.fxt'): |
| for _ in range(10000): |
| A @ B |
| |
| # magic_trace_stop_indicator() |
| |
| |
| assert torch.allclose(r1.order(i, j), r0) |
| |
| def test_compare_dims(self): |
| i, j = dims() |
| i.size = 3 |
| j.size = 4 |
| (i < j) # noqa: B015 |
| |
| def test_c(self): |
| _test_c() |
| |
| def test_seg(self): |
| A = torch.rand(3, 4) |
| i, k = dims() |
| i.size = 4 |
| k.size = 3 |
| r = i + k - 1 |
| |
| def test_expand(self): |
| A = torch.rand(3, 4) |
| i = dims() |
| assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4] |
| |
| |
| def test_parse(self): |
| self.assertEqual(("x", None, None, None), _parse_test(1, 0, "x")) |
| self.assertEqual(("x", None, "y", None), _parse_test(1, 0, "x", c="y")) |
| self.assertEqual(("x", None, "y", "z"), _parse_test(1, 0, "x", d="z", c="y")) |
| |
| self.assertEqual(("x", "4", None, None), _parse_test(2, 0, "x", b="4")) |
| self.assertEqual(("x", "y", "z", "q"), _parse_test(2, 0, "x", "y", "z", "q")) |
| with self.assertRaises(TypeError): |
| _parse_test(2, 0, "x", "y", "z", "q", "5") |
| with self.assertRaises(TypeError): |
| _parse_test(2, 0, "x", "y", b="y") |
| |
| with self.assertRaises(TypeError): |
| _parse_test(2, 0, "x", c="y") |
| with self.assertRaises(TypeError): |
| _parse_test(2, 0, "x") |
| |
| def test_network(self): |
| if resnet18 is None: |
| self.skipTest('no torchvision') |
| rn = resnet18(norm_layer=lambda x: torch.nn.BatchNorm2d(x, track_running_stats=False)) |
| rn.train() |
| img = torch.rand(1, 1, 2, 3, 224, 224) |
| imgf = img.view(2, 3, 224, 224) |
| |
| i, j = dims() |
| r = rn(img[i, j]) |
| r = r.order(i, j).view(2, 1000) |
| r2 = rn(imgf) |
| assert torch.allclose(r2, r, atol=1e-06) |
| |
| def test_dim_args(self): |
| a = dimlists() |
| assert isinstance(a, DimList) |
| a = dims() |
| b = dimlists() |
| assert isinstance(a, Dim) |
| assert isinstance(b, DimList) |
| assert str(a) == 'a' |
| a, b = dims(sizes=[3, 4]) |
| assert a.size == 3 |
| assert b.size == 4 |
| a = dims(sizes=[3]) |
| b = dimlists(sizes=[4]) |
| assert len(b) == 4 |
| a = dims() |
| b = dimlists(sizes=[[4, 5]]) |
| assert b[0].size == 4 |
| assert b[1].size == 5 |
| |
| def test_diag(self): |
| i = dims() |
| A = torch.rand(4, 4) |
| (A[i, i]) |
| |
| def test_softmax_split(self): |
| a = torch.rand(16) |
| g, i = dims(sizes=[2, None]) |
| a2 = a[[i, g], ] |
| |
| m_b, _ = a2.max(i) |
| f_b = torch.exp(a2 - m_b) |
| l_b = f_b.sum(i) |
| |
| m, _ = m_b.max(g) |
| c = torch.exp(m_b - m) |
| f = (c * f_b).order((i, g)) |
| l = (c * l_b).sum(g) |
| assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0)) |
| |
| def test_index(self): |
| A = torch.rand(3, 4) |
| B = torch.rand(4, 5) |
| i, j, k = dims() |
| |
| o, l = dims() |
| o.size = 2 |
| r = A[i, k].index(k, [o, l]) |
| assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) |
| rr = r.index([o, l], k) |
| assert torch.allclose(A, rr.order(i, k)) |
| z = dims() |
| C = torch.arange(2) |
| x = A[i, k].index(k, C[z]).order(i, z) |
| assert torch.allclose(A[:, 0:2], x) |
| |
| C = torch.rand(3, 4, 5) |
| ik = dims() |
| assert torch.allclose(C.index((0, 2), ik).order(ik), C.permute(0, 2, 1).reshape(15, 4)) |
| |
| # failures that came up from monkey patching some operators... |
| def test_monkey(self): |
| A = torch.rand(3, 4) |
| A[0, 0] = 5 |
| x = torch.randn(3, 4, 4, 4, 3) |
| x_clone1 = x.clone() |
| ia = torch.tensor([0, 2, 1]) |
| ib = torch.tensor([0, 2, 1]) |
| first_shape = x[:, ia, None, ib, 0].shape |
| x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1) |
| x = torch.autograd.Variable(torch.tensor([])) |
| z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) |
| a = [z[2], z[0] + 3] |
| x.new(a) |
| # self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) |
| |
| def test_index_placement(self): |
| A = torch.rand(1, 2, 3, 4) |
| |
| i, j = dims(sizes=[2, 4]) |
| |
| a = A[:, i + 0, :, j + 0] |
| r = a.order(i, j) |
| |
| assert torch.allclose(A.permute(1, 3, 0, 2), r) |
| |
| def test_order(self): |
| i, j = dims() |
| A = torch.rand(3, 4, 5) |
| assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1)) |
| |
| def test_mask(self): |
| a = torch.rand(5) |
| i, j = dims(sizes=[a.size(0), a.size(0)]) |
| ((i >= j) * a[i]).sum(j).order(i) |
| |
| def test_eq(self): |
| i, j = dims(sizes=[3, 3]) |
| assert (i == j).sum((i, j)) == 3 |
| |
| def test_dims_with_size(self): |
| x = dims(3) |
| assert len(x) == 3 and isinstance(x[0], Dim) |
| |
| class Foo: |
| pass |
| y = Foo() |
| z, y.x, q = dims(3) |
| assert str(z) == "z" |
| assert str(y.x) == "d1" |
| assert str(q) == "d2" |
| |
| def test_dir(self): |
| i, j = dims(sizes=[3, 3]) |
| dir(i <= j) |
| |
| def test_doc(self): |
| assert Tensor.clamp.__doc__ == torch.Tensor.clamp.__doc__ |
| |
| def test_embed(self): |
| |
| embeddings = torch.rand(8, 32) |
| ids = torch.tensor([1, 0, 3, 4]) |
| |
| # slow but Pythonic |
| values_ = torch.empty(4, 32) |
| for batch in range(ids.size(0)): |
| for feature in range(embeddings.size(1)): |
| values_[batch, feature] = embeddings[ids[batch], feature] |
| |
| # with torchdim, single indexing kernel |
| batch, feature = dims(2) |
| values = embeddings[ids[batch], feature].order(batch, feature) |
| |
| assert torch.allclose(values, values_) |
| |
| def test_functorch(self): |
| A = torch.rand(3, 4, 5) |
| B = torch.rand(3, 4, 5) |
| C = torch.rand(5, 2) |
| |
| i, j = dims() |
| |
| AA = torch.mm(A[i], C) # 3, 4, 2 |
| BB = torch.mm(B[j], C) # 3, 4, 2 |
| assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2] |
| |
| def test_permute_orig(self): |
| d = dims(1) |
| t_fc = torch.rand(1, 2, 3, 4)[d] |
| assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape |
| |
| def test_order_keyword(self): |
| d = dims(1) |
| t = torch.rand(3)[d] |
| self.assertRaises(TypeError, lambda: t.order(wrong=3)) |
| |
| def test_big_split(self): |
| total = 0 |
| l = [] |
| while total < 6400: |
| l.append(torch.randint(2, 10, (1,)).item()) |
| total += l[-1] |
| x = torch.randn(total, 1) |
| x.split(l, 0) |
| |
| skip_functorch_only = ['test_time_mm_fuse', 'test_attn_cuda'] |
| class TestMinFunctorchOnly(TestMin): |
| def setUp(self): |
| super().setUp() |
| _set_pointwise_optimize(False) |
| |
| def tearDown(self): |
| _set_pointwise_optimize(True) |
| super().tearDown() |
| |
| for n in skip_functorch_only: |
| setattr(TestMinFunctorchOnly, n, skip("skip_functorch_only")(lambda self: None)) |
| |
| if __name__ == '__main__': |
| run_tests() |