| import torch |
| |
| from . import benchmark |
| |
| |
| class RNNEltwise(benchmark.Benchmark): |
| def __init__(self, mode, device, dtype, b, hs): |
| super().__init__(mode, device, dtype) |
| self.b = b |
| self.hs = hs |
| self.input = self.rand( |
| [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.hx = self.rand( |
| [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.cx = self.rand( |
| [b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.b_ih = self.rand( |
| [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.b_hh = self.rand( |
| [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.inputs = [ |
| self.input, |
| self.hx, |
| self.cx, |
| self.b_ih, |
| self.b_hh, |
| ] |
| |
| def forward(self, input, hx, cx, b_ih, b_hh): |
| gates = input + hx + b_ih + b_hh |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| def config(self): |
| return [self.b, self.hs] |
| |
| @staticmethod |
| def module(): |
| return "rnn_eltwise" |
| |
| def memory_workload(self): |
| def memsize(t): |
| return t.numel() * t.element_size() |
| |
| input_size = sum(memsize(t) for t in self.inputs) |
| output_size = 2 * memsize(self.cx) |
| io_size = input_size + output_size |
| return {"sol": io_size, "algorithmic": io_size} |
| |
| @staticmethod |
| def default_configs(): |
| return [[64, 512]] |
| |
| |
| benchmark.register_benchmark_class(RNNEltwise) |
| |
| |
| class DynamicLSTM(benchmark.DynamicShape, RNNEltwise): |
| def __init__(self, mode, device, dtype, b, hs): |
| benchmark.DynamicShape.__init__(self) |
| RNNEltwise.__init__(self, mode, device, dtype, b, hs) |
| |
| def instantiate_input(self): |
| b, hs = self.rand_shape([self.b, self.hs]) |
| |
| self.input = self.rand( |
| [b, 4 * hs], |
| device=self.device, |
| dtype=self.dtype, |
| requires_grad=self.requires_grad, |
| ) |
| self.hx = self.rand( |
| [b, 4 * hs], |
| device=self.device, |
| dtype=self.dtype, |
| requires_grad=self.requires_grad, |
| ) |
| self.cx = self.rand( |
| [b, hs], |
| device=self.device, |
| dtype=self.dtype, |
| requires_grad=self.requires_grad, |
| ) |
| self.b_ih = self.rand( |
| [b, 4 * hs], |
| device=self.device, |
| dtype=self.dtype, |
| requires_grad=self.requires_grad, |
| ) |
| self.b_hh = self.rand( |
| [b, 4 * hs], |
| device=self.device, |
| dtype=self.dtype, |
| requires_grad=self.requires_grad, |
| ) |
| self.inputs = [ |
| self.input, |
| self.hx, |
| self.cx, |
| self.b_ih, |
| self.b_hh, |
| ] |
| |
| @staticmethod |
| def module(): |
| return "dynamic_lstm" |
| |
| |
| benchmark.register_benchmark_class(DynamicLSTM) |