| import argparse |
| import torch |
| from os import path |
| import json |
| |
| # Import all utils so that getattr below can find them |
| from torch.utils import bottleneck, checkpoint, model_zoo |
| |
| all_submod_list = [ |
| "", |
| "nn", |
| "nn.functional", |
| "nn.init", |
| "optim", |
| "autograd", |
| "cuda", |
| "sparse", |
| "distributions", |
| "fft", |
| "linalg", |
| "jit", |
| "distributed", |
| "futures", |
| "onnx", |
| "random", |
| "utils.bottleneck", |
| "utils.checkpoint", |
| "utils.data", |
| "utils.model_zoo", |
| ] |
| |
| def get_content(submod): |
| mod = torch |
| if submod: |
| submod = submod.split(".") |
| for name in submod: |
| mod = getattr(mod, name) |
| content = dir(mod) |
| return content |
| |
| def namespace_filter(data): |
| out = set(d for d in data if d[0] != "_") |
| return out |
| |
| def run(args, submod): |
| print(f"## Processing torch.{submod}") |
| prev_filename = f"prev_data_{submod}.json" |
| new_filename = f"new_data_{submod}.json" |
| |
| if args.prev_version: |
| content = get_content(submod) |
| with open(prev_filename, "w") as f: |
| json.dump(content, f) |
| print("Data saved for previous version.") |
| elif args.new_version: |
| content = get_content(submod) |
| with open(new_filename, "w") as f: |
| json.dump(content, f) |
| print("Data saved for new version.") |
| else: |
| assert args.compare |
| if not path.exists(prev_filename): |
| raise RuntimeError("Previous version data not collected") |
| |
| if not path.exists(new_filename): |
| raise RuntimeError("New version data not collected") |
| |
| with open(prev_filename, "r") as f: |
| prev_content = set(json.load(f)) |
| |
| with open(new_filename, "r") as f: |
| new_content = set(json.load(f)) |
| |
| if not args.show_all: |
| prev_content = namespace_filter(prev_content) |
| new_content = namespace_filter(new_content) |
| |
| if new_content == prev_content: |
| print("Nothing changed.") |
| print("") |
| else: |
| print("Things that were added:") |
| print(new_content - prev_content) |
| print("") |
| |
| print("Things that were removed:") |
| print(prev_content - new_content) |
| print("") |
| |
| def main(): |
| parser = argparse.ArgumentParser(description='Tool to check namespace content changes') |
| |
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument('--prev-version', action='store_true') |
| group.add_argument('--new-version', action='store_true') |
| group.add_argument('--compare', action='store_true') |
| |
| group = parser.add_mutually_exclusive_group() |
| group.add_argument('--submod', default='', help='part of the submodule to check') |
| group.add_argument('--all-submod', action='store_true', help='collects data for all main submodules') |
| |
| parser.add_argument('--show-all', action='store_true', help='show all the diff, not just public APIs') |
| |
| |
| args = parser.parse_args() |
| |
| if args.all_submod: |
| submods = all_submod_list |
| else: |
| submods = [args.submod] |
| |
| for mod in submods: |
| run(args, mod) |
| |
| |
| if __name__ == '__main__': |
| main() |