blob: a364a05f86a426ed12d9e6143d5da9a85c634be1 [file] [log] [blame]
import csv
from collections import defaultdict
import torch
import yaml
def get_ops_for_key(key):
# Needs modified PyTorch C++ code to work
if key is None:
ops = torch._C._dispatch_get_registrations_for_dispatch_key()
else:
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
cleaned_ops = []
for i in ops:
if "aten::" not in i:
continue
cleaned_ops.append(i[6:].strip())
return set(cleaned_ops)
def gen_data(special_op_lists, analysis_name):
all_ops = get_ops_for_key(None)
composite_ops = get_ops_for_key("CompositeImplicitAutograd")
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(
open("../../aten/src/ATen/native/native_functions.yaml").read(),
Loader=yaml.CLoader,
)
annotated_ops = {
a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
}
from collections import defaultdict
uniq_ops = []
uniq_names = set()
overload_types = defaultdict(list)
cnt = 0
for op in ops:
func_str = op["func"]
name = func_str[: func_str.index("(")]
if "." in name:
uniq_name = name[: name.index(".")]
overload_types[name[name.index(".") + 1 :]].append(name)
else:
uniq_name = name
op["name"] = uniq_name
full_name = func_str[: func_str.index("(")]
op["full_name"] = full_name
ret_type = func_str[func_str.index("->") + 3 :]
op["ret_type"] = ret_type
cnt += 1
if uniq_name in uniq_names:
continue
uniq_names.add(uniq_name)
uniq_ops.append(op)
def annotate_ops(ops, is_unique):
categorization = defaultdict(int)
for op in ops:
if op["name"][-1] == "_":
categorization["inplace"] += 1
op["meta"] = "inplace"
continue
if not is_unique and "a!" in op["func"].lower():
categorization["out"] += 1
op["meta"] = "out"
continue
if "conv" in op["name"]:
categorization["conv"] += 1
op["meta"] = "conv"
continue
if "pool" in op["name"]:
categorization["pool"] += 1
op["meta"] = "pool"
continue
if "backward" in op["name"]:
categorization["backward"] += 1
op["meta"] = "backward"
continue
if op["name"][0] == "_" and op["name"][1] != "_":
categorization["private"] += 1
op["meta"] = "private"
continue
if "batch_norm" in op["name"]:
categorization["batch_norm"] += 1
op["meta"] = "batch_norm"
continue
if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]:
categorization["non_tensor"] += 1
op["meta"] = "non_tensor"
continue
if (
"cudnn" in op["name"]
or "mkldnn" in op["name"]
or "miopen" in op["name"]
or "native" in op["name"]
or "thnn" in op["name"]
or "slow" in op["name"]
):
categorization["backend"] += 1
op["meta"] = "backend"
continue
if op["name"] in annotated_ops:
categorization["core"] += 1
op["meta"] = "core " + annotated_ops[op["name"]]
continue
categorization["core"] += 1
op["meta"] = "core unknown"
return categorization
annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", "w") as f:
for op in ops:
info = [
op["full_name"],
op["meta"],
op["full_name"] not in noncomposite_ops,
] + [check(op) for check in special_op_lists]
f.write(",".join([str(i) for i in info]) + "\n")
def name_check(lst):
return lambda x: x["name"] in lst
def full_name_check(lst):
return lambda x: x["full_name"] in lst
# Generates batching rule data
gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt")
def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string
def remove_prefix(input_string, prefix):
if prefix and input_string.startswith(prefix):
return input_string[len(prefix) :]
return input_string
if True:
with open("run_ops.txt") as f:
opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open("count_ops.txt") as f:
opinfo_counts = [i.strip() for i in f.readlines()]
opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
def count_fn(x):
return opinfo_counts[x["full_name"]]
with open("run_decompositions.txt") as f:
decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f.readlines()]
with open("public_api") as f:
ref_api = [i.strip() for i in f.readlines()]
def has_ref_impl(x):
name = x["name"]
for prefix in ["linalg_", "special_"]:
name = remove_prefix(name, prefix)
prefixes = ["nn.functional", "fft", "special", "linalg"]
return (
any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
)
gen_data(
[
full_name_check(opinfo_ops),
full_name_check(decomposed_ops),
count_fn,
has_ref_impl,
],
"decompositions.txt",
)