blob: db4ed965cf8495e2ba99fde9154c756d10c3f202 [file] [log] [blame]
#!/usr/bin/env python3
import argparse
import array
import glob
import os
import re
import sys
import subprocess
from torchgen.code_template import CodeTemplate
from dataclasses import dataclass
from typing import List
from tools.gen_vulkan_glsl import GLSLGenerator
H_NAME = "spv.h"
CPP_NAME = "spv.cpp"
DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
@dataclass
class ShaderInfo:
tile_size: List[int]
layouts: List[str]
weight_storage_type: str = ""
bias_storage_type: str = ""
def getName(filePath):
return os.path.basename(filePath).replace("/", "_").replace(".", "_")
def isDescriptorLine(lineStr):
descriptorLineId = r"^layout\(set"
return re.search(descriptorLineId, lineStr)
def isTileSizeLine(lineStr):
tile_size_id = r"^ \* TILE_SIZE = \("
return re.search(tile_size_id, lineStr)
def findTileSizes(lineStr):
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr)
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
def isWeightStorageTypeLine(lineStr):
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr)
def getWeightStorageType(lineStr):
weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
return matches.group(1)
def isBiasStorageTypeLine(lineStr):
weight_storage_id = r"^ \* BIAS_STORAGE = "
return re.search(weight_storage_id, lineStr)
def getBiasStorageType(lineStr):
weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
matches = re.search(weight_storage_id, lineStr)
return matches.group(1)
typeIdMapping = {
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
}
storageTypeToEnum = {
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
"BUFFER" : "api::StorageType::BUFFER",
"": "api::StorageType::UNKNOWN",
}
def determineDescriptorType(lineStr):
for identifier, typeNum in typeIdMapping.items():
if re.search(identifier, lineStr):
return typeNum
def getShaderInfo(srcFilePath):
shader_info = ShaderInfo([], [], "")
with open(srcFilePath, 'r') as srcFile:
for line in srcFile:
if isDescriptorLine(line):
shader_info.layouts.append(determineDescriptorType(line))
if isTileSizeLine(line):
shader_info.tile_size = findTileSizes(line)
if isWeightStorageTypeLine(line):
shader_info.weight_storage_type = getWeightStorageType(line)
if isBiasStorageTypeLine(line):
shader_info.bias_storage_type = getBiasStorageType(line)
return shader_info
def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
template_dir_path = os.path.join(src_dir_path, "templates")
vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
parameter_yaml_files = []
for f in vexs:
if len(f) > 1:
parameter_yaml_files.append(f)
generator = GLSLGenerator()
for params_yaml in parameter_yaml_files:
generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call]
vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()
for glslt in templateSrcPaths:
generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call]
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))
vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
templateSrcPaths = []
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()
# Now add glsl files that are generated from templates
genGLSLFromGLSLT(srcDirPath, tmpDirPath)
vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True)
for f in vexs:
if len(f) > 1:
templateSrcPaths.append(f)
templateSrcPaths.sort()
print("templateSrcPaths:{}".format(templateSrcPaths))
spvPaths = {}
for templateSrcPath in templateSrcPaths:
print("templateSrcPath {}".format(templateSrcPath))
name = getName(templateSrcPath).replace("_glsl", "")
print("name {}".format(name))
codeTemplate = CodeTemplate.from_file(templateSrcPath)
srcPath = tmpDirPath + "/" + name + ".glsl"
content = codeTemplate.substitute(env)
with open(srcPath, 'w') as f:
f.write(content)
spvPath = tmpDirPath + "/" + name + ".spv"
print("spvPath {}".format(spvPath))
cmd = [
glslcPath, "-fshader-stage=compute",
srcPath, "-o", spvPath,
"--target-env=vulkan1.0",
"-I", srcDirPath,
"-Werror"
]
print("\nglslc cmd:", cmd)
subprocess.check_call(cmd)
spvPaths[spvPath] = templateSrcPath
h = "#pragma once\n"
h += "#include <stdint.h>\n"
h += "#include <vector>\n"
h += "#include <string>\n"
h += "#include <ATen/native/vulkan/api/Types.h>\n"
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
h += nsbegin
# Forward declaration of ShaderInfo
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
cpp = "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
cpp += "#include <ATen/native/vulkan/api/Shader.h>\n"
cpp += nsbegin
shader_info_bin_code = []
shader_info_cpp_code = []
shader_info_h_code = []
for spvPath, srcPath in spvPaths.items():
name = getName(spvPath)
print("spvPath:{}".format(spvPath))
with open(spvPath, 'rb') as f:
next_bin = array.array('I', f.read())
sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
name,
",\n ".join(str(x) for x in next_bin),
)
)
shader_info = getShaderInfo(srcPath)
tile_size = (
"{{{}}}".format(", ".join(str(x) for x in shader_info.tile_size))
if (len(shader_info.tile_size) > 0)
else "std::vector<uint32_t>()"
)
shader_info_args = [
"\"vulkan.{}\"".format(name.replace("_spv", "")),
"{}_bin".format(name),
str(sizeBytes),
"{{{}}}".format(", ".join(shader_info.layouts)),
tile_size,
storageTypeToEnum[shader_info.weight_storage_type],
storageTypeToEnum[shader_info.bias_storage_type],
]
shader_info_h_code.append("extern const api::ShaderInfo {};".format(name))
shader_info_cpp_code.append(
"const api::ShaderInfo {}(\n {}\n);".format(
name,
",\n ".join(shader_info_args),
),
)
cpp += "namespace {{\n{}\n}} // namespace\n".format("\n".join(shader_info_bin_code))
cpp += "{}\n".format("\n".join(shader_info_cpp_code))
h += "{}\n".format("\n".join(shader_info_h_code))
cpp += nsend
h += nsend
with open(hFilePath, "w") as f:
f.write(h)
with open(cppFilePath, "w") as f:
f.write(cpp)
def parse_arg_env(items):
d = {}
if items:
for item in items:
tokens = item.split("=")
key = tokens[0].strip()
value = tokens[1].strip()
d[key] = value
return d
def main(argv):
parser = argparse.ArgumentParser(description='')
parser.add_argument(
'-i',
'--glsl-path',
help='',
default='.')
parser.add_argument(
'-c',
'--glslc-path',
required=True,
help='')
parser.add_argument(
'-t',
'--tmp-dir-path',
required=True,
help='/tmp')
parser.add_argument(
'-o',
'--output-path',
required=True,
help='')
parser.add_argument(
"--env",
metavar="KEY=VALUE",
nargs='*',
help="Set a number of key-value pairs")
options = parser.parse_args()
env = DEFAULT_ENV
for key, value in parse_arg_env(options.env).items():
env[key] = value
if not os.path.exists(options.output_path):
os.makedirs(options.output_path)
if not os.path.exists(options.tmp_dir_path):
os.makedirs(options.tmp_dir_path)
genCppH(
hFilePath=options.output_path + "/spv.h",
cppFilePath=options.output_path + "/spv.cpp",
srcDirPath=options.glsl_path,
glslcPath=options.glslc_path,
tmpDirPath=options.tmp_dir_path,
env=env)
if __name__ == '__main__':
sys.exit(main(sys.argv))