blob: 6ab4fd433effa30cbdbf8e6935148aa52df814d8 [file] [log] [blame]
"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""
import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
entry_to_pretty_str,
remove_boolean_dispatch_from_name,
)
import os.path
# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
os.path.realpath(os.path.join(__file__, "..")),
"quantization_backend_configs"
)
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
output_path = os.path.join(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt")
with open(output_path, "w") as f:
native_backend_config_dict = get_native_backend_config_dict()
configs = native_backend_config_dict['configs']
def _sort_key_func(entry):
pattern = entry['pattern']
while isinstance(pattern, tuple):
pattern = pattern[-1]
pattern = remove_boolean_dispatch_from_name(pattern)
if not isinstance(pattern, str):
# methods are already strings
pattern = torch.typename(pattern)
# we want
#
# torch.nn.modules.pooling.AdaptiveAvgPool1d
#
# and
#
# torch._VariableFunctionsClass.adaptive_avg_pool1d
#
# to be next to each other, so convert to all lower case
# and remove the underscores, and compare the last part
# of the string
pattern_str_normalized = pattern.lower().replace('_', '')
key = pattern_str_normalized.split('.')[-1]
return key
configs.sort(key=_sort_key_func)
entries = []
for entry in configs:
entries.append(entry_to_pretty_str(entry))
entries = ",\n".join(entries)
f.write(entries)