blob: 5ea4d503e4fdd462f1496b5579999d8ab477dc40 [file] [log] [blame]
import itertools
import operator
import numpy as np
import scipy.special
import torch
from . import benchmark
# A template class for elementwise operations.
# A derived class will override the class instance to customize its behavior.
class ElementBench(benchmark.Benchmark):
# List of customization class variables.
op_str = None
binary_op_pt_func = None
binary_op_np_func = None
unary_op_pt_func = None
unary_op_np_func = None
split_input = True
def __init__(self, mode, device, dtype, N):
super().__init__(mode, device, dtype)
self.N = N
self.d1 = self.rand(
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.d2 = self.rand(
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.d3 = self.rand(
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.d4 = self.rand(
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [self.d1, self.d2, self.d3, self.d4]
self.deterministic = "rand" not in self.op_str
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
if not binary_op:
def binary_op(x, y):
return x + y
if not unary_op:
def unary_op(x):
return x
if self.split_input:
d1 = unary_op(d1)
d2 = unary_op(d2)
d3 = unary_op(d3)
d4 = unary_op(d4)
else:
d2 = unary_op(d1 + 0.001)
d3 = unary_op(d1 + 0.002)
d4 = unary_op(d1 + 0.003)
d1 = unary_op(d1)
a = binary_op(d1, d2)
b = binary_op(d3, d4)
c = a + b
return c
def forward(self, d1, d2, d3, d4):
binary_op = self.__class__.binary_op_pt_func
unary_op = self.__class__.unary_op_pt_func
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
def reference(self):
binary_op = self.__class__.binary_op_np_func
unary_op = self.__class__.unary_op_np_func
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
def config(self):
return [self.N]
@classmethod
def module(cls):
return "element_" + cls.op_str
def memory_workload(self):
input_count = len(self.inputs)
if self.mode == "fwd":
if self.split_input:
sol_count = input_count + 1
algorithmic_count = input_count + 1
else:
sol_count = 1 + 1
algorithmic_count = 1 + 1
if "rand" in self.op_str:
sol_count = 1
algorithmic_count = 1
else:
if self.split_input:
sol_count = (input_count + 1) + (1 + input_count)
algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
else:
sol_count = 1 + 1
algorithmic_count = 1 + 1
if "rand" in self.op_str:
sol_count = 1
algorithmic_count = 1
buffer_size = self.N
return {
"sol": buffer_size * sol_count,
"algorithmic": buffer_size * algorithmic_count,
}
@staticmethod
def default_configs():
return [[1 << 25]]
def register_element_ops():
binary_op_list = [
["mul", operator.mul],
["add", operator.add],
["sub", operator.sub],
["div", lambda a, b: a / (b + 1e-4)],
[
"pow",
torch.pow,
np.power,
], # no fuson triggered
["max", torch.max, np.maximum],
["min", torch.min, np.minimum],
]
unary_op_list = [
["erf", torch.erf, scipy.special.erf],
["exp", torch.exp, np.exp],
["sin", torch.sin, np.sin],
["cos", torch.cos, np.cos],
["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
]
for split_input, binary_op in itertools.product([True, False], binary_op_list):
# Make a copy of ElementBench
if len(binary_op) == 2:
[op_str, op_pt_func] = binary_op
op_np_func = op_pt_func
elif len(binary_op) == 3:
[op_str, op_pt_func, op_np_func] = binary_op
split_str = "split" if split_input else "shared"
op_str = split_str + "_" + op_str
bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
bm_cls.op_str = op_str
bm_cls.binary_op_pt_func = op_pt_func
bm_cls.binary_op_np_func = op_np_func
bm_cls.split_input = split_input
benchmark.register_benchmark_class(bm_cls)
for split_input, unary_op in itertools.product([True, False], unary_op_list):
# Make a copy of ElementBench
if len(unary_op) == 2:
[op_str, op_pt_func] = unary_op
op_np_func = op_pt_func
elif len(unary_op) == 3:
[op_str, op_pt_func, op_np_func] = unary_op
split_str = "split" if split_input else "shared"
op_str = split_str + "_" + op_str
bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
bm_cls.op_str = op_str
bm_cls.unary_op_pt_func = op_pt_func
bm_cls.unary_op_np_func = op_np_func
bm_cls.split_input = split_input
benchmark.register_benchmark_class(bm_cls)
# benchmark.register_benchmark_class(ElementMulBench)
register_element_ops()
class SimpleElementBench(benchmark.Benchmark):
def __init__(self, mode, device, dtype, N):
super().__init__(mode, device, dtype)
self.N = N
self.data = self.rand(
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [self.data]
def forward(self, data):
a = data + 0.001
b = a + 0.002
return b
def reference(self):
binary_op = self.__class__.binary_op_np_func
unary_op = self.__class__.unary_op_np_func
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
def config(self):
return [self.N]
@staticmethod
def input_iterable():
return True
@classmethod
def module(cls):
return "simple_element"
def memory_workload(self):
input_count = len(self.inputs)
if self.mode == "fwd":
sol_count = 2
algorithmic_count = 2
else:
sol_count = 2
algorithmic_count = 2
buffer_size = self.N
return {
"sol": buffer_size * sol_count,
"algorithmic": buffer_size * algorithmic_count,
}
@staticmethod
def default_configs():
return [[1 << 25]]
benchmark.register_benchmark_class(SimpleElementBench)
class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench):
def __init__(self, mode, device, dtype, N):
benchmark.DynamicShape.__init__(self)
SimpleElementBench.__init__(self, mode, device, dtype, N)
@classmethod
def module(cls):
return "simple_dynamic_element"
def instantiate_input(self):
(N,) = self.rand_shape([self.N])
data = self.rand(
[N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.inputs = [data]
benchmark.register_benchmark_class(DynamicSimpleElementBench)