[Vulkan] Remove GLSL Code Gen (#91912)

@bypass-github-export-checks

GLSL Code Gen is not used, so this diff removes
- GLSL parts of ShaderSource
- Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself
- gen_vulkan_glsl script

Plus some additional refactoring

Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91912
Approved by: https://github.com/mcr229
diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl
index 58a49fd..3d685f4 100644
--- a/tools/BUCK.bzl
+++ b/tools/BUCK.bzl
@@ -211,18 +211,6 @@
         srcs = [
             "gen_vulkan_spv.py",
         ],
-        base_module = "",
-        deps = [
-            torchgen_deps,
-            ":gen_aten_vulkan_glsl_lib",
-        ],
-    )
-
-    python_library(
-        name = "gen_aten_vulkan_glsl_lib",
-        srcs = [
-            "gen_vulkan_glsl.py",
-        ],
         base_module = "tools",
         deps = [
             torchgen_deps,
@@ -231,12 +219,11 @@
 
     python_binary(
         name = "gen_aten_vulkan_spv_bin",
-        main_module = "gen_vulkan_spv",
+        main_module = "tools.gen_vulkan_spv",
         visibility = [
             "PUBLIC",
         ],
         deps = [
-            ":gen_aten_vulkan_glsl_lib",
             ":gen_aten_vulkan_spv_lib",
         ],
     )
@@ -249,7 +236,6 @@
         contacts = contacts,
         visibility = ["PUBLIC"],
         deps = [
-            ":gen_aten_vulkan_glsl_lib",
             ":gen_aten_vulkan_spv_lib",
         ],
     )
diff --git a/tools/gen_vulkan_glsl.py b/tools/gen_vulkan_glsl.py
deleted file mode 100644
index 6d89da0..0000000
--- a/tools/gen_vulkan_glsl.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import copy
-import os
-
-from collections import OrderedDict
-
-import yaml
-from torchgen.code_template import CodeTemplate
-from yaml.constructor import ConstructorError
-from yaml.nodes import MappingNode
-
-try:
-    from yaml import CLoader as Loader
-except ImportError:
-    from yaml import Loader  # type: ignore[misc]
-
-# https://gist.github.com/pypt/94d747fe5180851196eb
-class UniqueKeyLoader(Loader):
-    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
-        if not isinstance(node, MappingNode):
-            raise ConstructorError(
-                None,
-                None,
-                "expected a mapping node, but found %s" % node.id,
-                node.start_mark,
-            )
-        mapping = {}
-        for key_node, value_node in node.value:
-            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
-            try:
-                hash(key)
-            except TypeError as e:
-                raise ConstructorError(
-                    "while constructing a mapping",
-                    node.start_mark,
-                    "found unacceptable key ",
-                    key_node.start_mark,
-                ) from e
-            # check for duplicate keys
-            if key in mapping:
-                raise ConstructorError(
-                    "while constructing a mapping",
-                    node.start_mark,
-                    "found duplicate key",
-                    key_node.start_mark,
-                )
-            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
-            mapping[key] = value
-        return mapping
-
-
-class GLSLGenerator(object):
-    standard_header = """
-#version 450 core
-#define PRECISION $precision
-#define FORMAT $format
-
-"""
-
-    def __init__(self):  # type: ignore[no-untyped-def]
-        self.ops_template_params = {}
-
-    def add_params_yaml(self, parameters_yaml_file):  # type: ignore[no-untyped-def]
-        all_template_params = OrderedDict()
-        with open(parameters_yaml_file, "r") as f:
-            contents = yaml.load(f, Loader=UniqueKeyLoader)
-            for key in contents:
-                all_template_params[key] = contents[key]
-        self.validate_and_construct_op_params(all_template_params)  # type: ignore[no-untyped-call]
-
-    def validate_and_construct_op_params(self, all_template_params):  # type: ignore[no-untyped-def]
-        for op in all_template_params:
-            if op in self.ops_template_params:
-                raise KeyError(f"{op} params file has already been parsed")
-            op_params_default_vals = all_template_params[op][
-                "parameter_names_with_default_values"
-            ]
-            template_params_set = set(op_params_default_vals.keys())
-            self.ops_template_params[op] = []
-            self.ops_template_params[op].append(op_params_default_vals)
-            op_template_params_values = all_template_params[op]["parameter_values"]
-            for param_vals in op_template_params_values:
-                param_vals_set = set(param_vals.keys())
-                invalid_keys = param_vals_set - template_params_set
-                if (len(invalid_keys)) > 0:
-                    raise KeyError(f"Invalid keys {invalid_keys} are found")
-                param_vals_copy = copy.deepcopy(op_params_default_vals)
-                for key in param_vals:
-                    param_vals_copy[key] = param_vals[key]
-                self.ops_template_params[op].append(param_vals_copy)
-
-    def generate(self, glsl_template_in, out_dir):  # type: ignore[no-untyped-def]
-        glsl_template_name = os.path.basename(glsl_template_in)
-        op_name, extension_name = glsl_template_name.split(".")
-        if extension_name != "glslt":
-            raise TypeError(f"invalid file type for glsl template {extension_name}")
-        if op_name not in self.ops_template_params:
-            raise KeyError(f"{op_name} params have not been populated")
-        code_template = CodeTemplate.from_file(glsl_template_in)
-        for template_params in self.ops_template_params[op_name]:
-            content = GLSLGenerator.standard_header
-            param_vals_string = "x".join([str(i) for i in template_params.values()])
-            output_file_name = op_name + "_" + param_vals_string + ".glsl"
-            content += code_template.substitute(template_params)
-            output_file = os.path.join(out_dir, output_file_name)
-            with open(output_file, "w") as f:
-                f.write(content)
-
-
-# Remove this
-if __name__ == "__main__":
-    pass
diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py
index db4ed96..92f6a5a 100644
--- a/tools/gen_vulkan_spv.py
+++ b/tools/gen_vulkan_spv.py
@@ -2,21 +2,122 @@
 
 import argparse
 import array
+import copy
 import glob
 import os
 import re
 import sys
 import subprocess
+import yaml
+from collections import OrderedDict
 from torchgen.code_template import CodeTemplate
 from dataclasses import dataclass
-from typing import List
+from typing import Any, Dict, List
+from yaml.constructor import ConstructorError
+from yaml.nodes import MappingNode
 
-from tools.gen_vulkan_glsl import GLSLGenerator
+try:
+    from yaml import CLoader as Loader
+except ImportError:
+    from yaml import Loader  # type: ignore[misc]
 
 H_NAME = "spv.h"
 CPP_NAME = "spv.cpp"
 DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"}
 
+# https://gist.github.com/pypt/94d747fe5180851196eb
+class UniqueKeyLoader(Loader):
+    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
+        if not isinstance(node, MappingNode):
+            raise ConstructorError(
+                None,
+                None,
+                "expected a mapping node, but found %s" % node.id,
+                node.start_mark,
+            )
+        mapping = {}
+        for key_node, value_node in node.value:
+            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
+            try:
+                hash(key)
+            except TypeError as e:
+                raise ConstructorError(
+                    "while constructing a mapping",
+                    node.start_mark,
+                    "found unacceptable key ",
+                    key_node.start_mark,
+                ) from e
+            # check for duplicate keys
+            if key in mapping:
+                raise ConstructorError(
+                    "while constructing a mapping",
+                    node.start_mark,
+                    "found duplicate key",
+                    key_node.start_mark,
+                )
+            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
+            mapping[key] = value
+        return mapping
+
+
+class VulkanShaderGenerator(object):
+    standard_header = """
+#version 450 core
+#define PRECISION $precision
+#define FORMAT $format
+
+"""
+
+    def __init__(self: "VulkanShaderGenerator") -> None:
+        self.ops_template_params: Dict[Any, Any] = {}
+
+    def add_params_yaml(self, parameters_yaml_file):  # type: ignore[no-untyped-def]
+        all_template_params = OrderedDict()
+        with open(parameters_yaml_file, "r") as f:
+            contents = yaml.load(f, Loader=UniqueKeyLoader)
+            for key in contents:
+                all_template_params[key] = contents[key]
+        self.validate_and_construct_op_params(all_template_params)  # type: ignore[no-untyped-call]
+
+    def validate_and_construct_op_params(self, all_template_params):  # type: ignore[no-untyped-def]
+        for op in all_template_params:
+            if op in self.ops_template_params:
+                raise KeyError(f"{op} params file has already been parsed")
+            op_params_default_vals = all_template_params[op][
+                "parameter_names_with_default_values"
+            ]
+            template_params_set = set(op_params_default_vals.keys())
+            self.ops_template_params[op] = []
+            self.ops_template_params[op].append(op_params_default_vals)
+            op_template_params_values = all_template_params[op]["parameter_values"]
+            for param_vals in op_template_params_values:
+                param_vals_set = set(param_vals.keys())
+                missing_keys = template_params_set - param_vals_set
+                invalid_keys = param_vals_set - template_params_set
+                if (len(invalid_keys)) > 0:
+                    raise KeyError(f"Invalid keys {invalid_keys} are found")
+                param_vals_copy = copy.deepcopy(op_params_default_vals)
+                for key in param_vals:
+                    param_vals_copy[key] = param_vals[key]
+                self.ops_template_params[op].append(param_vals_copy)
+
+    def generate(self, glsl_template_in, out_dir):  # type: ignore[no-untyped-def]
+        glsl_template_name = os.path.basename(glsl_template_in)
+        op_name, extension_name = glsl_template_name.split(".")
+        if extension_name != "glslt":
+            raise TypeError(f"invalid file type for glsl template {extension_name}")
+        if op_name not in self.ops_template_params:
+            raise KeyError(f"{op_name} params have not been populated")
+        code_template = CodeTemplate.from_file(glsl_template_in)
+        for template_params in self.ops_template_params[op_name]:
+            content = VulkanShaderGenerator.standard_header
+            param_vals_string = "x".join([str(i) for i in template_params.values()])
+            output_file_name = op_name + "_" + param_vals_string + ".glsl"
+            content += code_template.substitute(template_params)
+            output_file = os.path.join(out_dir, output_file_name)
+            with open(output_file, "w") as f:
+                f.write(content)
+
 
 @dataclass
 class ShaderInfo:
@@ -25,38 +126,44 @@
     weight_storage_type: str = ""
     bias_storage_type: str = ""
 
-def getName(filePath):
+def getName(filePath: str) -> str:
     return os.path.basename(filePath).replace("/", "_").replace(".", "_")
 
-def isDescriptorLine(lineStr):
+def isDescriptorLine(lineStr: str) -> bool:
     descriptorLineId = r"^layout\(set"
-    return re.search(descriptorLineId, lineStr)
+    return re.search(descriptorLineId, lineStr) is not None
 
-def isTileSizeLine(lineStr):
+def isTileSizeLine(lineStr: str) -> bool:
     tile_size_id = r"^ \* TILE_SIZE = \("
-    return re.search(tile_size_id, lineStr)
+    return re.search(tile_size_id, lineStr) is not None
 
-def findTileSizes(lineStr):
+def findTileSizes(lineStr: str) -> List[int]:
     tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
     matches = re.search(tile_size_id, lineStr)
+    if matches is None:
+        raise AssertionError("matches is None in findTileSizes")
     return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
 
-def isWeightStorageTypeLine(lineStr):
+def isWeightStorageTypeLine(lineStr: str) -> bool:
     weight_storage_id = r"^ \* WEIGHT_STORAGE = "
-    return re.search(weight_storage_id, lineStr)
+    return re.search(weight_storage_id, lineStr) is not None
 
-def getWeightStorageType(lineStr):
+def getWeightStorageType(lineStr: str) -> str:
     weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
     matches = re.search(weight_storage_id, lineStr)
+    if matches is None:
+        raise AssertionError("matches is None in getWeightStorageType")
     return matches.group(1)
 
-def isBiasStorageTypeLine(lineStr):
+def isBiasStorageTypeLine(lineStr: str) -> bool:
     weight_storage_id = r"^ \* BIAS_STORAGE = "
-    return re.search(weight_storage_id, lineStr)
+    return re.search(weight_storage_id, lineStr) is not None
 
-def getBiasStorageType(lineStr):
+def getBiasStorageType(lineStr: str) -> str:
     weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
     matches = re.search(weight_storage_id, lineStr)
+    if matches is None:
+        raise AssertionError("matches is None in getBiasStorageType")
     return matches.group(1)
 
 typeIdMapping = {
@@ -73,12 +180,15 @@
     "": "api::StorageType::UNKNOWN",
 }
 
-def determineDescriptorType(lineStr):
+def determineDescriptorType(lineStr: str) -> str:
     for identifier, typeNum in typeIdMapping.items():
         if re.search(identifier, lineStr):
             return typeNum
+    raise AssertionError(
+        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
+    )
 
-def getShaderInfo(srcFilePath):
+def getShaderInfo(srcFilePath: str) -> ShaderInfo:
     shader_info = ShaderInfo([], [], "")
     with open(srcFilePath, 'r') as srcFile:
         for line in srcFile:
@@ -93,14 +203,14 @@
 
     return shader_info
 
-def genGLSLFromGLSLT(src_dir_path, tmp_dir_path):
+def genGLSLFromGLSLT(src_dir_path: str, tmp_dir_path: str) -> None:
     template_dir_path = os.path.join(src_dir_path, "templates")
     vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True)
     parameter_yaml_files = []
     for f in vexs:
         if len(f) > 1:
             parameter_yaml_files.append(f)
-    generator = GLSLGenerator()
+    generator = VulkanShaderGenerator()
     for params_yaml in parameter_yaml_files:
         generator.add_params_yaml(params_yaml)  # type: ignore[no-untyped-call]
 
@@ -113,9 +223,20 @@
     for glslt in templateSrcPaths:
         generator.generate(glslt, tmp_dir_path)  # type: ignore[no-untyped-call]
 
-def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
-    print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
-        hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))
+
+def genCppH(
+    hFilePath: str,
+    cppFilePath: str,
+    srcDirPath: str,
+    glslcPath: str,
+    tmpDirPath: str,
+    env: Dict[Any, Any],
+) -> None:
+    print(
+        "hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format(
+            hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath
+        )
+    )
 
     vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
     templateSrcPaths = []
@@ -142,8 +263,8 @@
         codeTemplate = CodeTemplate.from_file(templateSrcPath)
         srcPath = tmpDirPath + "/" + name + ".glsl"
         content = codeTemplate.substitute(env)
-        with open(srcPath, 'w') as f:
-            f.write(content)
+        with open(srcPath, 'w') as fw:
+            fw.write(content)
 
         spvPath = tmpDirPath + "/" + name + ".spv"
         print("spvPath {}".format(spvPath))
@@ -188,8 +309,8 @@
         name = getName(spvPath)
 
         print("spvPath:{}".format(spvPath))
-        with open(spvPath, 'rb') as f:
-            next_bin = array.array('I', f.read())
+        with open(spvPath, 'rb') as fr:
+            next_bin = array.array('I', fr.read())
             sizeBytes = 4 * len(next_bin)
             shader_info_bin_code.append(
                 "const uint32_t {}_bin[] = {{\n  {}\n}};".format(
@@ -231,13 +352,13 @@
     cpp += nsend
     h += nsend
 
-    with open(hFilePath, "w") as f:
-        f.write(h)
-    with open(cppFilePath, "w") as f:
-        f.write(cpp)
+    with open(hFilePath, "w") as fw:
+        fw.write(h)
+    with open(cppFilePath, "w") as fw:
+        fw.write(cpp)
 
 
-def parse_arg_env(items):
+def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
     d = {}
     if items:
         for item in items:
@@ -248,7 +369,7 @@
     return d
 
 
-def main(argv):
+def main(argv: List[str]) -> int:
     parser = argparse.ArgumentParser(description='')
     parser.add_argument(
         '-i',
@@ -294,5 +415,7 @@
         tmpDirPath=options.tmp_dir_path,
         env=env)
 
+    return 0
+
 if __name__ == '__main__':
     sys.exit(main(sys.argv))
diff --git a/tools/test/test_vulkan_codegen.py b/tools/test/test_vulkan_codegen.py
index 8b0b4b3..ae87c27 100644
--- a/tools/test/test_vulkan_codegen.py
+++ b/tools/test/test_vulkan_codegen.py
@@ -2,11 +2,11 @@
 import tempfile
 import unittest
 
-from tools.gen_vulkan_glsl import GLSLGenerator
+from tools.gen_vulkan_spv import VulkanShaderGenerator
 from yaml.constructor import ConstructorError
 
 
-class TestGLSLCodegen(unittest.TestCase):
+class TestVulkanShaderCodegen(unittest.TestCase):
     def test_assert_on_duplicate_key_yaml(self) -> None:
         yaml_with_duplicate_keys = """
 conv2d_pw:
@@ -37,7 +37,7 @@
       TILE_SIZE_Y: 4
 """
 
-        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
         with tempfile.NamedTemporaryFile(mode="w") as fp:
             fp.write(yaml_with_duplicate_keys)
             fp.flush()
@@ -57,7 +57,7 @@
       TILE_SIZE_Z: 2
 """
 
-        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
         with tempfile.NamedTemporaryFile(mode="w") as fp:
             fp.write(yaml_with_key_mismatch)
             fp.flush()
@@ -77,7 +77,7 @@
 x = $TILE_SIZE_X + $TILE_SIZE_Y
 """
 
-        generator = GLSLGenerator()  # type: ignore[no-untyped-call]
+        generator = VulkanShaderGenerator()  # type: ignore[no-untyped-call]
         with tempfile.NamedTemporaryFile(mode="w") as fp:
             fp.write(yaml_with_key_mismatch)
             fp.flush()