blob: 8df048992c01965c164ff37b6b61f13a0cec8880 [file] [log] [blame]
#!/usr/bin/env python3
from __future__ import print_function
import collections
import os
import sys
import logging
BANNER = "Auto-generated by generate-wrappers.py script. Do not modify"
WRAPPER_SRC_NAMES = {
"PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS": None,
"PROD_SCALAR_AARCH32_MICROKERNEL_SRCS" : "defined(__arm__)",
"PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_AARCH64_NEON_MICROKERNEL_SRCS": "defined(__aarch64__)",
"PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__aarch64__)",
"PROD_NEONDOT_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_SSE_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_SSE2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_SSSE3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_SSE41_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_F16C_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_XOP_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_FMA3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
"AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)",
# add additoonal:
"PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"ALL_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)",
"ALL_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"ALL_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"ALL_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_AVX512SKX_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_AVX512VBMI_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_F16C_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_FMA3_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_FP16ARITH_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEON_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEON_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONBF16_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONDOT_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFMA_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONFP16_MICROKERNEL_SRCS':"defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFP16ARITH_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONV8_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_SCALAR_MICROKERNEL_SRCS': "defined(__arm__)",
'ALL_SSE_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSE2_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSE41_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSSE3_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_XOP_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'AARCH32_ASM_MICROKERNEL_SRCS': "defined(__arm__)",
"PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__aarch64__)",
"PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_SCALAR_MICROKERNEL_SRCS": "defined(__arm__)",
}
SRC_NAMES = set([
"OPERATOR_SRCS",
"SUBGRAPH_SRCS",
"LOGGING_SRCS",
"XNNPACK_SRCS",
"HOT_SRCS",
"TABLE_SRCS",
"JIT_SRCS",
"JIT_AARCH32_SRCS",
"JIT_AARCH64_SRCS",
"PROD_SCALAR_PORTABLE_MICROKERNEL_SRCS",
"PROD_SSE_MICROKERNEL_SRCS",
"PROD_SSE2_MICROKERNEL_SRCS",
"PROD_SSSE3_MICROKERNEL_SRCS",
"PROD_SSE41_MICROKERNEL_SRCS",
"PROD_AVX_MICROKERNEL_SRCS",
"PROD_F16C_MICROKERNEL_SRCS",
"PROD_XOP_MICROKERNEL_SRCS",
"PROD_FMA3_MICROKERNEL_SRCS",
"PROD_AVX2_MICROKERNEL_SRCS",
"PROD_AVX512F_MICROKERNEL_SRCS",
"PROD_AVX512SKX_MICROKERNEL_SRCS",
"PROD_SCALAR_MICROKERNEL_SRCS",
"PROD_SCALAR_AARCH32_MICROKERNEL_SRCS",
"PROD_SCALAR_RISCV_MICROKERNEL_SRCS",
"PROD_ARMSIMD32_MICROKERNEL_SRCS",
"PROD_FP16ARITH_MICROKERNEL_SRCS",
"PROD_NEON_MICROKERNEL_SRCS",
"PROD_NEONFP16_MICROKERNEL_SRCS",
"PROD_NEONFMA_MICROKERNEL_SRCS",
"PROD_NEON_AARCH64_MICROKERNEL_SRCS",
"PROD_NEONV8_MICROKERNEL_SRCS",
"PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS",
"PROD_NEONDOT_MICROKERNEL_SRCS",
"PROD_SSE2_MICROKERNEL_SRCS",
"PROD_SSSE3_MICROKERNEL_SRCS",
"PROD_SSE41_MICROKERNEL_SRCS",
"PROD_AVX_MICROKERNEL_SRCS",
"PROD_F16C_MICROKERNEL_SRCS",
"PROD_AVX512VBMI_MICROKERNEL_SRCS",
"PROD_NEONFP16ARITH_MICROKERNEL_SRCS",
# new adding libs:
'ALL_ARMSIMD32_MICROKERNEL_SRCS',
'ALL_AVX_MICROKERNEL_SRCS',
'ALL_AVX2_MICROKERNEL_SRCS',
'ALL_AVX512F_MICROKERNEL_SRCS',
'ALL_AVX512SKX_MICROKERNEL_SRCS',
'ALL_AVX512VBMI_MICROKERNEL_SRCS',
'ALL_F16C_MICROKERNEL_SRCS',
'ALL_FMA3_MICROKERNEL_SRCS',
'ALL_FP16ARITH_MICROKERNEL_SRCS',
'ALL_HEXAGON_MICROKERNEL_SRCS',
'ALL_NEON_MICROKERNEL_SRCS',
'ALL_NEON_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONBF16_MICROKERNEL_SRCS',
'ALL_NEONBF16_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONDOT_MICROKERNEL_SRCS',
'ALL_NEONFMA_MICROKERNEL_SRCS',
'ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONFP16_MICROKERNEL_SRCS',
'ALL_NEONFP16ARITH_MICROKERNEL_SRCS',
'ALL_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONV8_MICROKERNEL_SRCS',
'ALL_SCALAR_MICROKERNEL_SRCS',
'ALL_SSE_MICROKERNEL_SRCS',
'ALL_SSE2_MICROKERNEL_SRCS',
'ALL_SSE41_MICROKERNEL_SRCS',
'ALL_SSSE3_MICROKERNEL_SRCS',
'ALL_WASM_MICROKERNEL_SRCS',
'ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS',
'ALL_WASMSIMD_MICROKERNEL_SRCS',
'ALL_XOP_MICROKERNEL_SRCS',
'AARCH32_ASM_MICROKERNEL_SRCS',
'AARCH64_ASM_MICROKERNEL_SRCS',
])
def handle_singleline_parse(line):
start_index = line.find("(")
end_index = line.find(")")
line = line[start_index+1:end_index]
key_val = line.split(" ")
return key_val[0], key_val[1][4:]
def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
sources = collections.defaultdict(list)
count = 0
with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
lines = cmake.readlines()
i = 0
while i < len(lines):
line = lines[i]
if lines[i].startswith("SET") and "src/" in lines[i]:
name, val = handle_singleline_parse(line)
sources[name].append(val)
i+=1
continue
if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES):
name = line.split('(')[1].strip(' \t\n\r')
i += 1
while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]:
# remove "src/" at the beginning, remove whitespaces and newline
value = lines[i].strip(' \t\n\r')
sources[name].append(value[4:])
i += 1
if i < len(lines) and len(lines[i]) > 4:
# remove "src/" at the beginning, possibly ')' at the end
value = lines[i].strip(' \t\n\r)')
sources[name].append(value[4:])
else:
i += 1
return sources
def gen_wrappers(xnnpack_path):
xnnpack_sources = collections.defaultdict(list)
sources = update_sources(xnnpack_path)
microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
for key in microkernels_sources:
sources[key] = microkernels_sources[key]
for name in WRAPPER_SRC_NAMES:
xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name])
for condition, filenames in xnnpack_sources.items():
print(condition)
for filename in filenames:
filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename)
if not os.path.isdir(os.path.dirname(filepath)):
os.makedirs(os.path.dirname(filepath))
with open(filepath, "w") as wrapper:
print("/* {} */".format(BANNER), file=wrapper)
print(file=wrapper)
# Architecture- or platform-dependent preprocessor flags can be
# defined here. Note: platform_preprocessor_flags can't be used
# because they are ignored by arc focus & buck project.
if condition is None:
print("#include <%s>" % filename, file=wrapper)
else:
# Include source file only if condition is satisfied
print("#if %s" % condition, file=wrapper)
print("#include <%s>" % filename, file=wrapper)
print("#endif /* %s */" % condition, file=wrapper)
# update xnnpack_wrapper_defs.bzl file under the same folder
with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs:
print('"""', file=wrapper_defs)
print(BANNER, file=wrapper_defs)
print('"""', file=wrapper_defs)
for name in WRAPPER_SRC_NAMES:
print('\n' + name + ' = [', file=wrapper_defs)
for file_name in sources[name]:
print(' "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs)
print(']', file=wrapper_defs)
# update xnnpack_src_defs.bzl file under the same folder
with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs:
print('"""', file=src_defs)
print(BANNER, file=src_defs)
print('"""', file=src_defs)
for name in SRC_NAMES:
print('\n' + name + ' = [', file=src_defs)
for file_name in sources[name]:
print(' "XNNPACK/src/{}",'.format(file_name), file=src_defs)
print(']', file=src_defs)
def main(argv):
if argv is None or len(argv) == 0:
gen_wrappers(".")
else:
gen_wrappers(argv[0])
# The first argument is the place where the "xnnpack_wrappers" folder will be created.
# Run it without arguments will generate "xnnpack_wrappers" in the current path.
# The two .bzl files will always be generated in the current path.
if __name__ == "__main__":
main(sys.argv[1:])