blob: a64fb45591f2ba087fb4d6a3c42c9b7ed7fc6ec4 [file] [log] [blame] [edit]
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import array
import codecs
import copy
import glob
import io
import os
import re
import sys
from itertools import product
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
import textwrap
from dataclasses import dataclass
from typing import Any
import yaml
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[assignment, misc]
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"
DEFAULT_ENV: dict[str, Any] = {
"PRECISION": "highp",
"FLOAT_IMAGE_FORMAT": "rgba16f",
"INT_IMAGE_FORMAT": "rgba32i",
"UINT_IMAGE_FORMAT": "rgba32ui",
}
TYPES_ENV: dict[str, Any] = {
"IMAGE_FORMAT": {
"float": "rgba32f",
"half": "rgba16f",
"int": "rgba32i",
"uint": "rgba32ui",
"int8": "rgba8i",
"uint8": "rgba8ui",
},
"IMAGE_T": {
3: {
"float": "image3D",
"half": "image3D",
"int": "iimage3D",
"uint": "uimage3D",
},
2: {
"float": "image2D",
"half": "image2D",
"int": "iimage2D",
"uint": "uimage2D",
},
},
"SAMPLER_T": {
3: {
"float": "sampler3D",
"half": "sampler3D",
"int": "isampler3D",
"uint": "usampler3D",
},
2: {
"float": "sampler2D",
"half": "sampler2D",
"int": "isampler2D",
"uint": "usampler2D",
},
},
"VEC4_T": {
"float": "vec4",
"half": "vec4",
"int": "ivec4",
"uint": "uvec4",
"int8": "vec4",
"uint8": "uvec4",
},
"T": {
"float": "float",
"half": "float",
"int": "int",
"uint": "uint",
"int8": "int",
"uint8": "uint8",
},
}
FUNCS_ENV: dict[str, Any] = {
"GET_POS": {
3: lambda pos: pos,
2: lambda pos: f"{pos}.xy",
}
}
def extract_filename(path: str, keep_ext: bool = True) -> Any:
if keep_ext:
return os.path.basename(path)
else:
return os.path.basename(path).split(".")[0]
############################
# SPIR-V Code Generation #
############################
# https://gist.github.com/pypt/94d747fe5180851196eb
class UniqueKeyLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
f"expected a mapping node, but found {node.id}",
node.start_mark,
)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try:
hash(key)
except TypeError as e:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key ",
key_node.start_mark,
) from e
# check for duplicate keys
if key in mapping:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found duplicate key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call]
mapping[key] = value
return mapping
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
def extract_leading_whitespace(line: str) -> str:
match = re.match(r"\s*", line)
return match.group(0) if match else ""
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
def escape(line: str) -> str:
output_parts = []
while "${" in line:
start_pos = line.index("${")
end_pos = line.index("}", start_pos + 2)
if start_pos != 0:
output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"')
output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")")
line = line[end_pos + 1 :]
if line:
output_parts.append('"' + line.replace('"', '\\"') + '"')
return " + ".join(output_parts)
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
def preprocess(
input_text: str, variables: dict[str, Any], input_path: str = "codegen"
) -> str:
input_lines = input_text.splitlines()
python_lines = []
blank_lines = 0
last_indent = ""
# List of tuples (total_index, python_indent)
indent_stack = [("", "")]
# Indicates whether this is the first line inside Python
# code block (i.e. for, while, if, elif, else)
python_block_start = True
for i, input_line in enumerate(input_lines):
if input_line == "":
blank_lines += 1
continue
# Skip lint markers.
if "LINT" in input_line:
continue
input_indent = extract_leading_whitespace(input_line)
if python_block_start:
assert input_indent.startswith(last_indent)
extra_python_indent = input_indent[len(last_indent) :]
python_indent = indent_stack[-1][1] + extra_python_indent
indent_stack.append((input_indent, python_indent))
assert input_indent.startswith(indent_stack[-1][0])
else:
while not input_indent.startswith(indent_stack[-1][0]):
del indent_stack[-1]
python_block_start = False
python_indent = indent_stack[-1][1]
stripped_input_line = input_line.strip()
if stripped_input_line.startswith("$") and not stripped_input_line.startswith(
"${"
):
if stripped_input_line.endswith(":"):
python_block_start = True
while blank_lines != 0:
python_lines.append(python_indent + "print(file=OUT_STREAM)")
blank_lines -= 1
python_lines.append(python_indent + stripped_input_line.replace("$", ""))
else:
assert input_line.startswith(python_indent)
while blank_lines != 0:
python_lines.append(python_indent + "print(file=OUT_STREAM)")
blank_lines -= 1
python_lines.append(
python_indent
+ f"print({escape(input_line[len(python_indent) :])}, file=OUT_STREAM)"
)
last_indent = input_indent
while blank_lines != 0:
python_lines.append(python_indent + "print(file=OUT_STREAM)")
blank_lines -= 1
exec_globals = dict(variables)
output_stream = io.StringIO()
exec_globals["OUT_STREAM"] = output_stream
python_bytecode = compile("\n".join(python_lines), input_path, "exec")
exec(python_bytecode, exec_globals)
return output_stream.getvalue()
class SPVGenerator:
def __init__(
self,
src_dir_paths: str | list[str],
env: dict[Any, Any],
glslc_path: str | None,
) -> None:
if isinstance(src_dir_paths, str):
self.src_dir_paths = [src_dir_paths]
else:
self.src_dir_paths = src_dir_paths
self.env = env
self.glslc_path = glslc_path
self.glsl_src_files: dict[str, str] = {}
self.template_yaml_files: list[str] = []
self.addSrcAndYamlFiles(self.src_dir_paths)
self.shader_template_params: dict[Any, Any] = {}
for yaml_file in self.template_yaml_files:
self.parseTemplateYaml(yaml_file)
self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
self.constructOutputMap()
def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
for src_path in src_dir_paths:
# Collect glsl source files
glsl_files = glob.glob(
os.path.join(src_path, "**", "*.glsl*"), recursive=True
)
for file in glsl_files:
if len(file) > 1:
self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
# Collect template yaml files
yaml_files = glob.glob(
os.path.join(src_path, "**", "*.yaml"), recursive=True
)
for file in yaml_files:
if len(file) > 1:
self.template_yaml_files.append(file)
def generateVariantCombinations(
self,
iterated_params: dict[str, Any],
exclude_params: set[str] | None = None,
) -> list[Any]:
if exclude_params is None:
exclude_params = set()
all_iterated_params = []
for param_name, value_list in iterated_params.items():
if param_name not in exclude_params:
param_values = []
for value in value_list:
suffix = value.get("SUFFIX", value["VALUE"])
param_values.append((param_name, suffix, value["VALUE"]))
all_iterated_params.append(param_values)
return list(product(*all_iterated_params))
def parseTemplateYaml(self, yaml_file: str) -> None:
with open(yaml_file) as f:
contents = yaml.load(f, Loader=UniqueKeyLoader)
for template_name, params_dict in contents.items():
if template_name in self.shader_template_params:
raise KeyError(f"{template_name} params file is defined twice")
default_params = params_dict["parameter_names_with_default_values"]
params_names = set(default_params.keys()).union({"NAME"})
self.shader_template_params[template_name] = []
default_iterated_params = params_dict.get(
"generate_variant_forall", None
)
for variant in params_dict["shader_variants"]:
variant_params_names = set(variant.keys())
invalid_keys = (
variant_params_names
- params_names
- {"generate_variant_forall"}
)
assert len(invalid_keys) == 0
iterated_params = variant.get(
"generate_variant_forall", default_iterated_params
)
if iterated_params is not None:
variant_combinations = self.generateVariantCombinations(
iterated_params, variant_params_names
)
for combination in variant_combinations:
default_params_copy = copy.deepcopy(default_params)
for key in variant:
if key != "generate_variant_forall":
default_params_copy[key] = variant[key]
variant_name = variant["NAME"]
for param_value in combination:
default_params_copy[param_value[0]] = param_value[2]
if len(param_value[1]) > 0:
variant_name = f"{variant_name}_{param_value[1]}"
default_params_copy["NAME"] = variant_name
self.shader_template_params[template_name].append(
default_params_copy
)
else:
default_params_copy = copy.deepcopy(default_params)
for key in variant:
default_params_copy[key] = variant[key]
self.shader_template_params[template_name].append(
default_params_copy
)
def create_shader_params(
self, variant_params: dict[str, Any] | None = None
) -> dict[str, str]:
if variant_params is None:
variant_params = {}
shader_params = copy.deepcopy(self.env)
for key, value in variant_params.items():
shader_params[key] = value
shader_dtype = shader_params.get("DTYPE", "float")
if shader_dtype == "int":
shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
elif shader_dtype == "uint":
shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
elif shader_dtype == "int32":
shader_params["FORMAT"] = "rgba32i"
elif shader_dtype == "uint32":
shader_params["FORMAT"] = "rgba32ui"
elif shader_dtype == "int8":
shader_params["FORMAT"] = "rgba8i"
elif shader_dtype == "uint8":
shader_params["FORMAT"] = "rgba8ui"
elif shader_dtype == "float32":
shader_params["FORMAT"] = "rgba32f"
# Assume float by default
else:
shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
return shader_params
def constructOutputMap(self) -> None:
for shader_name, params in self.shader_template_params.items():
for variant in params:
source_glsl = self.glsl_src_files[shader_name]
self.output_shader_map[variant["NAME"]] = (
source_glsl,
self.create_shader_params(variant),
)
for shader_name, source_glsl in self.glsl_src_files.items():
if shader_name not in self.shader_template_params:
self.output_shader_map[shader_name] = (
source_glsl,
self.create_shader_params(),
)
def generateSPV(self, output_dir: str) -> dict[str, str]:
output_file_map = {}
for shader_name in self.output_shader_map:
source_glsl = self.output_shader_map[shader_name][0]
shader_params = self.output_shader_map[shader_name][1]
with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
input_text = input_file.read()
output_text = preprocess(input_text, shader_params)
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
output_file.write(output_text)
# If no GLSL compiler is specified, then only write out the generated GLSL shaders.
# This is mainly for testing purposes.
if self.glslc_path is not None:
spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
cmd = [
self.glslc_path,
"-fshader-stage=compute",
glsl_out_path,
"-o",
spv_out_path,
"--target-env=vulkan1.0",
"-Werror",
] + [
arg
for src_dir_path in self.src_dir_paths
for arg in ["-I", src_dir_path]
]
print("glslc cmd:", cmd)
subprocess.check_call(cmd)
output_file_map[spv_out_path] = glsl_out_path
return output_file_map
##############################################
# Shader Info and Shader Registry Handling #
##############################################
@dataclass
class ShaderInfo:
tile_size: list[int]
layouts: list[str]
weight_storage_type: str = ""
bias_storage_type: str = ""
register_for: tuple[str, list[str]] | None = None
def getName(filePath: str) -> str:
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
def isDescriptorLine(lineStr: str) -> bool:
descriptorLineId = r"^layout\(set"
return re.search(descriptorLineId, lineStr) is not None
def isTileSizeLine(lineStr: str) -> bool:
tile_size_id = r"^ \* TILE_SIZE = \("
return re.search(tile_size_id, lineStr) is not None
def findTileSizes(lineStr: str) -> list[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr)
if matches is None:
raise AssertionError("matches is None in findTileSizes")
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
def isWeightStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None
def getWeightStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getWeightStorageType")
return matches.group(1)
def isBiasStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* BIAS_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None
def getBiasStorageType(lineStr: str) -> str:
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
if matches is None:
raise AssertionError("matches is None in getBiasStorageType")
return matches.group(1)
def isRegisterForLine(lineStr: str) -> bool:
# Check for Shader Name and a list of at least one Registry Key
register_for_id = (
r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
)
return re.search(register_for_id, lineStr) is not None
def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
register_for_pattern = r"'([A-Za-z0-9_]+)'"
matches = re.findall(register_for_pattern, lineStr)
if matches is None:
raise AssertionError("matches is None in getBiasStorageType")
matches_list = list(matches)
return (matches_list[0], matches_list[1:])
typeIdMapping = {
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
}
storageTypeToEnum = {
"TEXTURE_2D": "api::StorageType::TEXTURE_2D",
"TEXTURE_3D": "api::StorageType::TEXTURE_3D",
"BUFFER": "api::StorageType::BUFFER",
"": "api::StorageType::UNKNOWN",
}
def determineDescriptorType(lineStr: str) -> str:
for identifier, typeNum in typeIdMapping.items():
if re.search(identifier, lineStr):
return typeNum
raise AssertionError(
"No matching descriptor type for " + lineStr + " in determineDescriptorType"
)
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info = ShaderInfo([], [], "")
with open(srcFilePath) as srcFile:
for line in srcFile:
if isDescriptorLine(line):
shader_info.layouts.append(determineDescriptorType(line))
if isTileSizeLine(line):
shader_info.tile_size = findTileSizes(line)
if isWeightStorageTypeLine(line):
shader_info.weight_storage_type = getWeightStorageType(line)
if isBiasStorageTypeLine(line):
shader_info.bias_storage_type = getBiasStorageType(line)
if isRegisterForLine(line):
shader_info.register_for = findRegisterFor(line)
return shader_info
##########################
# C++ File Generation #
#########################
cpp_template = """
#include <ATen/native/vulkan/api/ShaderRegistry.h>
#include <stdint.h>
#include <vector>
using namespace at::native::vulkan;
namespace at {{
namespace native {{
namespace vulkan {{
namespace {{
{spv_bin_arrays}
}}
static void register_fn() {{
{register_shader_infos}
{shader_info_registry}
}}
static const api::ShaderRegisterInit register_shaders(&register_fn);
}}
}}
}}
"""
def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
name,
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
)
return sizeBytes, spv_bin_str
def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
tile_size = (
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
if (len(shader_info.tile_size) > 0)
else "std::vector<uint32_t>()"
)
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
shader_info_args = [
f'"{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
tile_size,
storageTypeToEnum[shader_info.weight_storage_type],
storageTypeToEnum[shader_info.bias_storage_type],
]
shader_info_str = textwrap.indent(
"api::shader_registry().register_shader(\n api::ShaderInfo(\n{args}));\n".format(
args=textwrap.indent(",\n".join(shader_info_args), " "),
),
" ",
)
return shader_info_str
def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
if shader_info.register_for is None:
return ""
(op_name, registry_keys) = shader_info.register_for
for registry_key in registry_keys:
shader_dispatch_str = textwrap.indent(
f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
" ",
)
return shader_dispatch_str
def genCppFiles(
spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
) -> None:
spv_bin_strs = []
register_shader_info_strs = []
shader_registry_strs = []
for spvPath, srcPath in spv_files.items():
name = getName(spvPath).replace("_spv", "")
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
spv_bin_strs.append(spv_bin_str)
shader_info = getShaderInfo(srcPath)
register_shader_info_strs.append(
generateShaderInfoStr(shader_info, name, sizeBytes)
)
if shader_info.register_for is not None:
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
spv_bin_arrays = "\n".join(spv_bin_strs)
register_shader_infos = "\n".join(register_shader_info_strs)
shader_info_registry = "\n".join(shader_registry_strs)
cpp = cpp_template.format(
spv_bin_arrays=spv_bin_arrays,
register_shader_infos=register_shader_infos,
shader_info_registry=shader_info_registry,
)
with open(cpp_src_file_path, "w") as fw:
fw.write(cpp)
##########
# Main #
##########
def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
d = {}
if items:
for item in items:
tokens = item.split("=")
key = tokens[0].strip()
value = tokens[1].strip()
d[key] = value
return d
def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-i",
"--glsl-paths",
nargs="+",
help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
default=["."],
)
parser.add_argument("-c", "--glslc-path", required=True, help="")
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
parser.add_argument("-o", "--output-path", required=True, help="")
parser.add_argument(
"--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
)
options = parser.parse_args()
DEFAULT_ENV.update(TYPES_ENV)
DEFAULT_ENV.update(FUNCS_ENV)
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
env[key] = value
if not os.path.exists(options.output_path):
os.makedirs(options.output_path)
if not os.path.exists(options.tmp_dir_path):
os.makedirs(options.tmp_dir_path)
shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
genCppFiles(
output_spv_files,
f"{options.output_path}/{CPP_H_NAME}",
f"{options.output_path}/{CPP_SRC_NAME}",
)
return 0
def invoke_main() -> None:
sys.exit(main(sys.argv))
if __name__ == "__main__":
invoke_main() # pragma: no cover