blob: a923af9505a1a9c68133a436ed9df8f806081ad6 [file] [log] [blame]
import torch
from . import benchmark
class RNNEltwise(benchmark.Benchmark):
def __init__(self, mode, device, dtype, b, hs):
super().__init__(mode, device, dtype)
self.b = b
self.hs = hs
self.input = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
def forward(self, input, hx, cx, b_ih, b_hh):
gates = input + hx + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def config(self):
return [self.b, self.hs]
@staticmethod
def module():
return "rnn_eltwise"
def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()
input_size = sum(memsize(t) for t in self.inputs)
output_size = 2 * memsize(self.cx)
io_size = input_size + output_size
return {"sol": io_size, "algorithmic": io_size}
@staticmethod
def default_configs():
return [[64, 512]]
benchmark.register_benchmark_class(RNNEltwise)
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
def __init__(self, mode, device, dtype, b, hs):
benchmark.DynamicShape.__init__(self)
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
def instantiate_input(self):
b, hs = self.rand_shape([self.b, self.hs])
self.input = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.hx = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.cx = self.rand(
[b, hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.b_ih = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.b_hh = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
@staticmethod
def module():
return "dynamic_lstm"
benchmark.register_benchmark_class(DynamicLSTM)