| import itertools |
| import operator |
| |
| import numpy as np |
| import scipy.special |
| |
| import torch |
| |
| from . import benchmark |
| |
| |
| # A template class for elementwise operations. |
| # A derived class will override the class instance to customize its behavior. |
| class ElementBench(benchmark.Benchmark): |
| # List of customization class variables. |
| op_str = None |
| binary_op_pt_func = None |
| binary_op_np_func = None |
| unary_op_pt_func = None |
| unary_op_np_func = None |
| split_input = True |
| |
| def __init__(self, mode, device, dtype, N): |
| super().__init__(mode, device, dtype) |
| self.N = N |
| self.d1 = self.rand( |
| [N], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.d2 = self.rand( |
| [N], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.d3 = self.rand( |
| [N], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.d4 = self.rand( |
| [N], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.inputs = [self.d1, self.d2, self.d3, self.d4] |
| self.deterministic = "rand" not in self.op_str |
| |
| def _eval(self, d1, d2, d3, d4, binary_op, unary_op): |
| if not binary_op: |
| |
| def binary_op(x, y): |
| return x + y |
| |
| if not unary_op: |
| |
| def unary_op(x): |
| return x |
| |
| if self.split_input: |
| d1 = unary_op(d1) |
| d2 = unary_op(d2) |
| d3 = unary_op(d3) |
| d4 = unary_op(d4) |
| else: |
| d2 = unary_op(d1 + 0.001) |
| d3 = unary_op(d1 + 0.002) |
| d4 = unary_op(d1 + 0.003) |
| d1 = unary_op(d1) |
| a = binary_op(d1, d2) |
| b = binary_op(d3, d4) |
| c = a + b |
| return c |
| |
| def forward(self, d1, d2, d3, d4): |
| binary_op = self.__class__.binary_op_pt_func |
| unary_op = self.__class__.unary_op_pt_func |
| return self._eval(d1, d2, d3, d4, binary_op, unary_op) |
| |
| def reference(self): |
| binary_op = self.__class__.binary_op_np_func |
| unary_op = self.__class__.unary_op_np_func |
| [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] |
| return self._eval(d1, d2, d3, d4, binary_op, unary_op) |
| |
| def config(self): |
| return [self.N] |
| |
| @classmethod |
| def module(cls): |
| return "element_" + cls.op_str |
| |
| def memory_workload(self): |
| input_count = len(self.inputs) |
| if self.mode == "fwd": |
| if self.split_input: |
| sol_count = input_count + 1 |
| algorithmic_count = input_count + 1 |
| else: |
| sol_count = 1 + 1 |
| algorithmic_count = 1 + 1 |
| if "rand" in self.op_str: |
| sol_count = 1 |
| algorithmic_count = 1 |
| else: |
| if self.split_input: |
| sol_count = (input_count + 1) + (1 + input_count) |
| algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) |
| else: |
| sol_count = 1 + 1 |
| algorithmic_count = 1 + 1 |
| if "rand" in self.op_str: |
| sol_count = 1 |
| algorithmic_count = 1 |
| |
| buffer_size = self.N |
| return { |
| "sol": buffer_size * sol_count, |
| "algorithmic": buffer_size * algorithmic_count, |
| } |
| |
| @staticmethod |
| def default_configs(): |
| return [[1 << 25]] |
| |
| |
| def register_element_ops(): |
| binary_op_list = [ |
| ["mul", operator.mul], |
| ["add", operator.add], |
| ["sub", operator.sub], |
| ["div", lambda a, b: a / (b + 1e-4)], |
| [ |
| "pow", |
| torch.pow, |
| np.power, |
| ], # no fuson triggered |
| ["max", torch.max, np.maximum], |
| ["min", torch.min, np.minimum], |
| ] |
| |
| unary_op_list = [ |
| ["erf", torch.erf, scipy.special.erf], |
| ["exp", torch.exp, np.exp], |
| ["sin", torch.sin, np.sin], |
| ["cos", torch.cos, np.cos], |
| ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)], |
| ] |
| |
| for split_input, binary_op in itertools.product([True, False], binary_op_list): |
| # Make a copy of ElementBench |
| if len(binary_op) == 2: |
| [op_str, op_pt_func] = binary_op |
| op_np_func = op_pt_func |
| elif len(binary_op) == 3: |
| [op_str, op_pt_func, op_np_func] = binary_op |
| split_str = "split" if split_input else "shared" |
| op_str = split_str + "_" + op_str |
| bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) |
| bm_cls.op_str = op_str |
| bm_cls.binary_op_pt_func = op_pt_func |
| bm_cls.binary_op_np_func = op_np_func |
| bm_cls.split_input = split_input |
| benchmark.register_benchmark_class(bm_cls) |
| |
| for split_input, unary_op in itertools.product([True, False], unary_op_list): |
| # Make a copy of ElementBench |
| if len(unary_op) == 2: |
| [op_str, op_pt_func] = unary_op |
| op_np_func = op_pt_func |
| elif len(unary_op) == 3: |
| [op_str, op_pt_func, op_np_func] = unary_op |
| split_str = "split" if split_input else "shared" |
| op_str = split_str + "_" + op_str |
| bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) |
| bm_cls.op_str = op_str |
| bm_cls.unary_op_pt_func = op_pt_func |
| bm_cls.unary_op_np_func = op_np_func |
| bm_cls.split_input = split_input |
| benchmark.register_benchmark_class(bm_cls) |
| |
| |
| # benchmark.register_benchmark_class(ElementMulBench) |
| register_element_ops() |
| |
| |
| class SimpleElementBench(benchmark.Benchmark): |
| def __init__(self, mode, device, dtype, N): |
| super().__init__(mode, device, dtype) |
| self.N = N |
| self.data = self.rand( |
| [N], device=device, dtype=dtype, requires_grad=self.requires_grad |
| ) |
| self.inputs = [self.data] |
| |
| def forward(self, data): |
| a = data + 0.001 |
| b = a + 0.002 |
| return b |
| |
| def reference(self): |
| binary_op = self.__class__.binary_op_np_func |
| unary_op = self.__class__.unary_op_np_func |
| [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] |
| return self._eval(d1, d2, d3, d4, binary_op, unary_op) |
| |
| def config(self): |
| return [self.N] |
| |
| @staticmethod |
| def input_iterable(): |
| return True |
| |
| @classmethod |
| def module(cls): |
| return "simple_element" |
| |
| def memory_workload(self): |
| input_count = len(self.inputs) |
| if self.mode == "fwd": |
| sol_count = 2 |
| algorithmic_count = 2 |
| else: |
| sol_count = 2 |
| algorithmic_count = 2 |
| |
| buffer_size = self.N |
| return { |
| "sol": buffer_size * sol_count, |
| "algorithmic": buffer_size * algorithmic_count, |
| } |
| |
| @staticmethod |
| def default_configs(): |
| return [[1 << 25]] |
| |
| |
| benchmark.register_benchmark_class(SimpleElementBench) |
| |
| |
| class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench): |
| def __init__(self, mode, device, dtype, N): |
| benchmark.DynamicShape.__init__(self) |
| SimpleElementBench.__init__(self, mode, device, dtype, N) |
| |
| @classmethod |
| def module(cls): |
| return "simple_dynamic_element" |
| |
| def instantiate_input(self): |
| (N,) = self.rand_shape([self.N]) |
| data = self.rand( |
| [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad |
| ) |
| self.inputs = [data] |
| |
| |
| benchmark.register_benchmark_class(DynamicSimpleElementBench) |