blob: 5665687446b1a2cd8da3f5d7df85570f4e71c857 [file] [log] [blame] [edit]
# Owner(s): ["module: unknown"]
import unittest
from typing import Dict, Optional
import numpy as np
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.static_module import StaticModule
from typing import List
def linear_shim(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
torch.nn.functional.linear = linear_shim
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
# self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
# energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim=-1)
# x = torch.matmul(self.dropout(attention), V)
x = torch.matmul(attention, V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
def create_mlp(ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, len(ln) - 1):
n = ln[i]
m = ln[i + 1]
LL = nn.Linear(int(n), int(m), bias=True)
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
with torch.no_grad():
s = torch.jit.script(torch.nn.Sequential(*layers))
s.eval()
return s
def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
def elementwise_square_addition(input1, input2):
return input1 * input1 + input2 * input2
def fork_wait_graph1(input1, input2):
fut = torch.jit.fork(elementwise_square_addition, input1, input2)
return torch.jit.wait(fut)
def fork_wait_graph2(input1, input2):
fut = torch.jit.fork(loop_graph, input1, input2, 5)
return torch.jit.wait(fut)
"""
graph with multiple fork/wait operations
:param input: torch.tensor input to forked subgraph
:param iters: number of future/wait pairs to be created
"""
def fork_wait_graph3(input, iters: int):
futures : List[torch.jit.Future[torch.Tensor]] = []
for _ in range(iters):
futures.append(torch.jit.fork(torch.neg, input))
results = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.sum(torch.stack(results))
"""
graph with multi-level fork/wait operations
:param input: torch.tensor input to forked subgraph
:param num_forks: number of top level forks
:param num_child_forks: number of child forks per parent fork
"""
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
futures : List[torch.jit.Future[torch.Tensor]] = []
for _ in range(num_forks):
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
results = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.sum(torch.stack(results))
def add_tensor(input1, input2):
return input1 + input2
def fork_wait_graph_exception(input1, input2):
fut = torch.jit.fork(add_tensor, input1, input2)
return torch.jit.wait(fut)
def loop_graph(a, b, iters: int):
c = a + b * 2
for i in range(iters):
c = c + b
c *= 2
c -= a
return c
def output_graph(a, b, c, iters: int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d: Dict[int, torch.Tensor] = {}
for i in range(iters):
d[i] = k + i
return d
class SubModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = 11
self.b = 2
def forward(self, x):
return self.a + self.b + x
class SubModule2(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = 12
self.b = 2
def forward(self, x):
self.b = 30
return self.a + self.b + x
class TestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sub1 = SubModule()
self.sub2 = SubModule2()
self.a = 3
self.b = 4
def forward(self, x):
self.b = 20
return self.sub1(x) + self.a + self.b + self.sub2(x)
class TestStaticModule(TestCase):
"""
Test Case: To test simple fork/wait operation in a graph
fork is called on simple addition operation on input tensors
"""
def test_fork_wait_1(self):
inp1 = torch.ones(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph1)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(inp1, inp2)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test simple fork/wait operation with
StaticRuntime runAsync API returning future
"""
def test_fork_wait_1_async(self):
inp1 = torch.ones(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph1)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((inp1, inp2), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
a loop subgraph performing mix of operations
"""
def test_fork_wait_2(self):
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph2)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(inp1, inp2)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation on a loop
subgraph with StaticRuntime runAsync API returning future
"""
def test_fork_wait_2_async(self):
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph2)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((inp1, inp2), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
having multiple fork/wait operations
"""
def test_fork_wait_3(self):
input = torch.ones(3, 3)
num_forks = 10
torch_graph = torch.jit.script(fork_wait_graph3)
output_ref = torch_graph(input, num_forks)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(input, num_forks)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation in a graph with
multiple fork/wait operations on runAsync API returning future
"""
def test_fork_wait_3_async(self):
input = torch.ones(3, 3)
num_forks = 10
torch_graph = torch.jit.script(fork_wait_graph3)
output_ref = torch_graph(input, num_forks)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((input, num_forks), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
multiple nested fork/wait operations
"""
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
def test_fork_wait_4(self):
input = torch.ones(3, 3)
num_forks = 10
num_child_forks = 10
torch_graph = torch.jit.script(fork_wait_graph4)
static_runtime_module = StaticModule(torch_graph)
output_ref = torch_graph(input, num_forks, num_child_forks)
output_test = static_runtime_module(input, num_forks, num_child_forks)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation in a graph with multiple
nested fork/wait operations on runAsync API returning future
"""
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
def test_fork_wait_4_async(self):
input = torch.ones(3, 3)
num_forks = 10
num_child_forks = 10
torch_graph = torch.jit.script(fork_wait_graph4)
static_runtime_module = StaticModule(torch_graph)
output_ref = torch_graph(input, num_forks, num_child_forks)
output_test = static_runtime_module.runAsync(
(input, num_forks, num_child_forks), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test exception handling in fork/wait
operation. Add.Tensor op is called for tensors with
non-matching dims on the forked subgraph and the
exception raised by subgraph is set on future returned
by prim::fork to parent graph. Returned exception is
checked for substring expected_error_msg as declared below
"""
def test_fork_wait_exception(self):
# incompatible tensors for add due to shape mismatch
input1 = torch.randn(4, 7)
input2 = torch.randn(4, 5)
torch_graph = torch.jit.script(fork_wait_graph_exception)
try:
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(input1, input2)
except Exception as error:
expected_error_msg = (
"The size of tensor a (7) must match the size "
"of tensor b (5) at non-singleton dimension 1"
)
# test fails if error does not contain expected substr
if str(error).find(expected_error_msg) == -1:
raise RuntimeError(
"Tried execution of add.Tensors with incompatible shape. "
"Exception raised by forked runtime execution does "
f'not contain expected substring: "{expected_error_msg}"'
) from error
"""
Test Case: To test exception handling in fork/wait
operation with runAsync API. Add.Tensor op is called for
tensors with non-matching dims on the forked subgraph
and the exception raised by subgraph is set on future returned
by prim::fork to parent graph. Returned exception is
checked for substring expected_error_msg as declared below
"""
def test_fork_wait_exception_async(self):
# incompatible tensors for add due to shape mismatch
input1 = torch.randn(4, 7)
input2 = torch.randn(4, 5)
torch_graph = torch.jit.script(fork_wait_graph_exception)
try:
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync(
(input1, input2), {})
except Exception as error:
expected_error_msg = (
"The size of tensor a (7) must match the size "
"of tensor b (5) at non-singleton dimension 1"
)
# test fails if error does not contain expected substr
if str(error).find(expected_error_msg) == -1:
raise RuntimeError(
"Tried execution of add.Tensors with incompatible shape. "
"Exception raised by forked runtime execution does "
f'not contain expected substring: "{expected_error_msg}"'
) from error
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticModule(attention)
o_test = attention_a(src, src, src, src_mask)
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_close(a, b)
for a, b in zip(o_ref, o_test_kw):
torch.testing.assert_close(a, b)
def test_multihead_attention_layer_benchmark(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention_a = StaticModule(attention)
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
metrics = attention_a.benchmark_individual_ops(
[src, src, src, src_mask], {}, 2, 2
)
def test_mlp(self):
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticModule(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticModule(top_l)
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)
torch.testing.assert_close(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)
torch.testing.assert_close(acc_top, ref_top)
for _ in range(5):
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)
torch.testing.assert_close(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)
torch.testing.assert_close(acc_top, ref_top)
def test_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticModule(tg)
o_test = tg_a(s, s, s)
torch.testing.assert_close(o_ref, o_test)
def test_leaky_relu(self):
s = torch.randn(5, 5)
tg = torch.jit.script(nn.LeakyReLU(0.1))
o_ref = tg(s)
tg_a = StaticModule(tg)
o_test = tg_a(s)
torch.testing.assert_close(o_ref, o_test)
def test_attr(self):
"""
TorchScript IR of TestModule() after freezing:
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
%x.1 : Tensor):
%18 : int = prim::Constant[value=30]()
%30 : int = prim::Constant[value=13]()
%3 : int = prim::Constant[value=20]()
%2 : int = prim::Constant[value=1]()
%self.sub2.a : int = prim::Constant[value=12]()
%self.a : int = prim::Constant[value=3]()
= prim::SetAttr[name="b"](%self, %3)
%17 : Tensor = aten::add(%x.1, %30, %2)
%7 : Tensor = aten::add(%17, %self.a, %2)
%b.1 : int = prim::GetAttr[name="b"](%self)
%9 : Tensor = aten::add(%7, %b.1, %2)
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
= prim::SetAttr[name="b"](%sub2, %18)
%b : int = prim::GetAttr[name="b"](%sub2)
%22 : int = aten::add(%self.sub2.a, %b)
%23 : Tensor = aten::add(%x.1, %22, %2)
%12 : Tensor = aten::add(%9, %23, %2)
return (%12)
"""
# test prim::SetAttr and prim::GetAttr impl in Static Runtime
m = TestModule()
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
ms = torch.jit.script(m)
sm = StaticModule(ms)
output_sm = sm(input)
torch.testing.assert_close(output_s, output_sm)
sm.benchmark([input], {}, 2, 2)
sm.benchmark_individual_ops([input], {}, 2, 2)
sm.benchmark([], {"x": input}, 2, 2)
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
@unittest.skip("Temporarily disabled")
def test_fusion_trivial_graph(self):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
torch._C._fuse_to_static_module(tg.graph)
assert "StaticSubgraph" in str(tg.graph)
o_test = tg(s, s, s)
torch.testing.assert_close(o_ref, o_test)
@unittest.skip("Temporarily disabled")
def test_fusion_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
with torch.no_grad():
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
torch._C._fuse_to_static_module(attention._c)
o_test = attention(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_close(a, b)
@unittest.skip("Temporarily disabled")
def test_fusion_loop(self):
a = torch.randn(5, 5)
b = torch.randn(5, 5)
c = 4
lg = torch.jit.script(loop_graph)
o_ref = lg(a, b, c)
torch._C._fuse_to_static_module(lg.graph)
assert "StaticSubgraph" in str(lg.graph)
o_test = lg(a, b, c)
torch.testing.assert_close(o_ref, o_test)
@unittest.skip("Temporarily disabled")
def test_fusion_outputs(self):
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = 4
og = torch.jit.script(output_graph)
o_ref = og(a, b, b, c)
torch._C._fuse_to_static_module(og.graph)
assert "StaticSubgraph" in str(og.graph)
o_test = og(a, b, b, c)
for i in o_ref.keys():
torch.testing.assert_close(o_ref[i], o_test[i])
def test_create_object(self):
class Foo: # noqa: B903
def __init__(self, x: torch.Tensor) -> None:
self.x = x
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, y: torch.Tensor) -> torch.Tensor:
foo = Foo(y)
return y * foo.x
mod = torch.jit.script(Mod()).eval()
y = torch.randn((1, ))
expected = mod(y)
static_mod = StaticModule(torch.jit.freeze(mod))
actual = static_mod(y)
self.assertEqual(expected, actual)
if __name__ == "__main__":
run_tests()