Run Black on all of tools/

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76089

Approved by: https://github.com/albanD
diff --git a/tools/actions_local_runner.py b/tools/actions_local_runner.py
index 1d2b045..27287f8 100755
--- a/tools/actions_local_runner.py
+++ b/tools/actions_local_runner.py
@@ -59,7 +59,7 @@
     return [line.strip() for line in lines]
 
 
-def find_changed_files(ref_branch : str = "origin/master") -> List[str]:
+def find_changed_files(ref_branch: str = "origin/master") -> List[str]:
     untracked = []
 
     for line in git(["status", "--porcelain"]):
@@ -334,7 +334,7 @@
         return await shell_cmd(script, env=env)
 
 
-def changed_files(ref_branch : str = "origin/master") -> Optional[List[str]]:
+def changed_files(ref_branch: str = "origin/master") -> Optional[List[str]]:
     changed_files: Optional[List[str]] = None
     try:
         changed_files = sorted(find_changed_files(ref_branch))
@@ -381,9 +381,11 @@
         "--no-quiet", help="output commands", action="store_true", default=False
     )
     parser.add_argument("--step", action="append", help="steps to run (in order)")
-    parser.add_argument("--ref_branch",
-                        default="origin/master",
-                        help="remote/branch used during comparison for --changed-only (default=origin/master")
+    parser.add_argument(
+        "--ref_branch",
+        default="origin/master",
+        help="remote/branch used during comparison for --changed-only (default=origin/master",
+    )
     args = parser.parse_args()
 
     quiet = not args.no_quiet
diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py
index 799d1c6..e78c35b 100755
--- a/tools/amd_build/build_amd.py
+++ b/tools/amd_build/build_amd.py
@@ -4,43 +4,50 @@
 import os
 import argparse
 import sys
-sys.path.append(os.path.realpath(os.path.join(
-    __file__,
-    os.path.pardir,
-    os.path.pardir,
-    os.path.pardir,
-    'torch',
-    'utils')))
+
+sys.path.append(
+    os.path.realpath(
+        os.path.join(
+            __file__, os.path.pardir, os.path.pardir, os.path.pardir, "torch", "utils"
+        )
+    )
+)
 
 from hipify import hipify_python  # type: ignore[import]
 
-parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters')
+parser = argparse.ArgumentParser(
+    description="Top-level script for HIPifying, filling in most common parameters"
+)
 parser.add_argument(
-    '--out-of-place-only',
-    action='store_true',
-    help="Whether to only run hipify out-of-place on source files")
+    "--out-of-place-only",
+    action="store_true",
+    help="Whether to only run hipify out-of-place on source files",
+)
 
 parser.add_argument(
-    '--project-directory',
+    "--project-directory",
     type=str,
-    default='',
+    default="",
     help="The root of the project.",
-    required=False)
+    required=False,
+)
 
 parser.add_argument(
-    '--output-directory',
+    "--output-directory",
     type=str,
-    default='',
+    default="",
     help="The directory to store the hipified project",
-    required=False)
+    required=False,
+)
 
 parser.add_argument(
-    '--extra-include-dir',
+    "--extra-include-dir",
     type=str,
     default=[],
-    nargs='+',
+    nargs="+",
     help="The list of extra directories in caffe2 to hipify",
-    required=False)
+    required=False,
+)
 
 args = parser.parse_args()
 
@@ -93,13 +100,13 @@
 for new_dir in args.extra_include_dir:
     abs_new_dir = os.path.join(proj_dir, new_dir)
     if os.path.exists(abs_new_dir):
-        new_dir = os.path.join(new_dir, '**/*')
+        new_dir = os.path.join(new_dir, "**/*")
         includes.append(new_dir)
 
 ignores = [
     "caffe2/operators/depthwise_3x3_conv_op_cudnn.cu",
     "caffe2/operators/pool_op_cudnn.cu",
-    '*/hip/*',
+    "*/hip/*",
     # These files are compatible with both cuda and hip
     "aten/src/ATen/core/*",
     "torch/csrc/jit/codegen/cuda/codegen.cpp",
@@ -116,12 +123,13 @@
 # Check if the compiler is hip-clang.
 def is_hip_clang() -> bool:
     try:
-        hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip')
-        with open(hip_path + '/lib/.hipInfo') as f:
-            return 'HIP_COMPILER=clang' in f.read()
+        hip_path = os.getenv("HIP_PATH", "/opt/rocm/hip")
+        with open(hip_path + "/lib/.hipInfo") as f:
+            return "HIP_COMPILER=clang" in f.read()
     except IOError:
         return False
 
+
 # TODO Remove once gloo submodule is recent enough to contain upstream fix.
 if is_hip_clang():
     gloo_cmake_file = "third_party/gloo/cmake/Hip.cmake"
@@ -129,7 +137,7 @@
     if os.path.exists(gloo_cmake_file):
         with open(gloo_cmake_file, "r") as sources:
             lines = sources.readlines()
-        newlines = [line.replace(' hip_hcc ', ' amdhip64 ') for line in lines]
+        newlines = [line.replace(" hip_hcc ", " amdhip64 ") for line in lines]
         if lines == newlines:
             print("%s skipped" % gloo_cmake_file)
         else:
@@ -143,7 +151,7 @@
     do_write = False
     with open(gloo_cmake_file, "r") as sources:
         lines = sources.readlines()
-    newlines = [line.replace('RCCL_LIBRARY', 'RCCL_LIBRARY_PATH') for line in lines]
+    newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIBRARY_PATH") for line in lines]
     if lines == newlines:
         print("%s skipped" % gloo_cmake_file)
     else:
@@ -159,7 +167,7 @@
     if os.path.exists(gloo_cmake_file):
         with open(gloo_cmake_file, "r") as sources:
             lines = sources.readlines()
-        newlines = [line.replace('HIP_HCC_FLAGS', 'HIP_CLANG_FLAGS') for line in lines]
+        newlines = [line.replace("HIP_HCC_FLAGS", "HIP_CLANG_FLAGS") for line in lines]
         if lines == newlines:
             print("%s skipped" % gloo_cmake_file)
         else:
@@ -174,4 +182,5 @@
     includes=includes,
     ignores=ignores,
     out_of_place_only=args.out_of_place_only,
-    hip_clang_launch=is_hip_clang())
+    hip_clang_launch=is_hip_clang(),
+)
diff --git a/tools/autograd/context.py b/tools/autograd/context.py
index 66f4f81..cc357f9 100644
--- a/tools/autograd/context.py
+++ b/tools/autograd/context.py
@@ -7,9 +7,12 @@
 
 # Like tools.api.context.with_native_function, but for
 # NativeFunctionWithDifferentiabilityInfo.
-def with_native_function_with_differentiability_info(func: Callable[[NFWDI], T]) -> Callable[[NFWDI], T]:
+def with_native_function_with_differentiability_info(
+    func: Callable[[NFWDI], T]
+) -> Callable[[NFWDI], T]:
     @functools.wraps(func)
     def wrapper(f: NFWDI) -> T:
         with native_function_manager(f.func):
             return func(f)
+
     return wrapper
diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py
index 2d1dbd5..8898bb7 100644
--- a/tools/autograd/gen_annotated_fn_args.py
+++ b/tools/autograd/gen_annotated_fn_args.py
@@ -25,19 +25,26 @@
 from tools.codegen.context import with_native_function
 from tools.codegen.model import BaseOperatorName, NativeFunction
 import tools.codegen.api.python as python
-from .gen_python_functions import should_generate_py_binding, is_py_torch_function, \
-    is_py_nn_function, is_py_linalg_function, is_py_variable_method, is_py_special_function, \
-    is_py_fft_function
+from .gen_python_functions import (
+    should_generate_py_binding,
+    is_py_torch_function,
+    is_py_nn_function,
+    is_py_linalg_function,
+    is_py_variable_method,
+    is_py_special_function,
+    is_py_fft_function,
+)
+
 
 def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
     native_functions = parse_native_yaml(native_yaml_path).native_functions
     mappings = (
-        (is_py_torch_function, 'torch._C._VariableFunctions'),
-        (is_py_nn_function, 'torch._C._nn'),
-        (is_py_linalg_function, 'torch._C._linalg'),
-        (is_py_special_function, 'torch._C._special'),
-        (is_py_fft_function, 'torch._C._fft'),
-        (is_py_variable_method, 'torch.Tensor'),
+        (is_py_torch_function, "torch._C._VariableFunctions"),
+        (is_py_nn_function, "torch._C._nn"),
+        (is_py_linalg_function, "torch._C._linalg"),
+        (is_py_special_function, "torch._C._special"),
+        (is_py_fft_function, "torch._C._fft"),
+        (is_py_variable_method, "torch.Tensor"),
     )
     annotated_args: List[str] = []
     for pred, namespace in mappings:
@@ -48,13 +55,18 @@
             groups[f.func.name.name].append(f)
         for group in groups.values():
             for f in group:
-                annotated_args.append(f'{namespace}.{gen_annotated_args(f)}')
+                annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
 
-    template_path = os.path.join(autograd_dir, 'templates')
+    template_path = os.path.join(autograd_dir, "templates")
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
-    fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
-        'annotated_args': textwrap.indent('\n'.join(annotated_args), '    '),
-    })
+    fm.write_with_template(
+        "annotated_fn_args.py",
+        "annotated_fn_args.py.in",
+        lambda: {
+            "annotated_args": textwrap.indent("\n".join(annotated_args), "    "),
+        },
+    )
+
 
 @with_native_function
 def gen_annotated_args(f: NativeFunction) -> str:
@@ -63,26 +75,28 @@
         if arg.default is not None:
             continue
         out_arg: Dict[str, Any] = {}
-        out_arg['name'] = arg.name
-        out_arg['simple_type'] = python.argument_type_str(arg.type, simple_type=True)
+        out_arg["name"] = arg.name
+        out_arg["simple_type"] = python.argument_type_str(arg.type, simple_type=True)
         size = python.argument_type_size(arg.type)
         if size:
-            out_arg['size'] = size
+            out_arg["size"] = size
         out_args.append(out_arg)
 
-    return f'{f.func.name.name}: {repr(out_args)},'
+    return f"{f.func.name.name}: {repr(out_args)},"
+
 
 def main() -> None:
-    parser = argparse.ArgumentParser(
-        description='Generate annotated_fn_args script')
-    parser.add_argument('native_functions', metavar='NATIVE',
-                        help='path to native_functions.yaml')
-    parser.add_argument('out', metavar='OUT',
-                        help='path to output directory')
-    parser.add_argument('autograd', metavar='AUTOGRAD',
-                        help='path to template directory')
+    parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
+    parser.add_argument(
+        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
+    )
+    parser.add_argument("out", metavar="OUT", help="path to output directory")
+    parser.add_argument(
+        "autograd", metavar="AUTOGRAD", help="path to template directory"
+    )
     args = parser.parse_args()
     gen_annotated(args.native_functions, args.out, args.autograd)
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     main()
diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py
index 26ab682..c6b2b61 100644
--- a/tools/autograd/gen_autograd.py
+++ b/tools/autograd/gen_autograd.py
@@ -26,19 +26,24 @@
 import os
 from tools.codegen.api import cpp
 from tools.codegen.api.autograd import (
-    match_differentiability_info, NativeFunctionWithDifferentiabilityInfo,
+    match_differentiability_info,
+    NativeFunctionWithDifferentiabilityInfo,
 )
 from tools.codegen.gen import parse_native_yaml
 from tools.codegen.selective_build.selector import SelectiveBuilder
 from typing import List
 from . import gen_python_functions
-from .gen_autograd_functions import gen_autograd_functions_lib, gen_autograd_functions_python
+from .gen_autograd_functions import (
+    gen_autograd_functions_lib,
+    gen_autograd_functions_python,
+)
 from .gen_trace_type import gen_trace_type
 from .gen_variable_type import gen_variable_type
 from .gen_inplace_or_view_type import gen_inplace_or_view_type
 from .gen_variable_factories import gen_variable_factories
 from .load_derivatives import load_derivatives
 
+
 def gen_autograd(
     native_functions_path: str,
     out: str,
@@ -48,27 +53,38 @@
 ) -> None:
     # Parse and load derivatives.yaml
     differentiability_infos = load_derivatives(
-        os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
+        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path
+    )
 
-    template_path = os.path.join(autograd_dir, 'templates')
+    template_path = os.path.join(autograd_dir, "templates")
 
     native_funcs = parse_native_yaml(native_functions_path).native_functions
-    fns = list(sorted(filter(
-        operator_selector.is_native_function_selected_for_training,
-        native_funcs), key=lambda f: cpp.name(f.func)))
-    fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(fns, differentiability_infos)
+    fns = list(
+        sorted(
+            filter(
+                operator_selector.is_native_function_selected_for_training, native_funcs
+            ),
+            key=lambda f: cpp.name(f.func),
+        )
+    )
+    fns_with_diff_infos: List[
+        NativeFunctionWithDifferentiabilityInfo
+    ] = match_differentiability_info(fns, differentiability_infos)
 
     # Generate VariableType.h/cpp
     if not disable_autograd:
-        gen_variable_type(out, native_functions_path, fns_with_diff_infos, template_path)
+        gen_variable_type(
+            out, native_functions_path, fns_with_diff_infos, template_path
+        )
 
-        gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
+        gen_inplace_or_view_type(
+            out, native_functions_path, fns_with_diff_infos, template_path
+        )
 
         # operator filter not applied as tracing sources are excluded in selective build
         gen_trace_type(out, native_funcs, template_path)
     # Generate Functions.h/cpp
-    gen_autograd_functions_lib(
-        out, differentiability_infos, template_path)
+    gen_autograd_functions_lib(out, differentiability_infos, template_path)
 
     # Generate variable_factories.h
     gen_variable_factories(out, native_functions_path, template_path)
@@ -80,34 +96,36 @@
     autograd_dir: str,
 ) -> None:
     differentiability_infos = load_derivatives(
-        os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
+        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path
+    )
 
-    template_path = os.path.join(autograd_dir, 'templates')
+    template_path = os.path.join(autograd_dir, "templates")
 
     # Generate Functions.h/cpp
-    gen_autograd_functions_python(
-        out, differentiability_infos, template_path)
+    gen_autograd_functions_python(out, differentiability_infos, template_path)
 
     # Generate Python bindings
-    deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
-    gen_python_functions.gen(
-        out, native_functions_path, deprecated_path, template_path)
+    deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
+    gen_python_functions.gen(out, native_functions_path, deprecated_path, template_path)
 
 
 def main() -> None:
-    parser = argparse.ArgumentParser(
-        description='Generate autograd C++ files script')
-    parser.add_argument('native_functions', metavar='NATIVE',
-                        help='path to native_functions.yaml')
-    parser.add_argument('out', metavar='OUT',
-                        help='path to output directory')
-    parser.add_argument('autograd', metavar='AUTOGRAD',
-                        help='path to autograd directory')
+    parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
+    parser.add_argument(
+        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
+    )
+    parser.add_argument("out", metavar="OUT", help="path to output directory")
+    parser.add_argument(
+        "autograd", metavar="AUTOGRAD", help="path to autograd directory"
+    )
     args = parser.parse_args()
-    gen_autograd(args.native_functions,
-                 args.out, args.autograd,
-                 SelectiveBuilder.get_nop_selector())
+    gen_autograd(
+        args.native_functions,
+        args.out,
+        args.autograd,
+        SelectiveBuilder.get_nop_selector(),
+    )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py
index fd9f50e..35657b5 100644
--- a/tools/autograd/gen_autograd_functions.py
+++ b/tools/autograd/gen_autograd_functions.py
@@ -8,18 +8,36 @@
 
 from typing import List, Sequence, Tuple
 
-from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
-                                        SavedAttribute, uses_retain_variables,
-                                        uses_single_grad)
-from tools.codegen.api.types import (Binding, BaseCType, OptionalCType, tensorT, longT,
-                                     doubleT, scalarT, stringT, boolT, intArrayRefT,
-                                     tensorListT, MutRefCType, ListCType, ArrayRefCType,
-                                     optionalIntArrayRefT)
+from tools.codegen.api.autograd import (
+    Derivative,
+    DifferentiabilityInfo,
+    SavedAttribute,
+    uses_retain_variables,
+    uses_single_grad,
+)
+from tools.codegen.api.types import (
+    Binding,
+    BaseCType,
+    OptionalCType,
+    tensorT,
+    longT,
+    doubleT,
+    scalarT,
+    stringT,
+    boolT,
+    intArrayRefT,
+    tensorListT,
+    MutRefCType,
+    ListCType,
+    ArrayRefCType,
+    optionalIntArrayRefT,
+)
 from tools.codegen.code_template import CodeTemplate
 from tools.codegen.utils import FileManager
 from tools.codegen.model import Argument
 
-FUNCTION_DECLARATION = CodeTemplate("""\
+FUNCTION_DECLARATION = CodeTemplate(
+    """\
 struct TORCH_API ${op} : public ${superclass} {
   using ${superclass}::${superclass};
   variable_list apply(variable_list&& grads) override;
@@ -32,16 +50,20 @@
   ${saved_variables}
   ${saved_list_sizes}
 };
-""")
+"""
+)
 
-WILL_RELEASE_VARIABLES = CodeTemplate("""\
+WILL_RELEASE_VARIABLES = CodeTemplate(
+    """\
 bool retain_variables = true;
 void will_release_variables() override {
   retain_variables = false;
 }
-""")
+"""
+)
 
-FUNCTION_DEFINITION = CodeTemplate("""\
+FUNCTION_DEFINITION = CodeTemplate(
+    """\
 variable_list ${op}::apply(variable_list&& grads) {
   ${thread_lock}
   ${asserts}
@@ -51,34 +73,43 @@
   ${body}
   return grad_inputs;
 }
-""")
+"""
+)
 
-GRAD_INPUT_MASK = CodeTemplate("""\
+GRAD_INPUT_MASK = CodeTemplate(
+    """\
   auto grad_input_mask = std::array<bool, ${n}>{
     ${masks}
   };\
-""")
+"""
+)
 
-DERIVATIVE_SINGLE = CodeTemplate("""\
+DERIVATIVE_SINGLE = CodeTemplate(
+    """\
 if (should_compute_output({ ${name}_ix })) {
   auto grad_result = ${derivative};
   copy_range(grad_inputs, ${name}_ix, grad_result);
 }
-""")
+"""
+)
 
-DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate("""\
+DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
+    """\
   if (should_compute_output({ ${name}_ix })) {
     copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
   }
-""")
+"""
+)
 
-DERIVATIVE_MULTI = CodeTemplate("""\
+DERIVATIVE_MULTI = CodeTemplate(
+    """\
 if (should_compute_output({ ${idx_ranges} })) {
   ${grad_input_mask}
   auto grad_result = ${derivative};
   ${copy_ranges}
 }
-""")
+"""
+)
 
 # Generates python bindings
 #
@@ -89,12 +120,15 @@
 #       Each PyGetSetDef has a function ptr to a getter, also defined here (3).
 #   (3) Getters for each of grad_fn's saved inputs and outputs.
 #
-PY_FUNCTION_DEFINITION = CodeTemplate("""\
+PY_FUNCTION_DEFINITION = CodeTemplate(
+    """\
 static PyTypeObject ${op}Class;
 addClass<${op}>(${op}Class, "${op}", ${op}_properties);
-""")
+"""
+)
 
-PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate("""\
+PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
+    """\
 ${all_getter_definitions}
 
 static struct PyGetSetDef ${op}_properties[] = {
@@ -103,43 +137,55 @@
   {nullptr} /* sentinel */
 };
 
-""")
+"""
+)
 
-PY_GETSETDEF_STRUCT = CodeTemplate("""\
-{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""")
+PY_GETSETDEF_STRUCT = CodeTemplate(
+    """\
+{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
+)
 
-PY_RAW_GETSETDEF_STRUCT = CodeTemplate("""\
-{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""")
+PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
+    """\
+{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
+)
 
 # Getter templates
-GETTER_DEFINITION = CodeTemplate("""\
+GETTER_DEFINITION = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   auto prop = static_cast<${op}*>(self->cdata.get())->${name};
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_SAVEDVAR = CodeTemplate("""\
+GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate("""\
+GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate("""\
+GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   const auto *node = static_cast<${op}*>(self->cdata.get());
@@ -151,9 +197,11 @@
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate("""\
+GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   const auto *node = static_cast<${op}*>(self->cdata.get());
@@ -165,9 +213,11 @@
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_OPT = CodeTemplate("""\
+GETTER_DEFINITION_OPT = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
@@ -178,9 +228,11 @@
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
-GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate("""\
+GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
+    """\
 PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
   HANDLE_TH_ERRORS
   auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
@@ -191,7 +243,8 @@
   ${body}
   END_HANDLE_TH_ERRORS
 }
-""")
+"""
+)
 
 # Getter body
 GETTER_BODY_SAVEDVAR = """\
@@ -293,6 +346,7 @@
 # TODO: This is probably not exhaustive, but it's a start
 UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
 
+
 def gen_autograd_functions_lib(
     out: str,
     differentiability_infos: Sequence[DifferentiabilityInfo],
@@ -305,19 +359,26 @@
     """
 
     # only create an autograd function if we are actually going to calculate a derivative
-    infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
+    infos = list(
+        filter(lambda info: info.args_with_derivatives, differentiability_infos)
+    )
     declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
     definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))
 
-    file_basename = 'Functions'
+    file_basename = "Functions"
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
-    for suffix in ['.h', '.cpp']:
+    for suffix in [".h", ".cpp"]:
         fname = file_basename + suffix
-        fm.write_with_template(fname, fname, lambda: {
-            'generated_comment': '@' + f'generated from {fm.template_dir}/' + fname,
-            'autograd_function_declarations': declarations,
-            'autograd_function_definitions': definitions,
-        })
+        fm.write_with_template(
+            fname,
+            fname,
+            lambda: {
+                "generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
+                "autograd_function_declarations": declarations,
+                "autograd_function_definitions": definitions,
+            },
+        )
+
 
 def gen_autograd_functions_python(
     out: str,
@@ -327,34 +388,43 @@
 
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
     num_shards = 5
-    fm.write('python_functions.h', lambda: {
-        'generated_comment': f'@generated from {fm.template_dir}/python_functions.h',
-        'shard_forward_declare': [
-            f"void initialize_autogenerated_functions_{i}();"
-            for i in range(num_shards)
-        ],
-        'shard_call': [
-            f"initialize_autogenerated_functions_{i}();"
-            for i in range(num_shards)
-        ]
-    })
+    fm.write(
+        "python_functions.h",
+        lambda: {
+            "generated_comment": f"@generated from {fm.template_dir}/python_functions.h",
+            "shard_forward_declare": [
+                f"void initialize_autogenerated_functions_{i}();"
+                for i in range(num_shards)
+            ],
+            "shard_call": [
+                f"initialize_autogenerated_functions_{i}();" for i in range(num_shards)
+            ],
+        },
+    )
 
-    infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
+    infos = list(
+        filter(lambda info: info.args_with_derivatives, differentiability_infos)
+    )
     fm.write_sharded(
-        'python_functions.cpp',
+        "python_functions.cpp",
         infos,
         key_fn=lambda info: info.name,
         base_env={
-            'generated_comment': f'@generated from {fm.template_dir}/python_functions.cpp',
+            "generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp",
         },
         env_callable=lambda info: {
-            'py_function_initializers': [process_function(info, PY_FUNCTION_DEFINITION)],
-            'py_function_props_and_getters': [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
+            "py_function_initializers": [
+                process_function(info, PY_FUNCTION_DEFINITION)
+            ],
+            "py_function_props_and_getters": [
+                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
+            ],
         },
         num_shards=num_shards,
-        sharded_keys={'py_function_initializers', 'py_function_props_and_getters'}
+        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
     )
 
+
 def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
     saved_variables: List[str] = []
     release_variables: List[str] = []
@@ -366,12 +436,15 @@
     py_getsetdef_structs: List[str] = []
 
     for arg in info.args_with_derivatives:
-        if arg.type == 'at::TensorList' or arg.type == 'const c10::List<c10::optional<at::Tensor>> &':
-            size = f'{arg.name}_size_'
-            saved_list_sizes.append(f'size_t {arg.name}_size_;')
+        if (
+            arg.type == "at::TensorList"
+            or arg.type == "const c10::List<c10::optional<at::Tensor>> &"
+        ):
+            size = f"{arg.name}_size_"
+            saved_list_sizes.append(f"size_t {arg.name}_size_;")
         else:
-            size = '1'
-        compute_index_ranges.append(f'auto {arg.name}_ix = gen.range({size});')
+            size = "1"
+        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
 
     def save_var(var: SavedAttribute, is_output: bool) -> None:
         name = var.nctype.name
@@ -379,80 +452,124 @@
         should_append_getsetdef = True
         should_append_raw_getsetdef = False
 
-        if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
-                type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \
-                (type == BaseCType(scalarT) and is_output):
-            saved_variables.append(f'SavedVariable {name}_;')
-            release_variables.append(f'{name}_.reset_data();')
-            ptr = 'shared_from_this()' if is_output else ''
-            unpack.append(f'auto {name} = {name}_.unpack({ptr});')
-            getter_definitions.append(GETTER_DEFINITION_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_SAVEDVAR))
-            getter_definitions.append(GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR))
+        if (
+            type == BaseCType(tensorT)
+            or type == OptionalCType(BaseCType(tensorT))
+            or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
+            or (type == BaseCType(scalarT) and is_output)
+        ):
+            saved_variables.append(f"SavedVariable {name}_;")
+            release_variables.append(f"{name}_.reset_data();")
+            ptr = "shared_from_this()" if is_output else ""
+            unpack.append(f"auto {name} = {name}_.unpack({ptr});")
+            getter_definitions.append(
+                GETTER_DEFINITION_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
+                )
+            )
             should_append_raw_getsetdef = True
         elif type == BaseCType(tensorListT):
-            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
-            saved_variables.append(f'bool {name}_released_ = false;')
+            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
+            saved_variables.append(f"bool {name}_released_ = false;")
             # Just clear() is sufficient, we don't need to loop and clear each variable.
             # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
-            release_variables.append(f'{name}_.clear();')
-            release_variables.append(f'{name}_released_ = true;')
-            unpack.append(f'auto {name} = unpack_list({name}_);')
-            asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
-            getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
-            getter_definitions.append(GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
+            release_variables.append(f"{name}_.clear();")
+            release_variables.append(f"{name}_released_ = true;")
+            unpack.append(f"auto {name} = unpack_list({name}_);")
+            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
+            getter_definitions.append(
+                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
+                )
+            )
             should_append_raw_getsetdef = True
         elif type == ListCType(OptionalCType(BaseCType(tensorT))):
-            saved_variables.append(f'std::vector<SavedVariable> {name}_;')
-            saved_variables.append(f'bool {name}_released_ = false;')
+            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
+            saved_variables.append(f"bool {name}_released_ = false;")
             # Just clear() is sufficient, we don't need to loop and clear each variable.
             # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
-            release_variables.append(f'{name}_.clear();')
-            release_variables.append(f'{name}_released_ = true;')
-            unpack.append(f'auto {name} = unpack_opt_list({name}_);')
-            asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
-            getter_definitions.append(GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR))
-            getter_definitions.append(GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
-                op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR))
+            release_variables.append(f"{name}_.clear();")
+            release_variables.append(f"{name}_released_ = true;")
+            unpack.append(f"auto {name} = unpack_opt_list({name}_);")
+            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
+            getter_definitions.append(
+                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
+                )
+            )
             should_append_raw_getsetdef = True
         elif type == BaseCType(intArrayRefT):
-            saved_variables.append(f'std::vector<int64_t> {name};')
-            getter_definitions.append(GETTER_DEFINITION.substitute(
-                op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
+            saved_variables.append(f"std::vector<int64_t> {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
         elif type == BaseCType(optionalIntArrayRefT):
-            saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
-            getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
-                op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
+            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
         elif type == OptionalCType(BaseCType(intArrayRefT)):
-            saved_variables.append(f'c10::OptionalArray<int64_t> {name};')
-            getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
-                op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG))
+            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
         elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
-            saved_variables.append(f'c10::OptionalArray<double> {name};')
-            getter_definitions.append(GETTER_DEFINITION_OPT_ARRAYREF.substitute(
-                op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE))
+            saved_variables.append(f"c10::OptionalArray<double> {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
+                )
+            )
         elif type == BaseCType(longT):
-            saved_variables.append(f'{type.cpp_type()} {name} = 0;')
-            getter_definitions.append(GETTER_DEFINITION.substitute(
-                op=info.op, name=name, body=GETTER_BODY_INT64_T))
+            saved_variables.append(f"{type.cpp_type()} {name} = 0;")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_INT64_T
+                )
+            )
         elif type == BaseCType(stringT):
-            saved_variables.append(f'std::string {name};')
-            getter_definitions.append(GETTER_DEFINITION.substitute(
-                op=info.op, name=name, body=GETTER_BODY_STRING))
+            saved_variables.append(f"std::string {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_STRING
+                )
+            )
         elif type == OptionalCType(BaseCType(stringT)):
-            saved_variables.append(f'c10::optional<std::string> {name};')
-            getter_definitions.append(GETTER_DEFINITION_OPT.substitute(
-                op=info.op, name=name, body=GETTER_BODY_STRING))
+            saved_variables.append(f"c10::optional<std::string> {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_STRING
+                )
+            )
         else:
-            saved_variables.append(f'{type.cpp_type()} {name};')
+            saved_variables.append(f"{type.cpp_type()} {name};")
 
             if type in MISC_GETTER_DEFS:
                 getter_def, body = MISC_GETTER_DEFS[type]
-                getter_definitions.append(getter_def.substitute(op=info.op, name=name, body=body))
+                getter_definitions.append(
+                    getter_def.substitute(op=info.op, name=name, body=body)
+                )
             else:
                 # Types we don't expose python bindings to yet:
                 #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
@@ -460,9 +577,13 @@
                 should_append_getsetdef = False
 
         if should_append_getsetdef:
-            py_getsetdef_structs.append(PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
+            py_getsetdef_structs.append(
+                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
+            )
         if should_append_raw_getsetdef:
-            py_getsetdef_structs.append(PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))
+            py_getsetdef_structs.append(
+                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
+            )
 
     for var in info.all_saved_inputs:
         save_var(var, is_output=False)
@@ -472,24 +593,25 @@
     # lock the mutex when we release variables and in Node::apply to protect thread safety
     # see Note [Thread Safety on Autograd Node]
     if len(release_variables) > 0:
-        thread_lock = 'std::lock_guard<std::mutex> lock(mutex_);'
+        thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
     else:
-        thread_lock = ''
+        thread_lock = ""
 
     if uses_retain_variables(info):
         will_release_variables = WILL_RELEASE_VARIABLES.substitute()
     else:
-        will_release_variables = ''
+        will_release_variables = ""
 
     body: List[str] = []
 
     if uses_single_grad(info):
-        body.append('const auto& grad = grads[0];')
+        body.append("const auto& grad = grads[0];")
     else:
         # Generate aliases for gradients named for returned values.
         body.extend(
-            f'const auto& {name} = grads[{info.available_named_gradients.index(name)}];'
-            for name in info.used_named_gradients)
+            f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
+            for name in info.used_named_gradients
+        )
 
     def emit_derivative(
         derivative: Derivative,
@@ -499,51 +621,65 @@
         var_names = derivative.var_names
         if len(var_names) == 1:
             checks_any_grad_defined = False
-            if 'not_implemented' not in formula:
+            if "not_implemented" not in formula:
                 matching_args = [
-                    arg for arg in args_with_derivatives
-                    if arg.name == var_names[0]]
+                    arg for arg in args_with_derivatives if arg.name == var_names[0]
+                ]
                 if len(matching_args) == 1:
                     # We can add undefined grad support if the input variable is a Tensor
                     arg = matching_args[0]
-                    if isinstance(arg.argument, Argument) and str(arg.argument.type) in ('Tensor', 'Tensor?'):
-                        formula = 'any_grad_defined ? (' + formula + ') : Tensor()'
+                    if isinstance(arg.argument, Argument) and str(
+                        arg.argument.type
+                    ) in ("Tensor", "Tensor?"):
+                        formula = "any_grad_defined ? (" + formula + ") : Tensor()"
                         checks_any_grad_defined = True
-            return (checks_any_grad_defined,
-                    DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula))
+            return (
+                checks_any_grad_defined,
+                DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
+            )
         else:
-            if 'grad_input_mask' in formula:
-                masks = [f'should_compute_output({{ {n}_ix }}),' for n in var_names]
-                grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
+            if "grad_input_mask" in formula:
+                masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names]
+                grad_input_mask = GRAD_INPUT_MASK.substitute(
+                    masks=masks, n=len(var_names)
+                )
             else:
-                grad_input_mask = ''
-            idx_ranges = ', '.join(f'{n}_ix' for n in var_names)
+                grad_input_mask = ""
+            idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
             copy_ranges: List[str] = []
             for i, n in enumerate(var_names):
                 copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
             return False, DERIVATIVE_MULTI.substitute(
-                idx_ranges=idx_ranges, copy_ranges=copy_ranges,
+                idx_ranges=idx_ranges,
+                copy_ranges=copy_ranges,
                 derivative=formula,
-                grad_input_mask=grad_input_mask)
+                grad_input_mask=grad_input_mask,
+            )
 
     body.extend(unpack)
     need_any_grad_defined_var = False
     for derivative in info.derivatives:
-        checks_any_grad_defined, derivative_text = emit_derivative(derivative, info.args_with_derivatives)
+        checks_any_grad_defined, derivative_text = emit_derivative(
+            derivative, info.args_with_derivatives
+        )
         body.append(derivative_text)
         need_any_grad_defined_var |= checks_any_grad_defined
     # Since single-output derivative formulas need to check if grads are
     # defined, only perform the check once, before all the formulas
     if need_any_grad_defined_var:
-        body.insert(-len(info.derivatives),
-                    'bool any_grad_defined = any_variable_defined(grads);')
+        body.insert(
+            -len(info.derivatives),
+            "bool any_grad_defined = any_variable_defined(grads);",
+        )
 
     if info.name in UNTRACEABLE_FUNCTIONS:
-        superclass = 'Node'
+        superclass = "Node"
     else:
-        superclass = 'TraceableFunction'
+        superclass = "TraceableFunction"
 
-    all_getsetdef_structs = ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
+    all_getsetdef_structs = (
+        ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
+    )
     all_getter_definitions = "\n".join(getter_definitions)
 
     return template.substitute(
@@ -558,5 +694,5 @@
         body=body,
         superclass=superclass,
         all_getter_definitions=all_getter_definitions,
-        all_getsetdef_structs=all_getsetdef_structs
+        all_getsetdef_structs=all_getsetdef_structs,
     )
diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py
index 1ea1c30..b0e6b12 100644
--- a/tools/autograd/gen_inplace_or_view_type.py
+++ b/tools/autograd/gen_inplace_or_view_type.py
@@ -6,22 +6,39 @@
 
 from tools.codegen.api import cpp
 from tools.codegen.api.autograd import (
-    NativeFunctionWithDifferentiabilityInfo, gen_differentiable_outputs,
+    NativeFunctionWithDifferentiabilityInfo,
+    gen_differentiable_outputs,
     dispatch_strategy,
 )
-from tools.codegen.api.types import (Binding, DispatcherSignature, CType, BaseCType,
-                                     OptionalCType, longT, boolT, intArrayRefT, symIntArrayRefT)
+from tools.codegen.api.types import (
+    Binding,
+    DispatcherSignature,
+    CType,
+    BaseCType,
+    OptionalCType,
+    longT,
+    boolT,
+    intArrayRefT,
+    symIntArrayRefT,
+)
 from tools.codegen.code_template import CodeTemplate
 from tools.codegen.context import with_native_function
 from tools.codegen.model import (
-    Type, NativeFunction, SelfArgument, TensorOptionsArguments, SchemaKind,
+    Type,
+    NativeFunction,
+    SelfArgument,
+    TensorOptionsArguments,
+    SchemaKind,
     is_foreach_op,
 )
 from typing import List, Optional, Sequence, Tuple, Dict
 from tools.codegen.utils import FileManager
 from .context import with_native_function_with_differentiability_info
 from .gen_trace_type import (
-    MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
+    MANUAL_AUTOGRAD,
+    type_wrapper_name,
+    tie_return_values,
+    get_return_value,
 )
 
 # See NOTE [ Autograd View Variables ] in variable.h for details.
@@ -33,58 +50,75 @@
 # A map: function name => name of the argument that all outputs are view of
 
 VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [
-    'view_as_complex',
-    'view_as_real',
-    '_conj',
-    '_neg_view'
+    "view_as_complex",
+    "view_as_real",
+    "_conj",
+    "_neg_view",
 ]
 
 VIEW_FUNCTIONS = {
-    'numpy_T': 'self',
-    'alias': 'self',
-    'as_strided': 'self',
-    'diagonal': 'self',
-    'expand': 'self',
-    'permute': 'self',
-    'select': 'self',
-    'slice': 'self',
-    'split': 'self',
-    'split_with_sizes': 'self',
-    'squeeze': 'self',
-    't': 'self',
-    'transpose': 'self',
-    'unfold': 'self',
-    'unsqueeze': 'self',
-    'flatten': 'self',
-    'view': 'self',
-    'unbind': 'self',
-    '_indices': 'self',
-    '_values': 'self',
-    'indices': 'self',
-    'values': 'self',
-    'crow_indices': 'self',
-    'col_indices': 'self',
+    "numpy_T": "self",
+    "alias": "self",
+    "as_strided": "self",
+    "diagonal": "self",
+    "expand": "self",
+    "permute": "self",
+    "select": "self",
+    "slice": "self",
+    "split": "self",
+    "split_with_sizes": "self",
+    "squeeze": "self",
+    "t": "self",
+    "transpose": "self",
+    "unfold": "self",
+    "unsqueeze": "self",
+    "flatten": "self",
+    "view": "self",
+    "unbind": "self",
+    "_indices": "self",
+    "_values": "self",
+    "indices": "self",
+    "values": "self",
+    "crow_indices": "self",
+    "col_indices": "self",
     # sparse_coo ctor output should really be views of both indices and values,
     # but we only supports making as view of a single variable, and indices is
     # discrete anyways.
     # FIXME: clone indices on construction.
-    'sparse_coo_tensor_with_dims_and_tensors': 'values',
-    '_reshape_alias': 'self',
+    "sparse_coo_tensor_with_dims_and_tensors": "values",
+    "_reshape_alias": "self",
 }
 
 for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
-    VIEW_FUNCTIONS[key] = 'self'
+    VIEW_FUNCTIONS[key] = "self"
 
 # note: some VIEW_FUNCTIONS are just compositions of the view functions above
 # this list contains both the root view functions and any that are purely composed
 # of viewing functions, and is used by the JIT to determine when an operator
 # may return a view of its inputs; however they may sometimes return a copy.
 # (e.g. `contiguous`)
-RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
-    'chunk', 'detach', 'contiguous', 'reshape', 'reshape_as',
-    'expand_as', 'view_as', 'real', 'imag', 'narrow', 'movedim',
-    'tensor_split', 'swapdims', 'swapaxes', 'mT', 'mH', 'adjoint', 'matrix_H'
-})
+RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union(
+    {
+        "chunk",
+        "detach",
+        "contiguous",
+        "reshape",
+        "reshape_as",
+        "expand_as",
+        "view_as",
+        "real",
+        "imag",
+        "narrow",
+        "movedim",
+        "tensor_split",
+        "swapdims",
+        "swapaxes",
+        "mT",
+        "mH",
+        "adjoint",
+        "matrix_H",
+    }
+)
 
 # These are the functions we consider views for the purposes of validating
 # StorageImpl and TensorImpl in gen_variable_type.
@@ -93,68 +127,90 @@
 # See NOTE [Unsafe View] for more info.
 ALL_VIEW_FUNCTIONS = {
     **VIEW_FUNCTIONS,
-    '_unsafe_view': 'self',
+    "_unsafe_view": "self",
 }
 
-ARRAYREF_TO_VEC = CodeTemplate("""\
+ARRAYREF_TO_VEC = CodeTemplate(
+    """\
 auto ${vec} = ${arg}.vec();
-""")
+"""
+)
 
-OPTIONAL_TO_VAL = CodeTemplate("""\
+OPTIONAL_TO_VAL = CodeTemplate(
+    """\
 auto ${val} = ${arg}.value_or(${default});
-""")
+"""
+)
 
-CALL_DISPATCH = CodeTemplate("""\
-at::_ops::${unambiguous_name}::call(${unpacked_args})""")
+CALL_DISPATCH = CodeTemplate(
+    """\
+at::_ops::${unambiguous_name}::call(${unpacked_args})"""
+)
 
-SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate("""\
+SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
+    """\
 std::function<at::Tensor(const at::Tensor&)> func=nullptr;
 if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided()) {
   ${replay_view_func}
 }
-""")
+"""
+)
 
-REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate("""\
+REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
+    """\
 func = [=](const at::Tensor& ${input_base}) {
   return ${replay_view_call};
 };
-""")
+"""
+)
 
-METHOD_DEFINITION = CodeTemplate("""\
+METHOD_DEFINITION = CodeTemplate(
+    """\
 ${return_type} ${type_wrapper_name}(${formals}) {
   ${type_definition_body}
 }
-""")
+"""
+)
 
-WRAPPER_REGISTRATION = CodeTemplate("""\
+WRAPPER_REGISTRATION = CodeTemplate(
+    """\
 m.impl("${unqual_operator_name_with_overload}",
        TORCH_FN(${class_type}::${type_wrapper_name})
 );
-""")
+"""
+)
 
-AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate("""\
+AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate(
+    """\
 m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback());
-""")
+"""
+)
 
-INPLACE_REDISPATCH = CodeTemplate("""\
+INPLACE_REDISPATCH = CodeTemplate(
+    """\
 {
   at::AutoDispatchBelowADInplaceOrView guard;
   at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
 }
-""")
+"""
+)
 
-ASSIGN_RETURN_VALUE = CodeTemplate("""\
+ASSIGN_RETURN_VALUE = CodeTemplate(
+    """\
 ${return_values} = ${rhs_value};
-""")
+"""
+)
 
-VIEW_REDISPATCH = CodeTemplate("""\
+VIEW_REDISPATCH = CodeTemplate(
+    """\
 ${assign_return_values} ([&]() {
   at::AutoDispatchBelowADInplaceOrView guard;
   return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
 })();
-""")
+"""
+)
 
-TMP_VAR = '_tmp'
+TMP_VAR = "_tmp"
 
 # FIXME: Ideally these functions should be methods on Type class, but we have a
 #        comment in codegen/model.py there saying these concepts are not well defined.
@@ -163,27 +219,38 @@
     # TODO: Should handle optional here?
     return t.is_tensor_like() and t.is_list_like() is None
 
+
 def is_tensor_list_type(t: Type) -> bool:
     # TODO: Should handle optional here?
     return t.is_tensor_like() and t.is_list_like() is not None
 
-UNPACK_TENSOR = CodeTemplate("""\
-auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""")
+
+UNPACK_TENSOR = CodeTemplate(
+    """\
+auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});"""
+)
+
 
 def unpacked_name(arg_name: str) -> str:
-    return arg_name + '_'
+    return arg_name + "_"
+
 
 @with_native_function
 def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
     body: List[str] = []
     unpacked_bindings: List[Binding] = []
 
-    bindings = [r for a in f.func.schema_order_arguments()
-                for r in cpp.argument(a,
-                                      method=False,
-                                      cpp_no_default_args=set(),
-                                      faithful=False,
-                                      has_tensor_options=False)]
+    bindings = [
+        r
+        for a in f.func.schema_order_arguments()
+        for r in cpp.argument(
+            a,
+            method=False,
+            cpp_no_default_args=set(),
+            faithful=False,
+            has_tensor_options=False,
+        )
+    ]
 
     for i, binding in enumerate(bindings):
         assert not isinstance(binding.argument, SelfArgument)
@@ -197,25 +264,31 @@
 
         is_tensor_list = is_tensor_list_type(binding.argument.type)
         ref = (not is_nullable) and not is_tensor_list
-        suffix = '_opt' if is_nullable and not is_tensor_list else ''
-        body.append(UNPACK_TENSOR.substitute(
-            arg_name=binding.name,
-            arg_pos=i,
-            suffix=suffix,
-            ref='&' if ref else '',
-        ))
-        unpacked_bindings.append(Binding(
-            name=unpacked_name(binding.name),
-            nctype=binding.nctype,
-            argument=binding.argument,
-            default=binding.default,
-        ))
+        suffix = "_opt" if is_nullable and not is_tensor_list else ""
+        body.append(
+            UNPACK_TENSOR.substitute(
+                arg_name=binding.name,
+                arg_pos=i,
+                suffix=suffix,
+                ref="&" if ref else "",
+            )
+        )
+        unpacked_bindings.append(
+            Binding(
+                name=unpacked_name(binding.name),
+                nctype=binding.nctype,
+                argument=binding.argument,
+                default=binding.default,
+            )
+        )
 
     return body, unpacked_bindings
 
+
 def get_base_name(f: NativeFunction) -> str:
     return f.func.name.name.base  # TODO: should be str(f.func.name.name)?
 
+
 def get_view_info(f: NativeFunction) -> Optional[str]:
     base_name = get_base_name(f)
     view_info = VIEW_FUNCTIONS.get(base_name, None)
@@ -223,115 +296,148 @@
         view_info = "self"
     return view_info
 
+
 # For view replay calls, we generate an ordinary Dispatcher::call() instead, because:
 #  - We want to replay the entire call into the op, including any previously-set dispatch keys (including autograd!).
 #  - The view replay call also is not part of the hot path.
-def emit_view_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
+def emit_view_call(
+    f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
+) -> str:
     # View replay functions use the standard Dispatcher::call API.
     return CALL_DISPATCH.substitute(
-        unambiguous_name=f.func.name.unambiguous_name(),
-        unpacked_args=unpacked_args)
+        unambiguous_name=f.func.name.unambiguous_name(), unpacked_args=unpacked_args
+    )
+
 
 def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
-    """ Generate an additional lambda function to recover views in backward when as_strided is not supported.
+    """Generate an additional lambda function to recover views in backward when as_strided is not supported.
     See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
-    input_base = 'input_base'
-    replay_view_func = ''
+    input_base = "input_base"
+    replay_view_func = ""
     updated_unpacked_args: List[str] = []
     known_view_arg_simple_types: List[CType] = [
         BaseCType(longT),
         OptionalCType(BaseCType(longT)),
         BaseCType(boolT),
         BaseCType(intArrayRefT),
-        BaseCType(symIntArrayRefT)]
+        BaseCType(symIntArrayRefT),
+    ]
     for unpacked_binding in unpacked_bindings:
         arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type
-        if arg == 'self_':
+        if arg == "self_":
             updated_unpacked_args.append(input_base)
             continue
         if arg_type not in known_view_arg_simple_types:
-            known_types_str = ', '.join([str(t) for t in known_view_arg_simple_types])
-            raise TypeError(f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: '
-                            f'{known_types_str}. Please update the list or materialize it so that it can be closed '
-                            'over by value, also add a test in pytorch/xla/test/test_operations.py where this code '
-                            'is exercised.')
+            known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types])
+            raise TypeError(
+                f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: "
+                f"{known_types_str}. Please update the list or materialize it so that it can be closed "
+                "over by value, also add a test in pytorch/xla/test/test_operations.py where this code "
+                "is exercised."
+            )
 
-        if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(symIntArrayRefT):
+        if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(
+            symIntArrayRefT
+        ):
             # It's not safe to close over IntArrayRef by value, since this is a
             # reference type, so materialize a vector to close over by value
-            arg_vec = arg + '_vec'
+            arg_vec = arg + "_vec"
             replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
             updated_unpacked_args.append(arg_vec)
         elif arg_type == OptionalCType(BaseCType(longT)):
             # Materialize int64_t? to int64_t
-            arg_value = arg + '_val'
-            replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0')
+            arg_value = arg + "_val"
+            replay_view_func += OPTIONAL_TO_VAL.substitute(
+                arg=arg, val=arg_value, default="0"
+            )
             updated_unpacked_args.append(arg_value)
         else:
             updated_unpacked_args.append(arg)
 
     replay_view_call = emit_view_call(f, input_base, updated_unpacked_args)
     replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
-        input_base=input_base,
-        replay_view_call=replay_view_call)
+        input_base=input_base, replay_view_call=replay_view_call
+    )
 
-    is_view_with_metadata_change = 'true' if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else 'false'
+    is_view_with_metadata_change = (
+        "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false"
+    )
 
     return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
         is_view_with_metadata_change=is_view_with_metadata_change,
-        replay_view_func=replay_view_func)
+        replay_view_func=replay_view_func,
+    )
 
-def emit_view_body(fn: NativeFunctionWithDifferentiabilityInfo, var: str) -> Tuple[str, str]:
+
+def emit_view_body(
+    fn: NativeFunctionWithDifferentiabilityInfo, var: str
+) -> Tuple[str, str]:
     # See NOTE [ Autograd View Variables ] in variable.h for details.
     f = fn.func
     base_name = get_base_name(f)
     view_info = get_view_info(f)
-    call = ''
+    call = ""
     differentiable_outputs = gen_differentiable_outputs(fn)
     differentiable_output_vars = {r.name for r in differentiable_outputs}
     if not isinstance(view_info, str):
-        raise TypeError(f'The view info should be a string for {base_name}, but it is: {view_info}')
+        raise TypeError(
+            f"The view info should be a string for {base_name}, but it is: {view_info}"
+        )
     if len(differentiable_output_vars) == 0:
         # no output is differentiable (.indices() for SparseTensors for example)
-        rhs_value = (f'as_view({view_info}, {var}, '
-                     f'/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)')
+        rhs_value = (
+            f"as_view({view_info}, {var}, "
+            f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)"
+        )
     elif len(differentiable_output_vars) == 1:
         # Single differentiable output (Tensor or Tensor[])
         return_info = differentiable_outputs[0]
         # We only support simple Tensor or a TensorList for functions that return views
-        if not is_tensor_type(return_info.type) and not is_tensor_list_type(return_info.type):
-            raise RuntimeError(f'{base_name} that return differentiable views can only return Tensor or Tensor[]')
+        if not is_tensor_type(return_info.type) and not is_tensor_list_type(
+            return_info.type
+        ):
+            raise RuntimeError(
+                f"{base_name} that return differentiable views can only return Tensor or Tensor[]"
+            )
 
         # See Note [ View + Inplace detection]
         def get_creation_meta_in_mode(original: str) -> str:
-            creation_meta_with_grad_mode = f'(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)'
-            return f'InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}'
+            creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)"
+            return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}"
 
         # Only allow rebasing of the history if we return a single Tensor
         # If we are in a no grad block, raise a warning
         # See NOTE [ View + Inplace detection ] for more details about this logic
         if is_tensor_list_type(return_info.type):
-            creation_meta = get_creation_meta_in_mode('CreationMeta::MULTI_OUTPUT_NODE')
-            call += (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, '
-                     '/* is_fw_differentiable */ true, '
-                     f'/* creation_meta */ {creation_meta});')
-            rhs_value = f'std::move({var})'
+            creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
+            call += (
+                f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
+                "/* is_fw_differentiable */ true, "
+                f"/* creation_meta */ {creation_meta});"
+            )
+            rhs_value = f"std::move({var})"
         else:
             _, unpacked_bindings = unpack_args(f)
             call += emit_view_lambda(f, unpacked_bindings)
-            creation_meta = get_creation_meta_in_mode('CreationMeta::DEFAULT')
-            rhs_value = (f'as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, '
-                         '/* is_fw_differentiable */ true, '
-                         f'/* view_func */ func, /* creation_meta */ {creation_meta})')
+            creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
+            rhs_value = (
+                f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
+                "/* is_fw_differentiable */ true, "
+                f"/* view_func */ func, /* creation_meta */ {creation_meta})"
+            )
     else:
         # This could be supported but we don't need it at the moment, so keeping things simple.
-        raise RuntimeError('Function that return multiple differentiable output '
-                           'when at least one of them is view is not supported.')
+        raise RuntimeError(
+            "Function that return multiple differentiable output "
+            "when at least one of them is view is not supported."
+        )
     return call, rhs_value
 
+
 def modifies_arguments(f: NativeFunction) -> bool:
     return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
 
+
 @with_native_function_with_differentiability_info
 def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
     f = fn.func
@@ -342,48 +448,63 @@
 
     # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
     # See Note [Plumbing Keys Through The Dispatcher] for details.
-    dispatch_key_set = 'ks & c10::after_ADInplaceOrView_keyset'
-    redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
+    dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset"
+    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
 
     # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
     # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
     if modifies_arguments(f):  # inplace op
-        inplace_view_body.append(INPLACE_REDISPATCH.substitute(
-            unambiguous_name=f.func.name.unambiguous_name(),
-            unpacked_args=redispatch_args,
-        ))
+        inplace_view_body.append(
+            INPLACE_REDISPATCH.substitute(
+                unambiguous_name=f.func.name.unambiguous_name(),
+                unpacked_args=redispatch_args,
+            )
+        )
         for r in cpp.return_names(f):
-            inplace_view_body.append(f'increment_version({r});')
+            inplace_view_body.append(f"increment_version({r});")
     else:
-        assert(get_view_info(f) is not None)
-        inplace_view_body.append(VIEW_REDISPATCH.substitute(
-            assign_return_values='auto ' + TMP_VAR + ' = ',
-            unambiguous_name=f.func.name.unambiguous_name(),
-            unpacked_args=redispatch_args,
-        ))
+        assert get_view_info(f) is not None
+        inplace_view_body.append(
+            VIEW_REDISPATCH.substitute(
+                assign_return_values="auto " + TMP_VAR + " = ",
+                unambiguous_name=f.func.name.unambiguous_name(),
+                unpacked_args=redispatch_args,
+            )
+        )
         call, rhs_value = emit_view_body(fn, TMP_VAR)
         inplace_view_body.append(call)
         assert rhs_value is not None
         inplace_view_body.append(
-            ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f), rhs_value=rhs_value))
+            ASSIGN_RETURN_VALUE.substitute(
+                return_values=tie_return_values(f), rhs_value=rhs_value
+            )
+        )
     if f.func.returns:
-        inplace_view_body.append(f'return {get_return_value(f)};')
+        inplace_view_body.append(f"return {get_return_value(f)};")
     return inplace_view_body
 
+
 @with_native_function
 def gen_formals(f: NativeFunction) -> str:
-    return ', '.join(
+    return ", ".join(
         # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
         # See Note [Plumbing Keys Through The Dispatcher] for details.
-        ['c10::DispatchKeySet ks'] +
-        [f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
-         for a in f.func.schema_order_arguments()]
+        ["c10::DispatchKeySet ks"]
+        + [
+            f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
+            for a in f.func.schema_order_arguments()
+        ]
     )
 
+
 @with_native_function_with_differentiability_info
-def inplace_or_view_method_definition(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
+def inplace_or_view_method_definition(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> Optional[str]:
     f = fn.func
-    if get_view_info(f) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):
+    if get_view_info(f) is None and (
+        not modifies_arguments(f) or is_foreach_op(str(f.func.name))
+    ):
         return None
     return METHOD_DEFINITION.substitute(
         return_type=cpp.returns_type(f.func.returns).cpp_type(),
@@ -392,38 +513,55 @@
         type_definition_body=emit_inplace_or_view_body(fn),
     )
 
+
 @with_native_function_with_differentiability_info
-def inplace_or_view_method_registration(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
+def inplace_or_view_method_registration(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> Optional[str]:
     f = fn.func
-    if get_view_info(f) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):
+    if get_view_info(f) is None and (
+        not modifies_arguments(f) or is_foreach_op(str(f.func.name))
+    ):
         return None
     return WRAPPER_REGISTRATION.substitute(
         unqual_operator_name_with_overload=f.func.name,
         type_wrapper_name=type_wrapper_name(f),
-        class_type='ADInplaceOrView',
+        class_type="ADInplaceOrView",
     )
 
+
 def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
     f = fn.func
     name = cpp.name(f.func)
-    return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == 'use_derived'
+    return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived"
 
-def gen_inplace_or_view_type_env(fn: NativeFunctionWithDifferentiabilityInfo) -> Dict[str, List[str]]:
+
+def gen_inplace_or_view_type_env(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> Dict[str, List[str]]:
     definition = inplace_or_view_method_definition(fn)
     registration = inplace_or_view_method_registration(fn)
 
     return {
-        'ops_headers': ([f'#include <ATen/ops/{fn.func.root_name}_ops.h>']
-                        if definition is not None else []),
-        'inplace_or_view_method_definitions': [definition] if definition is not None else [],
-        'inplace_or_view_wrapper_registrations': [registration] if registration is not None else [],
+        "ops_headers": (
+            [f"#include <ATen/ops/{fn.func.root_name}_ops.h>"]
+            if definition is not None
+            else []
+        ),
+        "inplace_or_view_method_definitions": [definition]
+        if definition is not None
+        else [],
+        "inplace_or_view_wrapper_registrations": [registration]
+        if registration is not None
+        else [],
     }
 
+
 def gen_inplace_or_view_type(
     out: str,
     native_yaml_path: str,
     fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
-    template_path: str
+    template_path: str,
 ) -> None:
     # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
     # template regarding sharding of the generated files.
@@ -431,15 +569,17 @@
 
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
     fm.write_sharded(
-        'ADInplaceOrViewType.cpp',
+        "ADInplaceOrViewType.cpp",
         [fn for fn in fns_with_infos if use_derived(fn)],
         key_fn=lambda fn: fn.func.root_name,
         base_env={
-            'generated_comment':
-            f'@generated from {template_path}/ADInplaceOrViewType.cpp',
+            "generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp",
         },
         env_callable=gen_inplace_or_view_type_env,
         num_shards=2,
-        sharded_keys={'ops_headers', 'inplace_or_view_method_definitions',
-                      'inplace_or_view_wrapper_registrations'}
+        sharded_keys={
+            "ops_headers",
+            "inplace_or_view_method_definitions",
+            "inplace_or_view_wrapper_registrations",
+        },
     )
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 2b9b133..6f31c09 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -40,22 +40,32 @@
 from tools.codegen.code_template import CodeTemplate
 from tools.codegen.api import cpp
 from tools.codegen.api.types import CppSignatureGroup
-from tools.codegen.api.python import (PythonArgument, PythonSignature,
-                                      PythonSignatureDeprecated,
-                                      PythonSignatureGroup,
-                                      PythonSignatureNativeFunctionPair,
-                                      arg_parser_output_exprs,
-                                      argument_type_str, cpp_dispatch_exprs,
-                                      cpp_dispatch_target,
-                                      dispatch_lambda_args,
-                                      dispatch_lambda_exprs,
-                                      dispatch_lambda_return_str,
-                                      has_tensor_options,
-                                      namedtuple_fieldnames, signature)
+from tools.codegen.api.python import (
+    PythonArgument,
+    PythonSignature,
+    PythonSignatureDeprecated,
+    PythonSignatureGroup,
+    PythonSignatureNativeFunctionPair,
+    arg_parser_output_exprs,
+    argument_type_str,
+    cpp_dispatch_exprs,
+    cpp_dispatch_target,
+    dispatch_lambda_args,
+    dispatch_lambda_exprs,
+    dispatch_lambda_return_str,
+    has_tensor_options,
+    namedtuple_fieldnames,
+    signature,
+)
 from tools.codegen.gen import cpp_string, parse_native_yaml
 from tools.codegen.context import with_native_function
-from tools.codegen.model import (Argument, BaseOperatorName, NativeFunction,
-                                 Type, Variant)
+from tools.codegen.model import (
+    Argument,
+    BaseOperatorName,
+    NativeFunction,
+    Type,
+    Variant,
+)
 from tools.codegen.utils import split_name_params, YamlLoader, FileManager
 
 from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
@@ -70,49 +80,97 @@
 
 # These functions require manual Python bindings or are not exposed to Python
 _SKIP_PYTHON_BINDINGS = [
-    'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
-    '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
-    '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
-    '_?sparse_csr_tensor.*',
-    '_arange.*', '_range.*', 'linspace.*', 'logspace.*',
-    '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
-    'index', 'unique_dim_consecutive',
-    '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
-    '_th_.*', '_thnn_.*',
-    'arange.*', 'range.*', '_solve.*', '_inverse.*',
-    'full(_out)?',
-    '_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*',
-    'slice', 'randint(_out)?',
-    'item', '_local_scalar_dense', 'to',
-    '_to_copy',
-    'copy_sparse_to_sparse_', 'copy_',
-    'numpy_T', 'matrix_H', 'mT', 'mH',  # these need to be an attributes in Python, not functions
-    'nonzero(_(out|numpy))?',
-    'set_data',
-    '.*_overrideable',  # overrideable functions for backend extension
-    'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
-    '_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
-    'fake_quantize_per_channel_affine_cachemask',
-    '_new_zeros_with_same_feature_meta', '_has_same_storage_numel',  # used for forward AD internals
-    '_reshape_alias',
-    'replace_',  # only used by the functionalization pass, doesn't need to be exposed to python
+    "alias",
+    "contiguous",
+    "is_cuda",
+    "is_sparse",
+    "is_sparse_csr",
+    "size",
+    "stride",
+    ".*_backward",
+    ".*_backward_(out|input|weight|bias)",
+    ".*_forward",
+    ".*_forward_out",
+    "_unsafe_view",
+    "tensor",
+    "_?sparse_coo_tensor.*",
+    "_?sparse_csr_tensor.*",
+    "_arange.*",
+    "_range.*",
+    "linspace.*",
+    "logspace.*",
+    "_sparse_add_out",
+    "_sparse_div.*",
+    "_sparse_mul.*",
+    "_sparse_sub.*",
+    "_sparse_dense_add_out",
+    "index",
+    "unique_dim_consecutive",
+    "_cumsum.*",
+    "_cumprod.*",
+    "_sum.*",
+    "_prod.*",
+    "_th_.*",
+    "_thnn_.*",
+    "arange.*",
+    "range.*",
+    "_solve.*",
+    "_inverse.*",
+    "full(_out)?",
+    "_cholesky.*",
+    "_triangular_solve.*",
+    "_qr.*",
+    "_symeig.*",
+    "_svd.*",
+    "slice",
+    "randint(_out)?",
+    "item",
+    "_local_scalar_dense",
+    "to",
+    "_to_copy",
+    "copy_sparse_to_sparse_",
+    "copy_",
+    "numpy_T",
+    "matrix_H",
+    "mT",
+    "mH",  # these need to be an attributes in Python, not functions
+    "nonzero(_(out|numpy))?",
+    "set_data",
+    ".*_overrideable",  # overrideable functions for backend extension
+    "data",
+    "is_leaf",
+    "output_nr",
+    "_version",
+    "requires_grad_",
+    "retains_grad",
+    "set_",
+    "_fw_primal",
+    "fake_quantize_per_tensor_affine_cachemask",
+    "fake_quantize_per_channel_affine_cachemask",
+    "_new_zeros_with_same_feature_meta",
+    "_has_same_storage_numel",  # used for forward AD internals
+    "_reshape_alias",
+    "replace_",  # only used by the functionalization pass, doesn't need to be exposed to python
 ]
 
-SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
+SKIP_PYTHON_BINDINGS = list(
+    map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS)
+)
 
 # These function signatures are not exposed to Python. Note that this signature
 # list does not support regex.
 SKIP_PYTHON_BINDINGS_SIGNATURES = [
-    'add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
-    'add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
-    'sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
-    'sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)',
-    'mul.Scalar(Tensor self, Scalar other) -> Tensor',
-    'mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
-    'div.Scalar(Tensor self, Scalar other) -> Tensor',
-    'div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)',
+    "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
+    "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
+    "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
+    "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
+    "mul.Scalar(Tensor self, Scalar other) -> Tensor",
+    "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
+    "div.Scalar(Tensor self, Scalar other) -> Tensor",
+    "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
 ]
 
+
 @with_native_function
 def should_generate_py_binding(f: NativeFunction) -> bool:
     name = cpp.name(f.func)
@@ -127,32 +185,42 @@
 
     return True
 
+
 def get_pycname(name: BaseOperatorName) -> str:
-    return f'THPVariable_{name}'
+    return f"THPVariable_{name}"
+
 
 def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
     return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
 
+
 def is_py_variable_method(f: NativeFunction) -> bool:
     return f.python_module is None and Variant.method in f.variants
 
+
 def is_py_torch_function(f: NativeFunction) -> bool:
     return f.python_module is None and Variant.function in f.variants
 
+
 def is_py_nn_function(f: NativeFunction) -> bool:
-    return f.python_module == 'nn'
+    return f.python_module == "nn"
+
 
 def is_py_fft_function(f: NativeFunction) -> bool:
-    return f.python_module == 'fft'
+    return f.python_module == "fft"
+
 
 def is_py_linalg_function(f: NativeFunction) -> bool:
-    return f.python_module == 'linalg'
+    return f.python_module == "linalg"
+
 
 def is_py_sparse_function(f: NativeFunction) -> bool:
-    return f.python_module == 'sparse'
+    return f.python_module == "sparse"
+
 
 def is_py_special_function(f: NativeFunction) -> bool:
-    return f.python_module == 'special'
+    return f.python_module == "special"
+
 
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
@@ -160,54 +228,104 @@
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
-def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
+
+def gen(
+    out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str
+) -> None:
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
     native_functions = parse_native_yaml(native_yaml_path).native_functions
     native_functions = list(filter(should_generate_py_binding, native_functions))
 
     methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
     create_python_bindings(
-        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
+        fm,
+        methods,
+        is_py_variable_method,
+        None,
+        "python_variable_methods.cpp",
+        method=True,
+    )
 
     # NOTE: num_shards here must be synced with gatherTorchFunctions in
     #       torch/csrc/autograd/python_torch_functions_manual.cpp
     functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
     create_python_bindings_sharded(
-        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp',
-        method=False, num_shards=3)
+        fm,
+        functions,
+        is_py_torch_function,
+        "torch",
+        "python_torch_functions.cpp",
+        method=False,
+        num_shards=3,
+    )
 
     create_python_bindings(
-        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
+        fm,
+        functions,
+        is_py_nn_function,
+        "torch.nn",
+        "python_nn_functions.cpp",
+        method=False,
+    )
 
     create_python_bindings(
-        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
+        fm,
+        functions,
+        is_py_fft_function,
+        "torch.fft",
+        "python_fft_functions.cpp",
+        method=False,
+    )
 
     create_python_bindings(
-        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
+        fm,
+        functions,
+        is_py_linalg_function,
+        "torch.linalg",
+        "python_linalg_functions.cpp",
+        method=False,
+    )
 
     create_python_bindings(
-        fm, functions, is_py_sparse_function, 'torch.sparse', 'python_sparse_functions.cpp', method=False)
+        fm,
+        functions,
+        is_py_sparse_function,
+        "torch.sparse",
+        "python_sparse_functions.cpp",
+        method=False,
+    )
 
     create_python_bindings(
-        fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
+        fm,
+        functions,
+        is_py_special_function,
+        "torch.special",
+        "python_special_functions.cpp",
+        method=False,
+    )
 
     # Currently, we only use `functions` to generate `return_types` bindings.
     # All methods which return namedtuple have function variant at this point.
     # If any method only operator with namedtuple is added in the future,
     # we will have to address that.
     create_python_return_type_bindings(
-        fm, functions, lambda fn: True, 'python_return_types.cpp')
+        fm, functions, lambda fn: True, "python_return_types.cpp"
+    )
+
 
 def group_filter_overloads(
     pairs: Sequence[PythonSignatureNativeFunctionPair],
-    pred: Callable[[NativeFunction], bool]
+    pred: Callable[[NativeFunction], bool],
 ) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
-    grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
+    grouped: Dict[
+        BaseOperatorName, List[PythonSignatureNativeFunctionPair]
+    ] = defaultdict(list)
     for pair in pairs:
         if pred(pair.function):
             grouped[pair.function.func.name.name].append(pair)
     return grouped
 
+
 def create_python_bindings(
     fm: FileManager,
     pairs: Sequence[PythonSignatureNativeFunctionPair],
@@ -230,15 +348,20 @@
         py_methods.append(method_impl(name, module, overloads, method=method))
         py_method_defs.append(method_def(name, module, overloads, method=method))
         py_forwards.extend(forward_decls(name, overloads, method=method))
-        ops_headers.append(f'#include <ATen/ops/{name.base}.h>')
+        ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
 
-    fm.write_with_template(filename, filename, lambda: {
-        'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
-        'ops_headers': ops_headers,
-        'py_forwards': py_forwards,
-        'py_methods': py_methods,
-        'py_method_defs': py_method_defs,
-    })
+    fm.write_with_template(
+        filename,
+        filename,
+        lambda: {
+            "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
+            "ops_headers": ops_headers,
+            "py_forwards": py_forwards,
+            "py_methods": py_methods,
+            "py_method_defs": py_method_defs,
+        },
+    )
+
 
 def create_python_return_type_bindings(
     fm: FileManager,
@@ -257,15 +380,24 @@
 
     for name in sorted(grouped.keys(), key=lambda x: str(x)):
         overloads = grouped[name]
-        definitions, map_entries = generate_return_type_definition_and_map_entry(overloads)
-        py_return_types_definition.append("" if not definitions else "\n".join(definitions))
+        definitions, map_entries = generate_return_type_definition_and_map_entry(
+            overloads
+        )
+        py_return_types_definition.append(
+            "" if not definitions else "\n".join(definitions)
+        )
         py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
 
-    fm.write_with_template(filename, filename, lambda: {
-        'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
-        'py_return_types': py_return_types_definition,
-        'py_return_types_map' : py_return_types_map,
-    })
+    fm.write_with_template(
+        filename,
+        filename,
+        lambda: {
+            "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
+            "py_return_types": py_return_types_definition,
+            "py_return_types_map": py_return_types_map,
+        },
+    )
+
 
 def create_python_bindings_sharded(
     fm: FileManager,
@@ -275,12 +407,14 @@
     filename: str,
     *,
     method: bool,
-    num_shards: int
+    num_shards: int,
 ) -> None:
     """Generates Python bindings to ATen functions"""
     grouped = group_filter_overloads(pairs, pred)
 
-    def key_func(kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]) -> str:
+    def key_func(
+        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
+    ) -> str:
         return kv[0].base
 
     def env_func(
@@ -288,25 +422,25 @@
     ) -> Dict[str, List[str]]:
         name, fn_pairs = kv
         return {
-            'ops_headers': [f'#include <ATen/ops/{name.base}.h>'],
-            'py_forwards': list(forward_decls(name, fn_pairs, method=method)),
-            'py_methods': [method_impl(name, module, fn_pairs, method=method)],
-            'py_method_defs': [method_def(name, module, fn_pairs, method=method)],
+            "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
+            "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
+            "py_methods": [method_impl(name, module, fn_pairs, method=method)],
+            "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
         }
 
     fm.write_sharded(
         filename,
         grouped.items(),
         base_env={
-            'generated_comment':
-            '@' + f'generated from {fm.template_dir}/{filename}',
+            "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
         },
         key_fn=key_func,
         env_callable=env_func,
         num_shards=num_shards,
-        sharded_keys={'ops_headers', 'py_forwards', 'py_methods', 'py_method_defs'}
+        sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
     )
 
+
 def load_signatures(
     native_functions: List[NativeFunction],
     deprecated_yaml_path: str,
@@ -315,7 +449,6 @@
     skip_deprecated: bool = False,
     pyi: bool = False,
 ) -> Sequence[PythonSignatureNativeFunctionPair]:
-
     @with_native_function
     def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
         return PythonSignatureNativeFunctionPair(
@@ -324,9 +457,12 @@
         )
 
     pairs = list(map(gen_signature_pairs, native_functions))
-    deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi)
+    deprecated = load_deprecated_signatures(
+        pairs, deprecated_yaml_path, method=method, pyi=pyi
+    )
     return pairs if skip_deprecated else pairs + deprecated
 
+
 def load_deprecated_signatures(
     pairs: Sequence[PythonSignatureNativeFunctionPair],
     deprecated_yaml_path: str,
@@ -345,28 +481,35 @@
         # remove inplace suffix but keep outplace suffix
         opname = str(f.func.name.name.base)
         if f.func.is_out_fn():
-            opname += '_out'
+            opname += "_out"
         if f.func.name.name.inplace and pyi:
-            opname += '_'
-        args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
+            opname += "_"
+        args = CppSignatureGroup.from_native_function(
+            f, method=False
+        ).signature.arguments()
         # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
-        types = ', '.join(argument_type_str(a.argument.type)
-                          for a in args if isinstance(a.argument, Argument))
-        return f'{opname}({types})'
+        types = ", ".join(
+            argument_type_str(a.argument.type)
+            for a in args
+            if isinstance(a.argument, Argument)
+        )
+        return f"{opname}({types})"
 
     # deprecated -> type-only native signature (according to the call order)
-    def signature_deprecated(opname: str, params: List[str], call_args: List[str]) -> str:
+    def signature_deprecated(
+        opname: str, params: List[str], call_args: List[str]
+    ) -> str:
         # create a mapping of parameter name to parameter type
         types: Dict[str, str] = {}
         for param in params:
-            if param == '*':
+            if param == "*":
                 continue
-            type, name = param.split(' ')
+            type, name = param.split(" ")
             types[name] = type
         # if the name in the call is not in the parameter list, assume it's
         # a literal Scalar
-        rearranged_types = ', '.join(types.get(arg, 'Scalar') for arg in call_args)
-        return f'{opname}({rearranged_types})'
+        rearranged_types = ", ".join(types.get(arg, "Scalar") for arg in call_args)
+        return f"{opname}({rearranged_types})"
 
     # group the original ATen signatures by type-only signature
     grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
@@ -376,12 +519,12 @@
     # find matching original signatures for each deprecated signature
     results: List[PythonSignatureNativeFunctionPair] = []
 
-    with open(deprecated_yaml_path, 'r') as f:
+    with open(deprecated_yaml_path, "r") as f:
         deprecated_defs = yaml.load(f, Loader=YamlLoader)
 
     for deprecated in deprecated_defs:
-        _, params = split_name_params(deprecated['name'])
-        aten_name, call_args = split_name_params(deprecated['aten'])
+        _, params = split_name_params(deprecated["name"])
+        aten_name, call_args = split_name_params(deprecated["aten"])
 
         for pair in grouped[signature_deprecated(aten_name, params, call_args)]:
             # It uses the types from the original ATen declaration, but the
@@ -392,12 +535,15 @@
             # but never changes output_args nor TensorOptions (if any?),
             # so here we only look into these two types of args.
             python_sig = pair.signature
-            src_args: Dict[str, PythonArgument] = {a.name: PythonArgument(
-                name=a.name,
-                type=a.type,
-                default=None,
-                default_init=None,
-            ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)}
+            src_args: Dict[str, PythonArgument] = {
+                a.name: PythonArgument(
+                    name=a.name,
+                    type=a.type,
+                    default=None,
+                    default_init=None,
+                )
+                for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)
+            }
 
             args: List[str] = []
             input_args: List[PythonArgument] = []
@@ -405,10 +551,10 @@
 
             kwarg_only = False
             for param in params:
-                if param == '*':
+                if param == "*":
                     kwarg_only = True
                     continue
-                _, param_name = param.split(' ')
+                _, param_name = param.split(" ")
                 args.append(param_name)
 
                 if param_name not in src_args:
@@ -416,49 +562,56 @@
                     continue
 
                 if not kwarg_only:
-                    if not method or param_name != 'self':
+                    if not method or param_name != "self":
                         input_args.append(src_args[param_name])
                 else:
                     input_kwargs.append(src_args[param_name])
 
-            results.append(PythonSignatureNativeFunctionPair(
-                signature=PythonSignatureDeprecated(
-                    name=python_sig.name,
-                    input_args=tuple(input_args),
-                    input_kwargs=tuple(input_kwargs),
-                    output_args=python_sig.output_args,
-                    tensor_options_args=python_sig.tensor_options_args,
-                    method=python_sig.method,
-                    deprecated_args_names=tuple(args),
-                    deprecated_args_exprs=tuple(call_args),
-                    returns=python_sig.returns,
-                ),
-                function=pair.function,
-            ))
+            results.append(
+                PythonSignatureNativeFunctionPair(
+                    signature=PythonSignatureDeprecated(
+                        name=python_sig.name,
+                        input_args=tuple(input_args),
+                        input_kwargs=tuple(input_kwargs),
+                        output_args=python_sig.output_args,
+                        tensor_options_args=python_sig.tensor_options_args,
+                        method=python_sig.method,
+                        deprecated_args_names=tuple(args),
+                        deprecated_args_exprs=tuple(call_args),
+                        returns=python_sig.returns,
+                    ),
+                    function=pair.function,
+                )
+            )
 
     return results
 
+
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
 #                         Named Tuple Codegen
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
+
 @with_native_function
 def gen_namedtuple_typename_key(f: NativeFunction) -> str:
     name = cpp.name(f.func)
     fieldnames = namedtuple_fieldnames(f.func.returns)
-    return '_'.join([name] + fieldnames)
+    return "_".join([name] + fieldnames)
+
 
 def emit_namedtuple_call(
-    overloads: Sequence[PythonSignatureNativeFunctionPair]
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
 ) -> Tuple[List[str], Dict[str, str]]:
     """
     Generate block of named tuple type def inits, and add typeref snippets
     to declarations that use them
     """
-    typenames: Dict[str, str] = {}    # map from unique name + field name lists to typedef name
-    typedefs: List[str] = []          # typedef declarations and init code
+    typenames: Dict[
+        str, str
+    ] = {}  # map from unique name + field name lists to typedef name
+    typedefs: List[str] = []  # typedef declarations and init code
 
     for overload in overloads:
         fieldnames = namedtuple_fieldnames(overload.function.func.returns)
@@ -471,8 +624,10 @@
         if typename is None:
             typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
             typenames[tn_key] = typename
-            typedefs.append(f"""\
-static PyTypeObject* {typename} = get_namedtuple("{name}");""")
+            typedefs.append(
+                f"""\
+static PyTypeObject* {typename} = get_namedtuple("{name}");"""
+            )
 
     return typedefs, typenames
 
@@ -485,16 +640,20 @@
     and return named tuple for a native function which returns named tuple
     and relevant entry for the map in same file.
     """
-    typenames: Dict[str, str] = {}  # map from unique name + field name lists to typedef name
+    typenames: Dict[
+        str, str
+    ] = {}  # map from unique name + field name lists to typedef name
     definitions: List[str] = []  # function defintion to register the typedef
-    map_entries: List[str] = []  # C++ map entry of <function_name, function creates it namedtuple>
+    map_entries: List[
+        str
+    ] = []  # C++ map entry of <function_name, function creates it namedtuple>
 
     for overload in overloads:
         fieldnames = namedtuple_fieldnames(overload.function.func.returns)
         if not fieldnames:
             continue
 
-        fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)
+        fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
 
         name = cpp.name(overload.function.func)  # use @with_native_function?
         tn_key = gen_namedtuple_typename_key(overload.function)
@@ -503,7 +662,8 @@
         if typename is None:
             typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
             typenames[tn_key] = typename
-            definitions.append(f"""\
+            definitions.append(
+                f"""\
 PyTypeObject* get_{name}_namedtuple() {{
     static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
     static PyTypeObject {typename};
@@ -516,11 +676,13 @@
     }}
     return &{typename};
 }}
-""")
+"""
+            )
             map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
 
     return definitions, map_entries
 
+
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
 #                         Method Impl Codegen
@@ -528,7 +690,8 @@
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
 # python binding for all overloads of a particular function/method
-PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\
+PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
+    r"""\
 // ${name}
 static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
 {
@@ -546,19 +709,23 @@
   ${method_footer}
 }
 
-""")
+"""
+)
 
 # handler for a single parsed signature - may be a single overload or
 # a pair of overloads that whose signatures only differ in output params
 # (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
-PY_VARIABLE_CASE = CodeTemplate("""\
+PY_VARIABLE_CASE = CodeTemplate(
+    """\
 case ${overload_index}: {
   ${body}
 }
-""")
+"""
+)
 
 # python binding for single-overload function/method
-PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\
+PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
+    """\
 // ${name}
 static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
 {
@@ -574,10 +741,12 @@
   ${method_footer}
 }
 
-""")
+"""
+)
 
 # python binding for a method with no args, shortcuts parsing
-PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
+PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
+    """\
 // ${name}
 static PyObject * ${pycname}(PyObject* self_, PyObject* args)
 {
@@ -587,14 +756,16 @@
   ${method_footer}
 }
 
-""")
+"""
+)
+
 
 def method_impl(
     name: BaseOperatorName,
     module: Optional[str],
     overloads: Sequence[PythonSignatureNativeFunctionPair],
     *,
-    method: bool
+    method: bool,
 ) -> str:
     """
     Generate a python binding for all overloads of an op.
@@ -603,15 +774,15 @@
     noarg = is_noarg(overloads)
     namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads)
 
-    method_header = ['HANDLE_TH_ERRORS']
+    method_header = ["HANDLE_TH_ERRORS"]
     method_header += namedtuple_inits
-    method_header += [
-        "const Tensor& self = THPVariable_Unpack(self_);"
-    ] if method else []
+    method_header += (
+        ["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
+    )
 
-    method_footer = ([] if noarg else ['Py_RETURN_NONE;']) + ['END_HANDLE_TH_ERRORS']
+    method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
 
-    traceable = 'true' if all(should_trace(o.function) for o in overloads) else 'false'
+    traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
 
     grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
     is_singleton = len(grouped_overloads) == 1
@@ -619,11 +790,15 @@
     dispatch: List[str] = []
     for overload_index, overload in enumerate(grouped_overloads):
         signature = overload.signature.signature_str()
-        signatures.append(f'{cpp_string(str(signature))},')
+        signatures.append(f"{cpp_string(str(signature))},")
         dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
         dispatch.append(
-            PY_VARIABLE_CASE.substitute(overload_index=overload_index, body=dispatch_body)
-            if not is_singleton else dispatch_body)
+            PY_VARIABLE_CASE.substitute(
+                overload_index=overload_index, body=dispatch_body
+            )
+            if not is_singleton
+            else dispatch_body
+        )
 
     if noarg:
         template = PY_VARIABLE_METHOD_NOARGS
@@ -650,6 +825,7 @@
         self_="self_" if method else "nullptr",
     )
 
+
 def gen_has_torch_function_check(
     name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
 ) -> str:
@@ -661,17 +837,21 @@
 }}
 """
         else:
-            return ''
+            return ""
 
     self_ = "self_" if method else "nullptr"
-    namespace = {
-        "torch": "THPVariableFunctionsModule",
-        "torch.nn": "THPNNVariableFunctionsModule",
-        "torch.fft": "THPFFTVariableFunctionsModule",
-        "torch.linalg": "THPLinalgVariableFunctionsModule",
-        "torch.sparse": "THPSparseVariableFunctionsModule",
-        "torch.special": "THPSpecialVariableFunctionsModule",
-    }[module] if module else "THPVariableClass"
+    namespace = (
+        {
+            "torch": "THPVariableFunctionsModule",
+            "torch.nn": "THPNNVariableFunctionsModule",
+            "torch.fft": "THPFFTVariableFunctionsModule",
+            "torch.linalg": "THPLinalgVariableFunctionsModule",
+            "torch.sparse": "THPSparseVariableFunctionsModule",
+            "torch.special": "THPSpecialVariableFunctionsModule",
+        }[module]
+        if module
+        else "THPVariableClass"
+    )
 
     return f"""\
 if(_r.has_torch_function()) {{
@@ -679,14 +859,18 @@
 }}
 """
 
+
 # handler for output/no-output overload pair
-PY_VARIABLE_OUT = CodeTemplate("""\
+PY_VARIABLE_OUT = CodeTemplate(
+    """\
 if (_r.isNone(${out_idx})) {
   ${call_dispatch}
 } else {
   ${call_dispatch_out}
 }
-""")
+"""
+)
+
 
 def emit_dispatch_case(
     overload: PythonSignatureGroup,
@@ -703,14 +887,18 @@
         return PY_VARIABLE_OUT.substitute(
             out_idx=overload.signature.output_idx(),
             call_dispatch=emit_single_dispatch(
-                overload.signature, overload.base, namedtuple_typenames),
+                overload.signature, overload.base, namedtuple_typenames
+            ),
             call_dispatch_out=emit_single_dispatch(
-                overload.signature, overload.outplace, namedtuple_typenames),
+                overload.signature, overload.outplace, namedtuple_typenames
+            ),
         )
     else:
         # no-output version only
         return emit_single_dispatch(
-            overload.signature, overload.base, namedtuple_typenames)
+            overload.signature, overload.base, namedtuple_typenames
+        )
+
 
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
@@ -718,24 +906,30 @@
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
+
 def forward_decls(
     name: BaseOperatorName,
     overloads: Sequence[PythonSignatureNativeFunctionPair],
     *,
-    method: bool
+    method: bool,
 ) -> Tuple[str, ...]:
     if method:
         return ()
 
     pycname = get_pycname(name)
     if is_noarg(overloads):
-        return (f"""\
+        return (
+            f"""\
 static PyObject * {pycname}(PyObject* self_, PyObject* args);
-""",)
+""",
+        )
     else:
-        return (f"""\
+        return (
+            f"""\
 static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
-""",)
+""",
+        )
+
 
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
@@ -743,12 +937,13 @@
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
+
 def method_def(
     name: BaseOperatorName,
     module: Optional[str],
     overloads: Sequence[PythonSignatureNativeFunctionPair],
     *,
-    method: bool
+    method: bool,
 ) -> str:
     """
     Generate method def entry.
@@ -756,14 +951,14 @@
     pycname = get_pycname(name)
 
     if is_noarg(overloads):
-        pyfunc_cast = ''
-        flags = 'METH_NOARGS' if method else 'METH_VARARGS | METH_KEYWORDS'
+        pyfunc_cast = ""
+        flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
     else:
-        pyfunc_cast = 'castPyCFunctionWithKeywords'
-        flags = 'METH_VARARGS | METH_KEYWORDS'
+        pyfunc_cast = "castPyCFunctionWithKeywords"
+        flags = "METH_VARARGS | METH_KEYWORDS"
 
     if module == "torch":
-        flags += ' | METH_STATIC'
+        flags += " | METH_STATIC"
 
     if name.dunder_method:
         # PyMethodDef entry for binary op, throws not implemented error
@@ -774,12 +969,14 @@
         return f"""\
 {{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
 
+
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
 #                   Overload Sorting and Grouping
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
+
 def group_overloads(
     overloads: Sequence[PythonSignatureNativeFunctionPair],
 ) -> Sequence[PythonSignatureGroup]:
@@ -792,15 +989,15 @@
         if overload.function.func.is_out_fn():
             if sig in outplaces:
                 raise RuntimeError(
-                    f'Found duplicated function definition:\n- {overload.function.func}.\n'
-                    f'Existing definition:\n- {outplaces[sig].function.func}.'
+                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
+                    f"Existing definition:\n- {outplaces[sig].function.func}."
                 )
             outplaces[sig] = overload
         else:
             if sig in bases:
                 raise RuntimeError(
-                    f'Found duplicated function definition:\n- {overload.function.func}.\n'
-                    f'Existing definition:\n- {bases[sig].function.func}.'
+                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
+                    f"Existing definition:\n- {bases[sig].function.func}."
                 )
             bases[sig] = overload
 
@@ -808,30 +1005,41 @@
         if sig not in bases:
             candidates: List[str] = []
             for overload in overloads:
-                if str(overload.function.func.name.name) == str(out.function.func.name.name) \
-                        and not overload.function.func.is_out_fn() \
-                        and not overload.signature.deprecated:
-                    candidates.append(overload.signature.signature_str(skip_outputs=True))
+                if (
+                    str(overload.function.func.name.name)
+                    == str(out.function.func.name.name)
+                    and not overload.function.func.is_out_fn()
+                    and not overload.signature.deprecated
+                ):
+                    candidates.append(
+                        overload.signature.signature_str(skip_outputs=True)
+                    )
             out_sig = out.signature.signature_str()
             raise RuntimeError(
-                f'While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. '
-                f'We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema '
-                'correctly in native_functions.yaml. We discovered the following candidate(s): \n'
-                + '\n'.join(f'- {candidate}' for candidate in candidates))
+                f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
+                f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
+                "correctly in native_functions.yaml. We discovered the following candidate(s): \n"
+                + "\n".join(f"- {candidate}" for candidate in candidates)
+            )
 
     grouped: List[PythonSignatureGroup] = []
     for sig, base in bases.items():
         outplace = outplaces.get(sig)
-        grouped.append(PythonSignatureGroup(
-            # prefer the signature with optional out=... arguments because it's the
-            # superset that can be used to parse input for both base and outplace.
-            signature=outplace.signature if outplace is not None else base.signature,
-            base=base.function,
-            outplace=outplace.function if outplace is not None else None,
-        ))
+        grouped.append(
+            PythonSignatureGroup(
+                # prefer the signature with optional out=... arguments because it's the
+                # superset that can be used to parse input for both base and outplace.
+                signature=outplace.signature
+                if outplace is not None
+                else base.signature,
+                base=base.function,
+                outplace=outplace.function if outplace is not None else None,
+            )
+        )
 
     return sort_overloads(grouped)
 
+
 # This function declares a partial order on declarations, and sorts them according
 # to its linear extension. This is necessary, because there's some ambiguity in the
 # choice of overload, and we want a different order.
@@ -876,20 +1084,27 @@
 #     foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
 #
 
+
 def sort_overloads(
-    grouped_overloads: Sequence[PythonSignatureGroup]
+    grouped_overloads: Sequence[PythonSignatureGroup],
 ) -> Sequence[PythonSignatureGroup]:
-
     def is_arg_smaller(t1: Type, t2: Type) -> bool:
-        return (str(t1) == 'Scalar' and str(t2) == 'Tensor' or
-                'Dimname' in str(t1) and 'Dimname' not in str(t2) or
-                # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
-                # discussed why it is important to prioritize int/int? over int[]
-                str(t1) == 'int[]' and (str(t2) == 'int' or str(t2) == 'int?') or
-                # TensorList currently throws an error during argument parsing, that's why it needs to be
-                # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
-                str(t1) == 'Tensor[]' and str(t2).find("[]") != -1)
-
+        return (
+            str(t1) == "Scalar"
+            and str(t2) == "Tensor"
+            or "Dimname" in str(t1)
+            and "Dimname" not in str(t2)
+            or
+            # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
+            # discussed why it is important to prioritize int/int? over int[]
+            str(t1) == "int[]"
+            and (str(t2) == "int" or str(t2) == "int?")
+            or
+            # TensorList currently throws an error during argument parsing, that's why it needs to be
+            # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
+            str(t1) == "Tensor[]"
+            and str(t2).find("[]") != -1
+        )
 
     def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
         """Returns True if s1 < s2 in the partial order."""
@@ -900,13 +1115,16 @@
         # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
         # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
         equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
-        smaller_or_equal = all(str(arg1.type) == str(arg2.type)
-                               or is_arg_smaller(arg1.type, arg2.type)
-                               for arg1, arg2 in zip(args1, args2))
+        smaller_or_equal = all(
+            str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
+            for arg1, arg2 in zip(args1, args2)
+        )
         return smaller_or_equal and not equal
 
     # First sort by signature
-    grouped_overloads = sorted(grouped_overloads, key=lambda x: x.signature.signature_str())
+    grouped_overloads = sorted(
+        grouped_overloads, key=lambda x: x.signature.signature_str()
+    )
 
     # Construct the relation graph
     larger_than: Dict[int, Set[int]] = defaultdict(set)
@@ -934,39 +1152,43 @@
 
     return list(map(lambda x: grouped_overloads[x], sorted_ids))
 
+
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #
 #                       Codegen API Integration
 #
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
+
 def emit_single_dispatch(
     ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
 ) -> str:
     """
     Emit dispatch code for a single native function.
     """
+
     @with_native_function
     def go(f: NativeFunction) -> str:
         # header comments
-        deprecated = '[deprecated] ' if ps.deprecated else ''
-        schema_comment = f'// {deprecated}aten::{f.func}'
+        deprecated = "[deprecated] " if ps.deprecated else ""
+        schema_comment = f"// {deprecated}aten::{f.func}"
 
         # dispatch lambda signature
         name = cpp.name(f.func)
-        lambda_formals = ', '.join(map(lambda a: f"{a.type_str} {a.name}",
-                                       dispatch_lambda_args(ps, f)))
+        lambda_formals = ", ".join(
+            map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
+        )
         lambda_return = dispatch_lambda_return_str(f)
 
         # dispatch lambda body
         dispatch_callee = cpp_dispatch_target(f)
-        dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps))
+        dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
 
         # from arg parser outputs to dispatch lambda arguments
         parser_outputs = arg_parser_output_exprs(ps, f)
         lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
-        inits = '\n'.join(lambda_arg_exprs.inits)
-        lambda_args = ', '.join(lambda_arg_exprs.exprs)
+        inits = "\n".join(lambda_arg_exprs.inits)
+        lambda_args = ", ".join(lambda_arg_exprs.exprs)
 
         # scatter fields
         # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
@@ -974,12 +1196,17 @@
         #       new_full, new_empty, and new_zeros. A much better but more difficult to
         #       implement solution involves refactoring according to Ed's description here:
         #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
-        need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or (
-            ps.method and ('requires_grad' in parser_outputs)))
-        set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \
-            if need_set_requires_grad else ''
+        need_set_requires_grad = ps.tensor_options_args and (
+            not has_tensor_options(f)
+            or (ps.method and ("requires_grad" in parser_outputs))
+        )
+        set_requires_grad = (
+            f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
+            if need_set_requires_grad
+            else ""
+        )
 
-        if lambda_return == 'void':
+        if lambda_return == "void":
             return f"""\
 {schema_comment}
 {inits}
@@ -992,7 +1219,7 @@
 """
         else:
             typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
-            namedtuple_typeref = f'{typename}, ' if typename is not None else ''
+            namedtuple_typeref = f"{typename}, " if typename is not None else ""
             return f"""\
 {schema_comment}
 {inits}
diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py
index 1b9cc7e..c1e72ad 100644
--- a/tools/autograd/gen_trace_type.py
+++ b/tools/autograd/gen_trace_type.py
@@ -6,8 +6,12 @@
 from tools.codegen.code_template import CodeTemplate
 from tools.codegen.context import with_native_function
 from tools.codegen.utils import FileManager
-from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
-                                 TensorOptionsArguments)
+from tools.codegen.model import (
+    Argument,
+    NativeFunction,
+    SchemaKind,
+    TensorOptionsArguments,
+)
 
 # Note [Manual Backend kernels]
 # For these ops, we want to manually register to dispatch key Backend and
@@ -19,16 +23,33 @@
 #   - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
 # Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
 # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
-MANUAL_BACKEND = set([
-    'options', 'data', 'set_data', 'is_leaf', 'output_nr', '_version', 'retain_grad',
-    '_backward', 'requires_grad_',
-])
+MANUAL_BACKEND = set(
+    [
+        "options",
+        "data",
+        "set_data",
+        "is_leaf",
+        "output_nr",
+        "_version",
+        "retain_grad",
+        "_backward",
+        "requires_grad_",
+    ]
+)
 
 # For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
 # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
-MANUAL_AUTOGRAD_AND_TRACER = set([
-    'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', '_fw_primal', '_make_dual',
-])
+MANUAL_AUTOGRAD_AND_TRACER = set(
+    [
+        "resize_",
+        "resize_as_",
+        "detach",
+        "detach_",
+        "copy_",
+        "_fw_primal",
+        "_make_dual",
+    ]
+)
 
 # Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
 #   union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)
@@ -41,45 +62,65 @@
 # on demand.  Only concrete ATen methods can be disabled this way; it will have
 # NO EFFECT otherwise.
 DONT_RECORD_TRACE = {
-    'convolution', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
-    'conv_transpose2d', 'conv_transpose3d', 'lstm_cell', 'gru_cell',
-    'rnn_tanh_cell', 'rnn_relu_cell',
+    "convolution",
+    "conv1d",
+    "conv2d",
+    "conv3d",
+    "conv_transpose1d",
+    "conv_transpose2d",
+    "conv_transpose3d",
+    "lstm_cell",
+    "gru_cell",
+    "rnn_tanh_cell",
+    "rnn_relu_cell",
     # FIXME: figure out a better way when we support sparse tensors in jit
-    '_coalesced',
+    "_coalesced",
 }
 
+
 def should_trace(f: NativeFunction) -> bool:
     # Operations involving Storage or Type are not traceable at the moment
-    if any(str(arg.type) in {'Storage', 'Type', 'ConstQuantizerPtr'}
-           for arg in f.func.schema_order_arguments()):
+    if any(
+        str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"}
+        for arg in f.func.schema_order_arguments()
+    ):
         return False
     # We can't trace functions which don't have any Tensor or TensorList returns
     if not any(r.type.is_tensor_like() for r in f.func.returns):
         return False
     return f.func.name.name.base not in DONT_RECORD_TRACE
 
-SELECT = CodeTemplate("""\
+
+SELECT = CodeTemplate(
+    """\
 
 if (${cond}) {
   ${true}
 } else {
   ${false}
 }
-""")
+"""
+)
 
-OP_NAME = CodeTemplate("""\
+OP_NAME = CodeTemplate(
+    """\
 op_name = c10::Symbol::fromQualString("aten::${trace_name}");
-""")
+"""
+)
 
 # These functions have their names recorded under trace renamed,
 RENAME_TRACE = {
-    'zero': 'zeros_like',  # replacing aten::zero_ with aten::zeros_like
-    'fill': 'full_like',  # replacing aten::fill_ with aten::full_like
+    "zero": "zeros_like",  # replacing aten::zero_ with aten::zeros_like
+    "fill": "full_like",  # replacing aten::fill_ with aten::full_like
 }
 
+
 def format_trace_op_name(f: NativeFunction) -> str:
     # TODO: byte-for-byte compatible with old codegen behavior - should clean up
-    if f.func.kind() in (SchemaKind.functional, SchemaKind.out) or f.func.name.name.dunder_method:
+    if (
+        f.func.kind() in (SchemaKind.functional, SchemaKind.out)
+        or f.func.name.name.dunder_method
+    ):
         # special case for *_out functions: the in-place and out-of-place ops
         # are overloaded with the same name in the JIT
         trace_name = str(f.func.name.name)
@@ -94,32 +135,39 @@
     inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
 
     return SELECT.substitute(
-        cond='tracer_state->force_outplace',
+        cond="tracer_state->force_outplace",
         true=OP_NAME.substitute(trace_name=outplace_trace_name),
         false=OP_NAME.substitute(trace_name=inplace_trace_name),
     )
 
+
 ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")
 
-def format_trace_inputs(f: NativeFunction) -> str:
 
-    def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequence[str]:
+def format_trace_inputs(f: NativeFunction) -> str:
+    def dispatch_trace_input(
+        arg: Union[Argument, TensorOptionsArguments]
+    ) -> Sequence[str]:
         if isinstance(arg, TensorOptionsArguments):
-            name = 'options'
+            name = "options"
             return [
-                ADD_TRACE_INPUT.substitute(name=name, input='optTypeMetaToScalarType(options.dtype_opt())'),
-                ADD_TRACE_INPUT.substitute(name=name, input='options.layout()'),
-                ADD_TRACE_INPUT.substitute(name=name, input='options.device()'),
-                ADD_TRACE_INPUT.substitute(name=name, input='options.pinned_memory()'),
+                ADD_TRACE_INPUT.substitute(
+                    name=name, input="optTypeMetaToScalarType(options.dtype_opt())"
+                ),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.device()"),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"),
             ]
         else:
             name = arg.name
-            if str(arg.type) == 'Tensor?[]':
+            if str(arg.type) == "Tensor?[]":
                 return [f'jit::tracer::addInputs(node, "{name}", {name});']
             else:
                 return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
 
-    args: List[Union[Argument, TensorOptionsArguments]] = list(f.func.schema_order_arguments())
+    args: List[Union[Argument, TensorOptionsArguments]] = list(
+        f.func.schema_order_arguments()
+    )
 
     if f.func.is_out_fn():
         # *_out functions take the result as a separate argument, but we don't want to
@@ -129,7 +177,9 @@
         # there is only one output argument.
         args = args[:-1]
 
-    trace_inputs = itertools.chain.from_iterable(dispatch_trace_input(arg) for arg in args)
+    trace_inputs = itertools.chain.from_iterable(
+        dispatch_trace_input(arg) for arg in args
+    )
 
     if f.func.is_out_fn():
         # for *_out functions, handle the result argument differently for inplace/outplace.
@@ -141,32 +191,49 @@
         # Factories are a bit special because their out-of-place overloads
         # take an extra TensorOptions argument, which is missing in the _out function
         has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
-        has_tensor_input_arg = any(a.type.is_tensor_like() for a in f.func.arguments.flat_non_out)
-        is_factory_method = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg)
+        has_tensor_input_arg = any(
+            a.type.is_tensor_like() for a in f.func.arguments.flat_non_out
+        )
+        is_factory_method = f.category_override == "factory" or (
+            has_tensor_return and not has_tensor_input_arg
+        )
 
         # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method`
         # flag for the whole family of ops with the same basename if any of them is a
         # factory method. For most cases the whole family of ops are indeed all factory
         # method - 'normal' is the only exception. So we handle it specially here to avoid
         # cloning the old logic.
-        if f.func.name.name.base == 'normal':
+        if f.func.name.name.base == "normal":
             is_factory_method = True
 
         if is_factory_method:
             outplace = [
-                ADD_TRACE_INPUT.substitute(name='out', input='optTypeMetaToScalarType(out.options().dtype_opt())'),
-                ADD_TRACE_INPUT.substitute(name='out', input='out.options().layout()'),
-                ADD_TRACE_INPUT.substitute(name='out', input='out.options().device()'),
-                ADD_TRACE_INPUT.substitute(name='out', input='out.options().pinned_memory()'),
+                ADD_TRACE_INPUT.substitute(
+                    name="out",
+                    input="optTypeMetaToScalarType(out.options().dtype_opt())",
+                ),
+                ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"),
+                ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"),
+                ADD_TRACE_INPUT.substitute(
+                    name="out", input="out.options().pinned_memory()"
+                ),
             ]
         else:
             outplace = []
 
         trace_inputs = itertools.chain(
             trace_inputs,
-            [SELECT.substitute(cond='tracer_state->force_outplace', true='\n'.join(outplace), false=inplace)])
+            [
+                SELECT.substitute(
+                    cond="tracer_state->force_outplace",
+                    true="\n".join(outplace),
+                    false=inplace,
+                )
+            ],
+        )
 
-    return '\n'.join(trace_inputs)
+    return "\n".join(trace_inputs)
+
 
 # `torch.jit.trace` have undocumented keyword argument `_force_outplace`,
 # which force jit to replace functions with outplace variants (for
@@ -191,29 +258,32 @@
 #  - Or keep `aten::zeros_like` arguments aligned with `aten::zero_`
 # arguments (inside of the `native_functions.yaml`)
 RENAME_TRACE_ADD_ARGS = {
-    'fill': '''\
+    "fill": """\
     jit::tracer::addInputs(node, "options", c10::optional<ScalarType>());
     jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt));
     jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt));
     jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt));
     c10::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
     jit::tracer::addInputs(node, "memory_format", memory_format);
-''',
-    'zero': '''\
+""",
+    "zero": """\
     jit::tracer::addInputs(node, "options", c10::optional<ScalarType>());
     jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt));
     jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt));
     jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt));
     c10::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
     jit::tracer::addInputs(node, "memory_format", memory_format);
-''',
+""",
 }
 
-INPLACE_GUARD = CodeTemplate("""\
+INPLACE_GUARD = CodeTemplate(
+    """\
 jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input});
-""")
+"""
+)
 
-PRE_RECORD_TRACE = CodeTemplate("""\
+PRE_RECORD_TRACE = CodeTemplate(
+    """\
 torch::jit::Node* node = nullptr;
 std::shared_ptr<jit::tracer::TracingState> tracer_state;
 if (jit::tracer::isTracing()) {
@@ -227,40 +297,59 @@
   ${inplace_guard}
   jit::tracer::setTracingState(nullptr);
 }
-""")
+"""
+)
+
 
 def format_prerecord_trace(f: NativeFunction) -> str:
     if not should_trace(f):
-        return ''
+        return ""
 
     # TODO: clean up old codegen behavior
-    is_inplace = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) and not f.func.name.name.dunder_method
-    add_args = RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, '') if is_inplace else ''
-    additional_inputs = SELECT.substitute(
-        cond='tracer_state->force_outplace',
-        true=add_args,
-        false='',
-    ) if add_args else ''
+    is_inplace = (
+        f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
+        and not f.func.name.name.dunder_method
+    )
+    add_args = (
+        RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else ""
+    )
+    additional_inputs = (
+        SELECT.substitute(
+            cond="tracer_state->force_outplace",
+            true=add_args,
+            false="",
+        )
+        if add_args
+        else ""
+    )
 
     return PRE_RECORD_TRACE.substitute(
         set_op_name=format_trace_op_name(f),
         add_trace_inputs=format_trace_inputs(f) + additional_inputs,
         inplace_guard=INPLACE_GUARD.substitute(
             name=cpp.name(f.func),
-            mutable_input=f.func.arguments.out[0].name if f.func.arguments.out else 'self',
-        ) if is_inplace else '',
+            mutable_input=f.func.arguments.out[0].name
+            if f.func.arguments.out
+            else "self",
+        )
+        if is_inplace
+        else "",
     )
 
-POST_RECORD_TRACE = CodeTemplate("""\
+
+POST_RECORD_TRACE = CodeTemplate(
+    """\
 if (tracer_state) {
   jit::tracer::setTracingState(std::move(tracer_state));
   ${add_trace_outputs}
 }
-""")
+"""
+)
+
 
 def format_postrecord_trace(f: NativeFunction) -> str:
     if not should_trace(f):
-        return ''
+        return ""
 
     # For outplacing ops, *_out overloads require special handling to move the
     # output *argument* to a return value
@@ -271,29 +360,37 @@
         # Code size optimization: the common case is that the return value is
         # the same for both variants
         if output_names_outplace == output_names_inplace:
-            outputs = [f'jit::tracer::addOutput(node, {n});' for n in output_names_outplace]
+            outputs = [
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
+            ]
             return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
 
         selection = SELECT.substitute(
-            cond='force_outplace',
-            true='\n'.join(f'jit::tracer::addOutput(node, {n});' for n in output_names_outplace),
-            false='\n'.join(f'jit::tracer::addOutput(node, {n});' for n in output_names_inplace),
+            cond="force_outplace",
+            true="\n".join(
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
+            ),
+            false="\n".join(
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace
+            ),
         )
         return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
     else:
         output_names = cpp.return_names(f)
-        outputs = [f'jit::tracer::addOutput(node, {n});' for n in output_names]
+        outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names]
         return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
 
+
 def declare_returned_variables(f: NativeFunction) -> str:
     modifies_arguments = f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
     if modifies_arguments:
-        return ''
+        return ""
     if len(f.func.returns) == 1:
-        return ''
+        return ""
     types = map(cpp.return_type, f.func.returns)
     names = cpp.return_names(f)
-    return '\n'.join(f'{type.cpp_type()} {name};' for type, name in zip(types, names))
+    return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names))
+
 
 def tie_return_values(f: NativeFunction) -> str:
     if len(f.func.returns) == 1:
@@ -301,6 +398,7 @@
     names = cpp.return_names(f)
     return f'std::tie({", ".join(names)})'
 
+
 def get_return_value(f: NativeFunction) -> str:
     names = cpp.return_names(f)
     if len(f.func.returns) == 1:
@@ -308,11 +406,15 @@
     if f.func.kind() == SchemaKind.out:
         return f'std::forward_as_tuple({", ".join(names)})'
     else:
-        moved = ", ".join(f'std::move({name})' for name in names)
-        return f'std::make_tuple({moved})'
+        moved = ", ".join(f"std::move({name})" for name in names)
+        return f"std::make_tuple({moved})"
 
-TRACE_DISPATCH = CodeTemplate("""\
-${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""")
+
+TRACE_DISPATCH = CodeTemplate(
+    """\
+${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});"""
+)
+
 
 def emit_trace_body(f: NativeFunction) -> List[str]:
     trace_body: List[str] = []
@@ -325,47 +427,59 @@
 
     # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
     # See Note [Plumbing Keys Through The Dispatcher] for details.
-    dispatch_key_set = 'ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)'
-    redispatch_args = ', '.join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
+    dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)"
+    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
 
-    assign_return_values = f'{tie_return_values(f)} = ' \
-                           if f.func.kind() == SchemaKind.functional and f.func.returns else ''
+    assign_return_values = (
+        f"{tie_return_values(f)} = "
+        if f.func.kind() == SchemaKind.functional and f.func.returns
+        else ""
+    )
 
     # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
     # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
-    trace_body.append(TRACE_DISPATCH.substitute(
-        assign_return_values=assign_return_values,
-        unambiguous_name=f.func.name.unambiguous_name(),
-        unpacked_args=redispatch_args,
-    ))
+    trace_body.append(
+        TRACE_DISPATCH.substitute(
+            assign_return_values=assign_return_values,
+            unambiguous_name=f.func.name.unambiguous_name(),
+            unpacked_args=redispatch_args,
+        )
+    )
 
     trace_body.append(format_postrecord_trace(f))
     if f.func.returns:
-        trace_body.append(f'return {get_return_value(f)};')
+        trace_body.append(f"return {get_return_value(f)};")
     return trace_body
 
-METHOD_DEFINITION = CodeTemplate("""\
+
+METHOD_DEFINITION = CodeTemplate(
+    """\
 ${return_type} ${type_wrapper_name}(${formals}) {
   ${type_definition_body}
 }
-""")
+"""
+)
+
 
 def type_wrapper_name(f: NativeFunction) -> str:
     if f.func.name.overload_name:
-        return f'{cpp.name(f.func)}_{f.func.name.overload_name}'
+        return f"{cpp.name(f.func)}_{f.func.name.overload_name}"
     else:
         return cpp.name(f.func)
 
+
 @with_native_function
 def method_definition(f: NativeFunction) -> str:
     assert cpp.name(f.func) not in MANUAL_TRACER
 
-    formals = ', '.join(
+    formals = ", ".join(
         # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
         # See Note [Plumbing Keys Through The Dispatcher] for details.
-        ['c10::DispatchKeySet ks'] +
-        [f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
-            for a in f.func.schema_order_arguments()]
+        ["c10::DispatchKeySet ks"]
+        + [
+            f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
+            for a in f.func.schema_order_arguments()
+        ]
     )
 
     return METHOD_DEFINITION.substitute(
@@ -375,11 +489,15 @@
         type_definition_body=emit_trace_body(f),
     )
 
-WRAPPER_REGISTRATION = CodeTemplate("""\
+
+WRAPPER_REGISTRATION = CodeTemplate(
+    """\
 m.impl("${name}",
        TORCH_FN(${class_type}::${type_wrapper_name})
 );
-""")
+"""
+)
+
 
 @with_native_function
 def method_registration(f: NativeFunction) -> str:
@@ -388,31 +506,36 @@
     return WRAPPER_REGISTRATION.substitute(
         name=f.func.name,
         type_wrapper_name=type_wrapper_name(f),
-        class_type='TraceType',
+        class_type="TraceType",
     )
 
-def gen_trace_type_func(
-    fn: NativeFunction
-) -> Dict[str, List[str]]:
+
+def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
     return {
-        'ops_headers': [f'#include <ATen/ops/{fn.root_name}_ops.h>'],
-        'trace_method_definitions': [method_definition(fn)],
-        'trace_wrapper_registrations': [method_registration(fn)],
+        "ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
+        "trace_method_definitions": [method_definition(fn)],
+        "trace_wrapper_registrations": [method_registration(fn)],
     }
 
-def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None:
+
+def gen_trace_type(
+    out: str, native_functions: List[NativeFunction], template_path: str
+) -> None:
     # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
     # template regarding sharding of the generated files.
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
     fm.write_sharded(
-        'TraceType.cpp',
+        "TraceType.cpp",
         [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
         key_fn=lambda fn: fn.root_name,
         base_env={
-            'generated_comment':
-            f'@generated from {template_path}/TraceType.cpp',
+            "generated_comment": f"@generated from {template_path}/TraceType.cpp",
         },
         env_callable=gen_trace_type_func,
         num_shards=5,
-        sharded_keys={'ops_headers', 'trace_method_definitions', 'trace_wrapper_registrations'}
+        sharded_keys={
+            "ops_headers",
+            "trace_method_definitions",
+            "trace_wrapper_registrations",
+        },
     )
diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py
index 1a09902..a1a57cc 100644
--- a/tools/autograd/gen_variable_factories.py
+++ b/tools/autograd/gen_variable_factories.py
@@ -20,28 +20,37 @@
 # TODO: maybe update the cpp argument API to take optional namespace argument?
 def fully_qualified_type(argument_type: str) -> str:
     def maybe_optional_type(type: str, is_opt: bool) -> str:
-        return f'c10::optional<{type}>' if is_opt else type
+        return f"c10::optional<{type}>" if is_opt else type
 
     opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
     is_opt = opt_match is not None
     if opt_match:
-        argument_type = argument_type[opt_match.start(1):opt_match.end(1)]
+        argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
     match = TYPE_PATTERN.match(argument_type)
     if match is None:
         return maybe_optional_type(argument_type, is_opt)
     index = match.start(1)
-    qualified_type = f'{argument_type[:index]}at::{argument_type[index:]}'
+    qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
     return maybe_optional_type(qualified_type, is_opt)
 
+
 def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
     native_functions = parse_native_yaml(native_yaml_path).native_functions
     factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
-    fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
-        'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h',
-        'ops_headers': [f'#include <ATen/ops/{fn.root_name}.h>' for fn in factory_functions],
-        'function_definitions': list(mapMaybe(process_function, factory_functions)),
-    })
+    fm.write_with_template(
+        "variable_factories.h",
+        "variable_factories.h",
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir}/variable_factories.h",
+            "ops_headers": [
+                f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
+            ],
+            "function_definitions": list(mapMaybe(process_function, factory_functions)),
+        },
+    )
+
 
 @with_native_function
 def is_factory_function(f: NativeFunction) -> bool:
@@ -52,6 +61,7 @@
     has_tensor_options = python.has_tensor_options(f)
     return has_tensor_options or name.endswith("_like")
 
+
 @with_native_function
 def process_function(f: NativeFunction) -> Optional[str]:
     name = cpp.name(f.func)
@@ -64,22 +74,22 @@
     sig = CppSignatureGroup.from_native_function(f, method=False).signature
     formals: List[str] = []
     exprs: List[str] = []
-    requires_grad = 'false'
+    requires_grad = "false"
     for arg in sig.arguments():
         qualified_type = fully_qualified_type(arg.type)
         if arg.default:
-            formals.append(f'{qualified_type} {arg.name} = {arg.default}')
+            formals.append(f"{qualified_type} {arg.name} = {arg.default}")
         else:
-            formals.append(f'{qualified_type} {arg.name}')
+            formals.append(f"{qualified_type} {arg.name}")
 
         if isinstance(arg.argument, TensorOptionsArguments):
             # note: we remove the requires_grad setting from the TensorOptions because
             # it is ignored anyways (and we actually have an assertion that it isn't set
             # which would fail otherwise). We handle requires_grad explicitly here
             # instead of passing it through to the kernel.
-            exprs.append(f'at::TensorOptions({arg.name}).requires_grad(c10::nullopt)')
+            exprs.append(f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)")
             # Manually set the requires_grad bit on the result tensor.
-            requires_grad = f'{arg.name}.requires_grad()'
+            requires_grad = f"{arg.name}.requires_grad()"
         else:
             exprs.append(arg.name)
 
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index d52b2b2..8036f64 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -27,31 +27,68 @@
 #
 from .context import with_native_function_with_differentiability_info
 from .gen_trace_type import (
-    MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, declare_returned_variables,
-    tie_return_values, get_return_value, type_wrapper_name,
+    MANUAL_BACKEND,
+    MANUAL_AUTOGRAD_AND_TRACER,
+    declare_returned_variables,
+    tie_return_values,
+    get_return_value,
+    type_wrapper_name,
 )
 from .gen_inplace_or_view_type import (
-    get_view_info, is_tensor_type, is_tensor_list_type, unpack_args, get_base_name,
-    use_derived, modifies_arguments, WRAPPER_REGISTRATION, TMP_VAR, METHOD_DEFINITION,
-    ASSIGN_RETURN_VALUE, gen_formals, ALL_VIEW_FUNCTIONS, unpacked_name,
-    AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION
+    get_view_info,
+    is_tensor_type,
+    is_tensor_list_type,
+    unpack_args,
+    get_base_name,
+    use_derived,
+    modifies_arguments,
+    WRAPPER_REGISTRATION,
+    TMP_VAR,
+    METHOD_DEFINITION,
+    ASSIGN_RETURN_VALUE,
+    gen_formals,
+    ALL_VIEW_FUNCTIONS,
+    unpacked_name,
+    AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION,
 )
 
-from tools.codegen.api.types import (Binding, DispatcherSignature, BaseCType, intArrayRefT,
-                                     tensorT, tensorListT, MutRefCType, OptionalCType,
-                                     ListCType, SpecialArgName, scalarT, stringT,
-                                     TupleCType, VectorCType)
+from tools.codegen.api.types import (
+    Binding,
+    DispatcherSignature,
+    BaseCType,
+    intArrayRefT,
+    tensorT,
+    tensorListT,
+    MutRefCType,
+    OptionalCType,
+    ListCType,
+    SpecialArgName,
+    scalarT,
+    stringT,
+    TupleCType,
+    VectorCType,
+)
 from tools.codegen.api.autograd import (
-    DifferentiableInput, NativeFunctionWithDifferentiabilityInfo,
-    SavedAttribute, dispatch_strategy, gen_differentiable_outputs,
-    is_differentiable)
+    DifferentiableInput,
+    NativeFunctionWithDifferentiabilityInfo,
+    SavedAttribute,
+    dispatch_strategy,
+    gen_differentiable_outputs,
+    is_differentiable,
+)
 from tools.codegen.api import cpp
 from tools.codegen.code_template import CodeTemplate
 from tools.codegen.context import native_function_manager, with_native_function
 from tools.codegen.utils import mapMaybe, FileManager
-from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
-                                 SelfArgument, TensorOptionsArguments,
-                                 BaseType, ListType)
+from tools.codegen.model import (
+    Argument,
+    NativeFunction,
+    SchemaKind,
+    SelfArgument,
+    TensorOptionsArguments,
+    BaseType,
+    ListType,
+)
 from typing import Callable, List, Optional, Sequence, Tuple, Union, Dict
 
 # We don't set or modify grad_fn on these methods. Generally, they return
@@ -59,72 +96,255 @@
 # not examine or modify requires_grad or grad_fn.
 DONT_REQUIRE_DERIVATIVE = {
     # These only depend on the input Tensor's shape and device, not the data
-    'ones_like', 'zeros_like', 'rand_like', 'randn_like',
+    "ones_like",
+    "zeros_like",
+    "rand_like",
+    "randn_like",
     # These are only implemented on integral types
-    '__and__', '__iand__', '__ilshift__', '__ior__', '__irshift__', '__ixor__',
-    '__lshift__', '__or__', '__rshift__', '__xor__',
+    "__and__",
+    "__iand__",
+    "__ilshift__",
+    "__ior__",
+    "__irshift__",
+    "__ixor__",
+    "__lshift__",
+    "__or__",
+    "__rshift__",
+    "__xor__",
     # These work on integral data types, and hence don't require derivative
-    '_sobol_engine_draw', '_sobol_engine_ff', '_sobol_engine_scramble_',
-    '_sobol_engine_initialize_state_',
+    "_sobol_engine_draw",
+    "_sobol_engine_ff",
+    "_sobol_engine_scramble_",
+    "_sobol_engine_initialize_state_",
     # This is an unsafe method that is meant to be out of reach of autograd.
-    '_coalesced_',
+    "_coalesced_",
     # Quantize functions should not record gradients
-    'quantize_per_tensor', 'quantize_per_channel',
+    "quantize_per_tensor",
+    "quantize_per_channel",
     # Functions that return integers should not have output that require gradients
-    'argmax', 'argmin', 'argsort', 'searchsorted',
-    'bucketize',
+    "argmax",
+    "argmin",
+    "argsort",
+    "searchsorted",
+    "bucketize",
     # Functions that return booleans are not differentiable
-    'isnan', 'isposinf', 'isneginf', 'isinf', 'signbit', 'isin',
+    "isnan",
+    "isposinf",
+    "isneginf",
+    "isinf",
+    "signbit",
+    "isin",
     # Functions return none are not differentiable
-    'record_stream',
+    "record_stream",
     # These functions are not differentiable
-    'logical_and', 'logical_xor', 'logical_not', 'logical_or',
+    "logical_and",
+    "logical_xor",
+    "logical_not",
+    "logical_or",
 }
 
 # The C -> R functions at the time of adding this are still being audited and tested
 # but will not error out.
 # C -> C, R -> C functions for which backward is correctly implemented and tested
 GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
-    't', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone', 'diag_embed',
-    'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose',
-    'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril',
-    'triu', 'chunk', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum',
-    '_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh', 'cosh', '__rmul__',
-    'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', 'index_put',
-    'neg', 'complex', 'select', 'where', 'as_strided', 'slice', 'constant_pad_nd',
-    'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
-    'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'outer',
-    'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
-    'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_',
-    'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
-    'matrix_exp', 'linalg_matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_linalg_svd', '_fft_c2c', '_fft_r2c',
-    'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv', 'linalg_inv_ex',
-    'l1_loss_backward', 'baddbmm', 'addbmm', 'addmm', 'addmv', 'addr', 'linalg_householder_product',
-    'constant_pad_nd', 'reflection_pad1d', 'reflection_pad2d', 'reflection_pad3d', 'linalg_cholesky_ex', 'linalg_eig',
-    'select_backward', 'diagonal_backward', 'slice_backward',
-    'reflection_pad1d_backward', 'reflection_pad2d_backward', 'reflection_pad3d_backward', 'symeig', '_sparse_sparse_matmul',
-    'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'take', 'put_', '_to_copy',
-    'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
-    'diag', 'masked_scatter', 'masked_select', 'index_add', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
-    'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
-    'index', 'masked_fill', 'linalg_cross', 'lu_unpack', 'renorm', '_conj_physical', 'linalg_lu_factor_ex',
-    'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
-    'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve',
-    'linalg_solve_triangular', 'linalg_pinv', 'linalg_lstsq', 'col2im', 'col2im_backward', 'im2col', 'im2col_backward',
-    'cholesky_inverse',
+    "t",
+    "view",
+    "reshape",
+    "reshape_as",
+    "view_as",
+    "roll",
+    "clone",
+    "diag_embed",
+    "repeat",
+    "expand",
+    "flip",
+    "fliplr",
+    "flipud",
+    "rot90",
+    "transpose",
+    "permute",
+    "squeeze",
+    "unsqueeze",
+    "resize",
+    "resize_as",
+    "tril",
+    "triu",
+    "chunk",
+    "zero_",
+    "eq_",
+    "ne_",
+    "add",
+    "__radd__",
+    "sum",
+    "_conj",
+    "sin",
+    "cos",
+    "mul",
+    "sinc",
+    "sinh",
+    "cosh",
+    "__rmul__",
+    "sgn",
+    "asin",
+    "acos",
+    "sub",
+    "div",
+    "cat",
+    "view_as_complex",
+    "index_put",
+    "neg",
+    "complex",
+    "select",
+    "where",
+    "as_strided",
+    "slice",
+    "constant_pad_nd",
+    "unbind",
+    "split",
+    "split_with_sizes",
+    "unsafe_split",
+    "split_with_sizes_backward",
+    "dot",
+    "vdot",
+    "cholesky",
+    "triangular_solve",
+    "mm",
+    "_unsafe_view",
+    "mv",
+    "outer",
+    "bmm",
+    "diagonal",
+    "alias",
+    "atan",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "reciprocal",
+    "tan",
+    "pow",
+    "rsqrt",
+    "tanh",
+    "tanh_backward",
+    "asinh",
+    "acosh",
+    "atanh",
+    "take",
+    "fill_",
+    "exp",
+    "nonzero",
+    "mean",
+    "inverse",
+    "solve",
+    "linalg_cholesky",
+    "addcmul",
+    "addcdiv",
+    "matrix_exp",
+    "linalg_matrix_exp",
+    "linalg_eigh",
+    "cholesky_solve",
+    "linalg_qr",
+    "_linalg_svd",
+    "_fft_c2c",
+    "_fft_r2c",
+    "linalg_solve",
+    "sqrt",
+    "stack",
+    "gather",
+    "index_select",
+    "index_add_",
+    "linalg_inv",
+    "linalg_inv_ex",
+    "l1_loss_backward",
+    "baddbmm",
+    "addbmm",
+    "addmm",
+    "addmv",
+    "addr",
+    "linalg_householder_product",
+    "constant_pad_nd",
+    "reflection_pad1d",
+    "reflection_pad2d",
+    "reflection_pad3d",
+    "linalg_cholesky_ex",
+    "linalg_eig",
+    "select_backward",
+    "diagonal_backward",
+    "slice_backward",
+    "reflection_pad1d_backward",
+    "reflection_pad2d_backward",
+    "reflection_pad3d_backward",
+    "symeig",
+    "_sparse_sparse_matmul",
+    "replication_pad1d",
+    "replication_pad2d",
+    "replication_pad3d",
+    "take",
+    "put_",
+    "_to_copy",
+    "replication_pad1d_backward",
+    "replication_pad2d_backward",
+    "replication_pad3d_backward",
+    "diag",
+    "masked_scatter",
+    "masked_select",
+    "index_add",
+    "index_fill",
+    "trace",
+    "polar",
+    "cumsum",
+    "rsub",
+    "eig",
+    "lerp",
+    "linalg_vector_norm",
+    "cumprod",
+    "prod",
+    "index_copy",
+    "lu",
+    "unfold",
+    "unfold_backward",
+    "index",
+    "masked_fill",
+    "linalg_cross",
+    "lu_unpack",
+    "renorm",
+    "_conj_physical",
+    "linalg_lu_factor_ex",
+    "scatter",
+    "scatter_add",
+    "sigmoid",
+    "sigmoid_backward",
+    "trapezoid",
+    "cumulative_trapezoid",
+    "conj_physical_",
+    "_neg_view",
+    "_reshape_alias",
+    "_det_lu_based_helper",
+    "lu_solve",
+    "linalg_solve_triangular",
+    "linalg_pinv",
+    "linalg_lstsq",
+    "col2im",
+    "col2im_backward",
+    "im2col",
+    "im2col_backward",
+    "cholesky_inverse",
 }
 
 GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
-    '_to_dense', '_coalesce', 'coalesce', 'values', '_sparse_coo_tensor_with_dims_and_tensors',
-    'sparse_mask_helper_cuda', '_sparse_addmm',
+    "_to_dense",
+    "_coalesce",
+    "coalesce",
+    "values",
+    "_sparse_coo_tensor_with_dims_and_tensors",
+    "sparse_mask_helper_cuda",
+    "_sparse_addmm",
 }
 
 GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
 
 # Some operators invalidate the grad_accumulator. Let's reset it.
-RESET_GRAD_ACCUMULATOR = {
-    'set', 'resize'
-}
+RESET_GRAD_ACCUMULATOR = {"set", "resize"}
 
 # NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
 #
@@ -139,212 +359,278 @@
 #      the input it is aliased with. Otherwise, its StorageImpl has use_count of 1
 #
 # The following code templates implement the checks for this invariant:
-SAVE_TENSOR_STORAGE = CodeTemplate("""\
+SAVE_TENSOR_STORAGE = CodeTemplate(
+    """\
 c10::optional<Storage> ${tensor_name}_storage_saved =
   ${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
-""")
+"""
+)
 
 # If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2)
-ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate("""\
+ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
+    """\
 if (${tensor_name}_storage_saved.has_value())
   AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
-""")
+"""
+)
 
-SAVE_TENSORLIST_STORAGE = CodeTemplate("""\
+SAVE_TENSORLIST_STORAGE = CodeTemplate(
+    """\
 std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
 for (const Tensor& tensor : ${tensorlist_name})
   ${tensorlist_name}_storage_saved.push_back(
     tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
-""")
+"""
+)
 
-ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate("""\
+ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
+    """\
 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
   if (${tensorlist_name}_storage_saved[i].has_value())
     AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
 }
-""")
+"""
+)
 
-SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
+SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
+    """\
 std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
 for (const c10::optional<Tensor>& tensor : ${tensorlist_name})
   ${tensorlist_name}_storage_saved.push_back(
     tensor.has_value() && tensor->has_storage() ? c10::optional<Storage>(tensor->storage()) : c10::nullopt);
-""")
+"""
+)
 
-ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
+ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
+    """\
 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
   if (${tensorlist_name}_storage_saved[i].has_value())
     AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
         static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->storage()));
 }
-""")
+"""
+)
 
-SAVE_TENSOR_IMPL = CodeTemplate("""\
+SAVE_TENSOR_IMPL = CodeTemplate(
+    """\
 c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
 if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
-""")
+"""
+)
 
-ENFORCE_SAME_TENSOR_IMPL = CodeTemplate("""\
+ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
+    """\
 if (${tensor_name}_impl_saved) AT_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
-""")
+"""
+)
 
-ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate("""\
+ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate(
+    """\
 AT_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}");
-""")
+"""
+)
 
-ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate("""\
+ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate(
+    """\
 if (${tensor_name}.has_storage()) AT_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}");
-""")
+"""
+)
 
-SAVE_TENSORLIST_IMPL = CodeTemplate("""\
+SAVE_TENSORLIST_IMPL = CodeTemplate(
+    """\
 std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
 for (size_t i=0; i<${tensorlist_name}.size(); i++)
   if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
-""")
+"""
+)
 
-ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate("""\
+ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
+    """\
 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
   if (${tensorlist_name}_impl_saved[i])
     AT_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
 }
-""")
+"""
+)
 
-SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
+SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate(
+    """\
 std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
   c10::optional<Tensor> t = ${tensorlist_name}[i];
   if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
 }
-""")
+"""
+)
 
-ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
+ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate(
+    """\
 for (size_t i=0; i<${tensorlist_name}.size(); i++) {
   if (${tensorlist_name}_impl_saved[i])
     AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
 }
-""")
+"""
+)
 
 # The following list contains functions that we don't enforce the invariant on.
 DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
     # These functions are expected to change impl or storage of input tensors
-    'set_', '_cudnn_rnn_flatten_weight',
+    "set_",
+    "_cudnn_rnn_flatten_weight",
 }
 DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
     # These non-inplace, non-out functions return tensors with use_count > 1
     # Therefore, they MAY (but not necessarily) return one of its inputs as-is
     # See https://github.com/pytorch/pytorch/issues/60426 for more information
-    '_embedding_bag', '_embedding_bag_forward_only',
-    'q_per_channel_scales', 'q_per_channel_zero_points',
-    'lu_unpack', '_cudnn_rnn_backward',
-
+    "_embedding_bag",
+    "_embedding_bag_forward_only",
+    "q_per_channel_scales",
+    "q_per_channel_zero_points",
+    "lu_unpack",
+    "_cudnn_rnn_backward",
     # The below failed StorageImpl use_count check but we skip tensor_impl check
     # just in case
-    '_cudnn_rnn', 'dequantize_self',
+    "_cudnn_rnn",
+    "dequantize_self",
 }
 
 DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
     # These non-view functions return tensors with storage use_count != 1
-    '_slow_conv2d_forward', 'slow_conv3d_forward', 'channel_shuffle',
-
+    "_slow_conv2d_forward",
+    "slow_conv3d_forward",
+    "channel_shuffle",
     # If an input is returned as-is in output, we cannot guarantee its storage_impl
     # use count to be 1 either.
     *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
 }
 # END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ]
 
-DECLARE_GRAD_FN = CodeTemplate("""\
+DECLARE_GRAD_FN = CodeTemplate(
+    """\
 std::shared_ptr<${op}> grad_fn;
-""")
+"""
+)
 
-SETUP_ANY_REQUIRES_GRAD = CodeTemplate("""\
+SETUP_ANY_REQUIRES_GRAD = CodeTemplate(
+    """\
 auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} );
 ${extra_differentiability_conditions}
 (void)_any_requires_grad;
-""")
+"""
+)
 
-SETUP_DERIVATIVE = CodeTemplate("""\
+SETUP_DERIVATIVE = CodeTemplate(
+    """\
 if (_any_requires_grad) {
   ${setup}
 }
-""")
+"""
+)
 
-SETUP_NONE_REQUIRES_GRAD = CodeTemplate("""\
+SETUP_NONE_REQUIRES_GRAD = CodeTemplate(
+    """\
 if (compute_requires_grad( ${args_to_check} )) {
   throw_error_out_requires_grad("${base_name}");
 }
-""")
+"""
+)
 
-ASSIGN_GRAD_FN = CodeTemplate("""\
+ASSIGN_GRAD_FN = CodeTemplate(
+    """\
 grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
 grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
-""")
+"""
+)
 
-CALL_REDISPATCH = CodeTemplate("""\
-at::redispatch::${api_name}(${unpacked_args})""")
+CALL_REDISPATCH = CodeTemplate(
+    """\
+at::redispatch::${api_name}(${unpacked_args})"""
+)
 # If the non-variable operation has return values, we use the `tmp` variable to hold the
 # values temporarily and pass the values to the return variables outside of the
 # `at::AutoDispatchBelowAutograd` guard block.
-DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate("""\
+DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
+    """\
 auto ${tmp_var} = ([&]() {
   ${guard}
   return ${base_type_call};
 })();
-""")
+"""
+)
 
-DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate("""\
+DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
+    """\
 {
   ${guard}
   ${base_type_call};
 }
-""")
+"""
+)
 
-SET_HISTORY = CodeTemplate("""\
+SET_HISTORY = CodeTemplate(
+    """\
 if (grad_fn) {
     ${fn}_history(${differentiable_outputs}, grad_fn);
 }
-""")
+"""
+)
 
-CONDITIONAL = CodeTemplate("""\
+CONDITIONAL = CodeTemplate(
+    """\
 if (${cond}) {
   ${statements}
 }
-""")
+"""
+)
 
-RUN_ONLY_IN_DEBUG_MODE = CodeTemplate("""\
+RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
+    """\
 #ifndef NDEBUG
 ${statements}
 #endif
-""")
+"""
+)
 
-FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate(
+    """\
 isFwGradDefined(${req_inp})\
-""")
+"""
+)
 
-FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
+    """\
 auto ${inp}_t_raw = toNonOptFwGrad(${inp});
 auto ${inp}_tensor = toNonOptTensor(${inp});
 auto ${inp}_t = (${inp}_t_raw.defined() || !${inp}_tensor.defined())
   ? ${inp}_t_raw : at::${zeros_fn}(${inp}_tensor.sizes(), ${inp}_tensor.options());
-""")
+"""
+)
 
-FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
+    """\
 auto ${inp}_p = toNonOptPrimal(${inp});
-""")
+"""
+)
 
-FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate("""\
+FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate(
+    """\
 if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined()) {
   // The hardcoded 0 here will need to be updated once we support multiple levels.
   ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
 }
-""")
+"""
+)
 
-FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate("""\
+FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate(
+    """\
 if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()) {
   ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false);
 }
-""")
+"""
+)
 
-FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate("""\
+FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
+    """\
 if (${out_arg}_new_fw_grad_opt.has_value()) {
   auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
   TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
@@ -355,25 +641,33 @@
     }
   }
 }
-""")
+"""
+)
 
-FW_DERIVATIVE_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_TEMPLATE = CodeTemplate(
+    """\
 ${fw_grad_opt_definition}
 if (${requires_fw_grad}) {
     ${unpacked_arguments}
     ${out_arg}_new_fw_grad_opt = ${formula};
 }
-""")
+"""
+)
 
-FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate(
+    """\
 TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
-""")
+"""
+)
 
-FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate("""\
+FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate(
+    """\
 for (const auto& _t: ${arg}) {
     TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
 }
-""")
+"""
+)
+
 
 def gen_variable_type(
     out: str,
@@ -389,47 +683,54 @@
     compute the output. The grad_fn is attached to differentiable functions.
     """
     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
-    fm.write('VariableType.h', lambda: {
-        'generated_comment': "@" f'generated from {template_path}/VariableType.h'
-    })
+    fm.write(
+        "VariableType.h",
+        lambda: {
+            "generated_comment": "@" f"generated from {template_path}/VariableType.h"
+        },
+    )
 
     # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
     # template regarding sharding of the generated files.
     fm.write_sharded(
-        'VariableType.cpp',
+        "VariableType.cpp",
         [fn for fn in fns_with_diff_infos if use_derived(fn)],
         key_fn=lambda fn: cpp.name(fn.func.func),
         base_env={
-            'generated_comment':
-            "@" f'generated from {template_path}/VariableType.cpp',
+            "generated_comment": "@" f"generated from {template_path}/VariableType.cpp",
         },
         env_callable=gen_variable_type_func,
         num_shards=5,
-        sharded_keys={'type_derived_method_definitions', 'wrapper_registrations'}
+        sharded_keys={"type_derived_method_definitions", "wrapper_registrations"},
     )
 
+
 @with_native_function
 def gen_wrapper_registration(f: NativeFunction) -> str:
     return WRAPPER_REGISTRATION.substitute(
         unqual_operator_name_with_overload=f.func.name,
         type_wrapper_name=type_wrapper_name(f),
-        class_type='VariableType',
+        class_type="VariableType",
     )
 
+
 def gen_variable_type_func(
-    fn: NativeFunctionWithDifferentiabilityInfo
+    fn: NativeFunctionWithDifferentiabilityInfo,
 ) -> Dict[str, List[str]]:
     f = fn.func
     with native_function_manager(f):
         name = cpp.name(f.func)
         formals = gen_formals(f)
 
-        if fn.info is None and not get_base_name(f) in RESET_GRAD_ACCUMULATOR \
-                and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE \
-                and len(gen_differentiable_outputs(fn)) > 0 \
-                and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE \
-                and not type_wrapper_name(f) in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT \
-                and not type_wrapper_name(f) in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
+        if (
+            fn.info is None
+            and not get_base_name(f) in RESET_GRAD_ACCUMULATOR
+            and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE
+            and len(gen_differentiable_outputs(fn)) > 0
+            and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
+            and not type_wrapper_name(f) in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
+            and not type_wrapper_name(f) in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
+        ):
             # NOTE: [ Registering AutogradNotImplemented boxed kernel ]
             #
             # When there is no derivatives.yaml entry, we register a generic boxed
@@ -448,7 +749,8 @@
             #    to (1).
             type_definition = ""
             wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute(
-                unqual_operator_name_with_overload=f.func.name)
+                unqual_operator_name_with_overload=f.func.name
+            )
         else:
             type_definition = METHOD_DEFINITION.substitute(
                 return_type=cpp.returns_type(f.func.returns).cpp_type(),
@@ -463,21 +765,24 @@
     # If you want to register a kernel to Autograd, you must make the op abstract.
     # In other words, this op must have dispatch section in native_functions.yaml.
     if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info and fn.info.has_derivatives):
-        msg = (f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. '
-               f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA '
-               f'or CompositeExplicitAutograd in native_functions.yaml. Please see '
-               f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword '
-               f'for instructions to choose the right dispatch keyword.')
+        msg = (
+            f"There's a formula for {name}(or its functional variant) in derivatives.yaml. "
+            f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA "
+            f"or CompositeExplicitAutograd in native_functions.yaml. Please see "
+            f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword "
+            f"for instructions to choose the right dispatch keyword."
+        )
         assert f.is_abstract, msg
 
     return {
-        'type_derived_method_definitions': [type_definition],
-        'wrapper_registrations': [wrapper_registration],
+        "type_derived_method_definitions": [type_definition],
+        "wrapper_registrations": [wrapper_registration],
     }
 
+
 @with_native_function_with_differentiability_info
 def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
-    assert dispatch_strategy(fn) == 'use_derived'
+    assert dispatch_strategy(fn) == "use_derived"
     f = fn.func
     info = fn.info
     fw_derivatives = fn.fw_derivatives
@@ -513,7 +818,9 @@
     def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
         return list(mapMaybe(gen_differentiable_input, f.func.arguments.non_out))
 
-    def find_args_with_derivatives(differentiable_inputs: List[DifferentiableInput]) -> List[DifferentiableInput]:
+    def find_args_with_derivatives(
+        differentiable_inputs: List[DifferentiableInput],
+    ) -> List[DifferentiableInput]:
         """Find arguments that have derivative definitions"""
         if info is None or not info.has_derivatives:
             return differentiable_inputs
@@ -521,26 +828,38 @@
         differentiable = [arg for arg in differentiable_inputs if arg.name in names]
         if len(differentiable) != len(names):
             missing = names - set(arg.name for arg in differentiable)
-            raise RuntimeError(f'Missing arguments for derivatives: {missing} in {info.name}')
+            raise RuntimeError(
+                f"Missing arguments for derivatives: {missing} in {info.name}"
+            )
         return differentiable
 
     differentiable_inputs = gen_differentiable_inputs(f)
     args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
     differentiable_outputs = gen_differentiable_outputs(fn)
 
-    undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (name in DONT_REQUIRE_DERIVATIVE)
+    undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
+        name in DONT_REQUIRE_DERIVATIVE
+    )
 
-    requires_derivative = (not undifferentiable) and (len(differentiable_inputs) > 0) and (len(differentiable_outputs) > 0)
+    requires_derivative = (
+        (not undifferentiable)
+        and (len(differentiable_inputs) > 0)
+        and (len(differentiable_outputs) > 0)
+    )
 
     if info is not None and info.has_derivatives and not requires_derivative:
-        raise RuntimeError(f'ERROR: derivative ignored for {name} -- specified an autograd function without derivative')
+        raise RuntimeError(
+            f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
+        )
 
     def emit_save_inputs() -> List[str]:
         setup: List[str] = []
         if info is None or not info.has_derivatives:
             return setup
 
-        has_tensorlist_arg = any(is_tensor_list_type(arg.type) for arg in args_with_derivatives)
+        has_tensorlist_arg = any(
+            is_tensor_list_type(arg.type) for arg in args_with_derivatives
+        )
 
         # We don't want to save tensors if we know that they will never be used
         # when computing the derivative, so we add guards to those statements
@@ -557,7 +876,7 @@
             # require_grad if the backward function even gets executed. I don't
             # have any good ideas for detecting those cases, so I simply disabled the
             # checks.
-            if 'backward' in info.name:
+            if "backward" in info.name:
                 return None
 
             # If there's a single derivative we could compute, we already have
@@ -587,12 +906,12 @@
             else:
                 raise AssertionError()
 
-            return f'grad_fn->should_compute_output({edge_off})'
+            return f"grad_fn->should_compute_output({edge_off})"
 
         setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
         for arg in args_with_derivatives:
             if is_tensor_list_type(arg.type):
-                setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')
+                setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
 
         return setup
 
@@ -600,25 +919,37 @@
         body: List[str] = []
         if is_out_fn:
             # For out functions, ensure that no input or output requires grad
-            body.append(DECLARE_GRAD_FN.substitute(op='Node'))
-            body.append(SETUP_NONE_REQUIRES_GRAD.substitute(
-                base_name=base_name,
-                args_to_check=[arg.name for arg in differentiable_inputs]))
-            body.append(SETUP_NONE_REQUIRES_GRAD.substitute(
-                base_name=base_name,
-                args_to_check=[arg.name for arg in differentiable_outputs]))
+            body.append(DECLARE_GRAD_FN.substitute(op="Node"))
+            body.append(
+                SETUP_NONE_REQUIRES_GRAD.substitute(
+                    base_name=base_name,
+                    args_to_check=[arg.name for arg in differentiable_inputs],
+                )
+            )
+            body.append(
+                SETUP_NONE_REQUIRES_GRAD.substitute(
+                    base_name=base_name,
+                    args_to_check=[arg.name for arg in differentiable_outputs],
+                )
+            )
             return body
 
-        op = info.op if info is not None and info.has_derivatives else 'NotImplemented'
+        op = info.op if info is not None and info.has_derivatives else "NotImplemented"
         setup = []
-        setup.extend(ASSIGN_GRAD_FN.substitute(
-            op=op,
-            op_ctor='' if info is not None and info.has_derivatives else f'"{cpp.name(f.func)}"',
-            args_with_derivatives=[arg.name for arg in args_with_derivatives],
-        ).split('\n'))
+        setup.extend(
+            ASSIGN_GRAD_FN.substitute(
+                op=op,
+                op_ctor=""
+                if info is not None and info.has_derivatives
+                else f'"{cpp.name(f.func)}"',
+                args_with_derivatives=[arg.name for arg in args_with_derivatives],
+            ).split("\n")
+        )
         setup.extend(emit_save_inputs())
 
-        body.extend(emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives))
+        body.extend(
+            emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)
+        )
         body.append(DECLARE_GRAD_FN.substitute(op=op))
         body.append(SETUP_DERIVATIVE.substitute(setup=setup))
         return body
@@ -630,7 +961,11 @@
         for arg in differentiable_outputs:
             name = arg.name
             # TODO: should be `arg.type.is_tensor_like()`?
-            if arg.cpp_type in ['at::Tensor', 'at::TensorList', 'const c10::List<c10::optional<at::Tensor>> &']:
+            if arg.cpp_type in [
+                "at::Tensor",
+                "at::TensorList",
+                "const c10::List<c10::optional<at::Tensor>> &",
+            ]:
                 body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
         return body
 
@@ -646,7 +981,7 @@
             arg_name = arg.name
             if info and arg_name in info.non_differentiable_arg_names:
                 continue
-            if arg_name == 'output':
+            if arg_name == "output":
                 # Double-backwards definitions sometimes take in 'input' and
                 # 'output', but only define the derivative for input.
                 continue
@@ -656,17 +991,19 @@
     def emit_original_self_definition() -> List[str]:
         body: List[str] = []
         if inplace:
-            body.append('c10::optional<at::Tensor> original_self;')
+            body.append("c10::optional<at::Tensor> original_self;")
 
             all_forward_grad_cond = []
             for derivative in fw_derivatives:
                 if derivative.required_original_self_value:
-                    all_forward_grad_cond.append(get_any_has_forward_grad_name(derivative.var_names))
+                    all_forward_grad_cond.append(
+                        get_any_has_forward_grad_name(derivative.var_names)
+                    )
 
             if all_forward_grad_cond:
                 body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{')
-                body.append('  original_self = self.clone();')
-                body.append('}')
+                body.append("  original_self = self.clone();")
+                body.append("}")
 
         return body
 
@@ -678,80 +1015,100 @@
         # assign the saved variables to the generated grad_fn
         stmts: List[str] = []
         for arg in saved_variables:
-            name = arg.nctype.name.name if isinstance(arg.nctype.name, SpecialArgName) else arg.nctype.name
+            name = (
+                arg.nctype.name.name
+                if isinstance(arg.nctype.name, SpecialArgName)
+                else arg.nctype.name
+            )
             type = arg.nctype.type
             expr = arg.expr
             stmts_prepend = None
-            if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \
-                    type == MutRefCType(OptionalCType(BaseCType(tensorT))) or (is_output and type == BaseCType(scalarT)):
+            if (
+                type == BaseCType(tensorT)
+                or type == OptionalCType(BaseCType(tensorT))
+                or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
+                or (is_output and type == BaseCType(scalarT))
+            ):
                 var = name
-                name += '_'
-                if var == 'self' and inplace:
-                    stmts_prepend = 'if (!original_self.has_value()) original_self = self.clone()'
-                    var = 'original_self.value()'
+                name += "_"
+                if var == "self" and inplace:
+                    stmts_prepend = (
+                        "if (!original_self.has_value()) original_self = self.clone()"
+                    )
+                    var = "original_self.value()"
                     assert not is_output
                 if inplace and is_output:
-                    var = 'self'
-                    is_inplace_view = f'{var}.is_view()'
-                    expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
+                    var = "self"
+                    is_inplace_view = f"{var}.is_view()"
+                    expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})"
                 else:
-                    expr = f'SavedVariable({var}, {str(is_output).lower()})'
-            elif type == BaseCType(tensorListT) or type == ListCType(OptionalCType(BaseCType(tensorT))):
-                expr = f'make_saved_variable_list({name})'
-                name += '_'
+                    expr = f"SavedVariable({var}, {str(is_output).lower()})"
+            elif type == BaseCType(tensorListT) or type == ListCType(
+                OptionalCType(BaseCType(tensorT))
+            ):
+                expr = f"make_saved_variable_list({name})"
+                name += "_"
             elif type == BaseCType(intArrayRefT):
                 expr = expr + ".vec()"
             elif type == BaseCType(stringT):
-                expr = f'std::string({expr})'
+                expr = f"std::string({expr})"
             elif type == OptionalCType(BaseCType(stringT)):
-                expr = f'{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt'
+                expr = f"{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt"
             guard = guard_for(arg)
             if guard is None:
                 if stmts_prepend:
-                    stmts.append(f'{stmts_prepend};')
-                stmts.append(f'grad_fn->{name} = {expr};')
+                    stmts.append(f"{stmts_prepend};")
+                stmts.append(f"grad_fn->{name} = {expr};")
             else:
-                stmts.append(f'if ({guard}) {{')
+                stmts.append(f"if ({guard}) {{")
                 if stmts_prepend:
-                    stmts.append(f'  {stmts_prepend};')
-                stmts.append(f'  grad_fn->{name} = {expr};')
-                stmts.append('}')
+                    stmts.append(f"  {stmts_prepend};")
+                stmts.append(f"  grad_fn->{name} = {expr};")
+                stmts.append("}")
         return stmts
 
     # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
     #  - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
     #  - redispatch() avoids a redundant call to RecordFunction, which was already called right before
     #    we entered this autograd kernel.
-    def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
-        """ Dispatch call via function in a namespace or method on Tensor."""
+    def emit_dispatch_call(
+        f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
+    ) -> str:
+        """Dispatch call via function in a namespace or method on Tensor."""
         dispatcher_sig = DispatcherSignature.from_schema(f.func)
         dispatcher_exprs = dispatcher_sig.exprs()
 
         # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
         # Ops also always have a function variant of the redispatch API.
         # See Note [Plumbing Keys Through The Dispatcher] for details.
-        dispatch_key_set = 'ks & c10::after_autograd_keyset'
+        dispatch_key_set = "ks & c10::after_autograd_keyset"
         call = CALL_REDISPATCH.substitute(
             api_name=cpp.name(
                 f.func,
                 faithful_name_for_out_overloads=True,
             ),
-            unpacked_args=[dispatch_key_set] + list(unpacked_args))
+            unpacked_args=[dispatch_key_set] + list(unpacked_args),
+        )
         return call
 
-    def wrap_output(f: NativeFunction, unpacked_bindings: List[Binding], var: str) -> str:
-        call = ''
+    def wrap_output(
+        f: NativeFunction, unpacked_bindings: List[Binding], var: str
+    ) -> str:
+        call = ""
         rhs_value: Optional[str] = None
         if not any(r.type.is_tensor_like() for r in f.func.returns):
             rhs_value = var
         else:
-            rhs_value = f'std::move({var})'
+            rhs_value = f"std::move({var})"
         assert rhs_value is not None
-        call += ASSIGN_RETURN_VALUE.substitute(return_values=tie_return_values(f),
-                                               rhs_value=rhs_value)
+        call += ASSIGN_RETURN_VALUE.substitute(
+            return_values=tie_return_values(f), rhs_value=rhs_value
+        )
         return call
 
-    def check_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str:
+    def check_tensorimpl_and_storage(
+        call: str, unpacked_bindings: List[Binding]
+    ) -> str:
         # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
         stmts_before_call: List[str] = []
         stmts_after_call: List[str] = []
@@ -764,22 +1121,42 @@
             arg = unpacked_binding.name
             noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
             if noref_cpp_type == BaseCType(tensorListT):
-                stmts_before_call += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
-                                      SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
-                stmts_after_call += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
-                                     ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
+                stmts_before_call += [
+                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
             elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
-                stmts_before_call += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
-                                      SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
-                stmts_after_call += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
-                                     ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
+                stmts_before_call += [
+                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
+                        tensorlist_name=arg
+                    ),
+                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
+                        tensorlist_name=arg
+                    ),
+                ]
             elif noref_cpp_type == BaseCType(tensorT):
-                stmts_before_call += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
-                                      SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
-                stmts_after_call += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg, out_tensor_name=arg),
-                                     ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)]
+                stmts_before_call += [
+                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
+                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
+                        tensor_name=arg, out_tensor_name=arg
+                    ),
+                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg),
+                ]
 
-        assert (stmts_before_call and stmts_after_call) or (not stmts_before_call and not stmts_after_call)
+        assert (stmts_before_call and stmts_after_call) or (
+            not stmts_before_call and not stmts_after_call
+        )
 
         # Check properties of outputs (enforce (2), (3))
         if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out):
@@ -787,33 +1164,55 @@
             aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
             if aliased_arg_name is not None:
                 aliased_arg_name = unpacked_name(aliased_arg_name)
-            for i, (ret, ret_name) in enumerate(zip(f.func.returns, cpp.return_names(f))):
+            for i, (ret, ret_name) in enumerate(
+                zip(f.func.returns, cpp.return_names(f))
+            ):
                 noref_cpp_type = cpp.return_type(ret).remove_const_ref()
                 if noref_cpp_type == BaseCType(tensorT):
                     if aliased_arg_name is not None:
-                        assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output"
-                        stmts_after_call += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=aliased_arg_name,
-                                                                                    out_tensor_name=ret_name)]
+                        assert (
+                            i == 0
+                        ), "Expect non-CompositeImplicitAutograd view function {base} to return single output"
+                        stmts_after_call += [
+                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
+                                tensor_name=aliased_arg_name, out_tensor_name=ret_name
+                            )
+                        ]
                     else:
-                        if type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT:
-                            stmts_after_call += [ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
-                                tensor_name=ret_name, fn_name=type_wrapper_name(f))]
+                        if (
+                            type_wrapper_name(f)
+                            not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
+                        ):
+                            stmts_after_call += [
+                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
+                                    tensor_name=ret_name, fn_name=type_wrapper_name(f)
+                                )
+                            ]
 
                     if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
-                        stmts_after_call += [ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
-                            tensor_name=ret_name, fn_name=type_wrapper_name(f))]
+                        stmts_after_call += [
+                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
+                                tensor_name=ret_name, fn_name=type_wrapper_name(f)
+                            )
+                        ]
 
                 # Currently we don't have any functions that return the following types, but
                 # we should update the checks once we do
                 elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
-                    raise AssertionError(f"Please add use_count checks for {noref_cpp_type}")
+                    raise AssertionError(
+                        f"Please add use_count checks for {noref_cpp_type}"
+                    )
                 elif noref_cpp_type == BaseCType(tensorListT):
-                    raise AssertionError(f"Please add use_count checks for {noref_cpp_type}")
+                    raise AssertionError(
+                        f"Please add use_count checks for {noref_cpp_type}"
+                    )
 
         if stmts_before_call and stmts_after_call:
-            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \
-                call + \
-                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
+            call = (
+                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call)
+                + call
+                + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
+            )
         return call
 
     def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
@@ -823,55 +1222,61 @@
         # in are now Variables.
         # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
         unpacked_args = [b.name for b in unpacked_bindings]
-        base_type_call = emit_dispatch_call(f, 'self_', unpacked_args)
+        base_type_call = emit_dispatch_call(f, "self_", unpacked_args)
 
         if get_view_info(f) is not None or modifies_arguments(f):
-            guard = 'at::AutoDispatchBelowAutograd guard;'
+            guard = "at::AutoDispatchBelowAutograd guard;"
         else:
-            guard = 'at::AutoDispatchBelowADInplaceOrView guard;'
+            guard = "at::AutoDispatchBelowADInplaceOrView guard;"
 
         if not modifies_arguments(f) and not returns_void:
             call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
-                base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard)
+                base_type_call=base_type_call, tmp_var=TMP_VAR, guard=guard
+            )
 
             call += wrap_output(f, unpacked_bindings, TMP_VAR)
         else:
             call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
-                base_type_call=base_type_call, guard=guard)
+                base_type_call=base_type_call, guard=guard
+            )
         call = check_tensorimpl_and_storage(call, unpacked_bindings)
         return call
 
     def emit_history() -> str:
-        fn = 'rebase' if modifies_arguments(f) and view_info is None else 'set'
+        fn = "rebase" if modifies_arguments(f) and view_info is None else "set"
         output_names = [r.name for r in differentiable_outputs]
         # TODO: flatten allocates a std::vector, which could be expensive
-        outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(outs=output_names)
+        outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
+            outs=output_names
+        )
         return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
 
     def emit_save_outputs() -> str:
         if is_out_fn:
             # out functions don't currently support differentiation
-            return ''
+            return ""
         if info is not None and info.has_derivatives:
             stmts = save_variables(info.all_saved_outputs, True)
             if len(stmts) == 0:
-                return ''
-            return CONDITIONAL.substitute(cond='grad_fn', statements=stmts)
-        return ''
+                return ""
+            return CONDITIONAL.substitute(cond="grad_fn", statements=stmts)
+        return ""
 
     def emit_any_requires_grad() -> List[str]:
-        extra_condition = ''
+        extra_condition = ""
         if fn.info and fn.info.output_differentiability_conditions:
             assert len(fn.info.output_differentiability_conditions) == 1
-            extra_condition = \
-                f'_any_requires_grad &= ({fn.info.output_differentiability_conditions[0]});'
-        return [SETUP_ANY_REQUIRES_GRAD.substitute(
-            args_with_derivatives=[arg.name for arg in args_with_derivatives],
-            extra_differentiability_conditions=extra_condition)]
+            extra_condition = f"_any_requires_grad &= ({fn.info.output_differentiability_conditions[0]});"
+        return [
+            SETUP_ANY_REQUIRES_GRAD.substitute(
+                args_with_derivatives=[arg.name for arg in args_with_derivatives],
+                extra_differentiability_conditions=extra_condition,
+            )
+        ]
 
     def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
         if len(var_names) == 1:
-            return f'_any_has_forward_grad_{var_names[0]}'
+            return f"_any_has_forward_grad_{var_names[0]}"
         else:
             return f'_any_has_forward_grad_{"_".join(var_names)}'
 
@@ -879,32 +1284,46 @@
         content: List[str] = []
         for derivative in fw_derivatives:
             assert derivative.required_inputs_fw_grad is not None
-            requires_fw_grad = " || ".join([FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
-                                           for inp in differentiable_inputs if inp.name in derivative.required_inputs_fw_grad])
+            requires_fw_grad = " || ".join(
+                [
+                    FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
+                    for inp in differentiable_inputs
+                    if inp.name in derivative.required_inputs_fw_grad
+                ]
+            )
             if not requires_fw_grad:
                 # Handle functions like stack
                 # For these, we don't unpack anything and always call the user function
-                if not (len(differentiable_inputs) == 1 and is_tensor_list_type(differentiable_inputs[0].type)):
-                    raise RuntimeError(f'No differentiable input to "{name}" is a differentiable Tensor (as the provided'
-                                       'forward AD formula does not use any input tangent) even though a forward gradient '
-                                       'formula has been defined for it. This case should only happen for function that '
-                                       'take a single TensorList as input. All other cases are not supported right now.')
+                if not (
+                    len(differentiable_inputs) == 1
+                    and is_tensor_list_type(differentiable_inputs[0].type)
+                ):
+                    raise RuntimeError(
+                        f'No differentiable input to "{name}" is a differentiable Tensor (as the provided'
+                        "forward AD formula does not use any input tangent) even though a forward gradient "
+                        "formula has been defined for it. This case should only happen for function that "
+                        "take a single TensorList as input. All other cases are not supported right now."
+                    )
                 requires_fw_grad = "true"
 
             if fn.info and fn.info.output_differentiability_conditions:
                 assert len(fn.info.output_differentiability_conditions) == 1
-                requires_fw_grad = \
-                    f'({fn.info.output_differentiability_conditions[0]}) && ({requires_fw_grad})'
+                requires_fw_grad = f"({fn.info.output_differentiability_conditions[0]}) && ({requires_fw_grad})"
 
-            content.append(f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
-                           f"(void){get_any_has_forward_grad_name(derivative.var_names)};")
+            content.append(
+                f"auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};\n"
+                f"(void){get_any_has_forward_grad_name(derivative.var_names)};"
+            )
 
         return content
 
     def emit_check_inplace() -> List[str]:
         if not inplace:
             return []
-        return [f'check_inplace({arg.name}, _any_requires_grad);' for arg in differentiable_outputs]
+        return [
+            f"check_inplace({arg.name}, _any_requires_grad);"
+            for arg in differentiable_outputs
+        ]
 
     def emit_fw_derivatives() -> List[str]:
         content: List[str] = []
@@ -912,7 +1331,9 @@
         for derivative in fw_derivatives:
             res = derivative.var_names
             if f.func.name.name.inplace:
-                assert len(res) == 1, "Expected number of outputs to be 1 if function is inplace"
+                assert (
+                    len(res) == 1
+                ), "Expected number of outputs to be 1 if function is inplace"
                 # TODO update this when inplace namings are unified
                 res = ("self",)
 
@@ -920,53 +1341,99 @@
 
             unpacked_arguments = ""
             for inp in differentiable_inputs:
-                zeros_fn = "zeros" if inplace and inp.name == "self" else "_efficientzerotensor"
+                zeros_fn = (
+                    "zeros"
+                    if inplace and inp.name == "self"
+                    else "_efficientzerotensor"
+                )
                 if inp.name in derivative.required_inputs_fw_grad:
-                    unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp=inp.name, zeros_fn=zeros_fn)
+                    unpacked_arguments += (
+                        FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
+                            inp=inp.name, zeros_fn=zeros_fn
+                        )
+                    )
                 if inp.name in (derivative.required_inputs_primal or []):
-                    unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp=inp.name)
+                    unpacked_arguments += (
+                        FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp=inp.name)
+                    )
             if derivative.required_original_self_value:
-                unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp="original_self", zeros_fn=zeros_fn)
-                unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp="original_self")
+                unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
+                    inp="original_self", zeros_fn=zeros_fn
+                )
+                unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
+                    inp="original_self"
+                )
             elif inplace and derivative.is_reusing_outplace_formula:
                 # The gradient wasn't already cloned, do it if grad mode is enabled
-                unpacked_arguments += "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
+                unpacked_arguments += (
+                    "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
+                )
 
             if inplace:
                 is_inplace_str = "true"
             else:
                 is_inplace_str = "false"
 
-            if all((isinstance(var_type, BaseType) and var_type.is_tensor_like()) for var_type in derivative.var_types):
+            if all(
+                (isinstance(var_type, BaseType) and var_type.is_tensor_like())
+                for var_type in derivative.var_types
+            ):
                 # Is there a way to get from BaseType to BaseCType
                 if len(derivative.var_types) == 1:
                     opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
-                    fw_grad_setters.append(FW_DERIVATIVE_SETTER_TENSOR.substitute(out_arg=res[0], is_inplace=is_inplace_str))
+                    fw_grad_setters.append(
+                        FW_DERIVATIVE_SETTER_TENSOR.substitute(
+                            out_arg=res[0], is_inplace=is_inplace_str
+                        )
+                    )
                 else:
-                    tuple_type = TupleCType([BaseCType(tensorT)] * len(derivative.var_types))
+                    tuple_type = TupleCType(
+                        [BaseCType(tensorT)] * len(derivative.var_types)
+                    )
                     opt_res_grad_type = OptionalCType(tuple_type).cpp_type()
                     for idx, single_res in enumerate(res):
-                        fw_grad_setters.append(FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(idx=idx, all_res='_'.join(res),
-                                                                                            out_arg=single_res))
-            elif isinstance(derivative.var_types[0], ListType) and derivative.var_types[0].is_tensor_like():
-                assert len(derivative.var_types) == 1, "Expected number of outputs to be 1 if function returns ListType"
-                opt_res_grad_type = OptionalCType(VectorCType(BaseCType(tensorT))).cpp_type()
-                fw_grad_setters.append(FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(out_arg=res[0], is_inplace=is_inplace_str))
+                        fw_grad_setters.append(
+                            FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(
+                                idx=idx, all_res="_".join(res), out_arg=single_res
+                            )
+                        )
+            elif (
+                isinstance(derivative.var_types[0], ListType)
+                and derivative.var_types[0].is_tensor_like()
+            ):
+                assert (
+                    len(derivative.var_types) == 1
+                ), "Expected number of outputs to be 1 if function returns ListType"
+                opt_res_grad_type = OptionalCType(
+                    VectorCType(BaseCType(tensorT))
+                ).cpp_type()
+                fw_grad_setters.append(
+                    FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
+                        out_arg=res[0], is_inplace=is_inplace_str
+                    )
+                )
             else:
                 raise RuntimeError("Unsupported output type for forward derivative")
 
-            fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
+            fw_grad_opt_definition = (
+                f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
+            )
 
             # View ops create fw_grad that already is a view of the base's fw_grad so just use that
-            content.append(FW_DERIVATIVE_TEMPLATE.substitute(
-                fw_grad_opt_definition=fw_grad_opt_definition,
-                requires_fw_grad=get_any_has_forward_grad_name(derivative.var_names),
-                formula=derivative.formula,
-                out_arg='_'.join(res),
-                unpacked_arguments=unpacked_arguments))
+            content.append(
+                FW_DERIVATIVE_TEMPLATE.substitute(
+                    fw_grad_opt_definition=fw_grad_opt_definition,
+                    requires_fw_grad=get_any_has_forward_grad_name(
+                        derivative.var_names
+                    ),
+                    formula=derivative.formula,
+                    out_arg="_".join(res),
+                    unpacked_arguments=unpacked_arguments,
+                )
+            )
 
         # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
-        content.append('\n'.join(fw_grad_setters))
+        content.append("\n".join(fw_grad_setters))
         return content
 
     def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
@@ -974,25 +1441,40 @@
             if is_out_fn:
                 msg = "because it is an out= function"
             else:
-                msg = ("because it has not been implemented yet.\\nPlease file an issue "
-                       "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
-                       "so that we can prioritize its implementation.")
+                msg = (
+                    "because it has not been implemented yet.\\nPlease file an issue "
+                    "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
+                    "so that we can prioritize its implementation."
+                )
             return msg
+
         res = ""
         to_check: List[str] = []
-        for inp in list(mapMaybe(gen_differentiable_input,
-                                 f.func.arguments.non_out + list(f.func.arguments.out))):  # type: ignore[operator]
+        for inp in list(
+            mapMaybe(
+                gen_differentiable_input,
+                f.func.arguments.non_out + list(f.func.arguments.out),  # type: ignore[operator]
+            )
+        ):
             if is_tensor_type(inp.type):
-                to_check.append(FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name))
+                to_check.append(
+                    FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
+                )
             elif is_tensor_list_type(inp.type):
                 cond = FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp="_t")
-                res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(arg=inp.name, cond=cond, name=name, msg=get_msg())
+                res += FW_DERIVATIVE_FORBID_LIST_TEMPLATE.substitute(
+                    arg=inp.name, cond=cond, name=name, msg=get_msg()
+                )
             else:
-                raise RuntimeError(f'Unsupported input type for "{name}" when forbidding forward AD usage.')
+                raise RuntimeError(
+                    f'Unsupported input type for "{name}" when forbidding forward AD usage.'
+                )
 
         if len(to_check) > 0:
             cond = " || ".join(to_check)
-            res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=get_msg())
+            res += FW_DERIVATIVE_FORBID_TEMPLATE.substitute(
+                cond=cond, name=name, msg=get_msg()
+            )
         return res
 
     body: List[str] = []
@@ -1022,12 +1504,15 @@
             if len(fw_derivatives) == 0:
                 body.append(emit_forbid_fw_derivatives())
             else:
-                assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(differentiable_outputs), (
+                assert sum(
+                    len(derivative.var_names) for derivative in fw_derivatives
+                ) == len(differentiable_outputs), (
                     "Expected the number of forward derivatives implemented to match the "
                     "number of differentiable outputs. NB: This only applies when at least "
                     "one forward derivative is implemented. Not implementing any forward "
                     "derivatives is also okay, and we would require inputs to the op to "
-                    "not have associated tangents in that case.")
+                    "not have associated tangents in that case."
+                )
 
     if requires_derivative:
         # Save only after the forward AD has been set up
@@ -1039,7 +1524,7 @@
         # `reset_grad_accumulator` in an operator that's not `inplace`, you can
         # remove this assert but the code generation will get more elaborate
         assert inplace
-        body.append('reset_grad_accumulator(self);')
+        body.append("reset_grad_accumulator(self);")
     if not returns_void:
-        body.append(f'return {get_return_value(f)};')
+        body.append(f"return {get_return_value(f)};")
     return body
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py
index 8c9eac8..084128c 100644
--- a/tools/autograd/load_derivatives.py
+++ b/tools/autograd/load_derivatives.py
@@ -7,18 +7,40 @@
 from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional
 import yaml
 
-from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
-                                        SavedAttribute, ForwardDerivative)
-from tools.codegen.api.types import (Binding, CppSignatureGroup, NamedCType, BaseCType, VectorCType,
-                                     intArrayRefT, tensorOptionsT, typeAndSizeT, longT, boolT, layoutT,
-                                     tensorGeometryT, scalarTypeT, SpecialArgName,
-                                     OptionalCType, stringT)
+from tools.codegen.api.autograd import (
+    Derivative,
+    DifferentiabilityInfo,
+    SavedAttribute,
+    ForwardDerivative,
+)
+from tools.codegen.api.types import (
+    Binding,
+    CppSignatureGroup,
+    NamedCType,
+    BaseCType,
+    VectorCType,
+    intArrayRefT,
+    tensorOptionsT,
+    typeAndSizeT,
+    longT,
+    boolT,
+    layoutT,
+    tensorGeometryT,
+    scalarTypeT,
+    SpecialArgName,
+    OptionalCType,
+    stringT,
+)
 from tools.codegen.api import cpp
 from tools.codegen.gen import parse_native_yaml, get_grouped_by_view_native_functions
 from tools.codegen.context import with_native_function
 from tools.codegen.model import (
-    FunctionSchema, NativeFunction, Variant, Type,
-    NativeFunctionsViewGroup, OperatorName
+    FunctionSchema,
+    NativeFunction,
+    Variant,
+    Type,
+    NativeFunctionsViewGroup,
+    OperatorName,
 )
 from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader, concatMap
 
@@ -29,30 +51,35 @@
 # we generate them here instead of duplicating them in the yaml.
 # See Note [Codegen'd {view}_copy Operators]
 def add_view_copy_derivatives(
-    infos: List[DifferentiabilityInfo],
-    view_groups: List[NativeFunctionsViewGroup]
+    infos: List[DifferentiabilityInfo], view_groups: List[NativeFunctionsViewGroup]
 ) -> List[DifferentiabilityInfo]:
     # Get the map from each view op's name to its corresponding view group
     view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
-        g.view.func.name: g for g in view_groups}
+        g.view.func.name: g for g in view_groups
+    }
 
     view_copy_differentiability_infos = []
     for info in infos:
         maybe_view_group = view_name_to_group.get(info.func.func.name, None)
         if maybe_view_group is not None and maybe_view_group.view_copy is not None:
-            view_copy_info = info.create_view_copy_from_view_derivative(maybe_view_group)
+            view_copy_info = info.create_view_copy_from_view_derivative(
+                maybe_view_group
+            )
             if view_copy_info is not None:
                 view_copy_differentiability_infos.append(view_copy_info)
 
     return view_copy_differentiability_infos
 
-def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
+
+def load_derivatives(
+    derivatives_yaml_path: str, native_yaml_path: str
+) -> Sequence[DifferentiabilityInfo]:
     # Do some caching as this is a deterministic function
     global _GLOBAL_LOAD_DERIVATIVE_CACHE
     key = (derivatives_yaml_path, native_yaml_path)
     if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
 
-        with open(derivatives_yaml_path, 'r') as f:
+        with open(derivatives_yaml_path, "r") as f:
             definitions = yaml.load(f, Loader=YamlLoader)
 
         funcs = parse_native_yaml(native_yaml_path).native_functions
@@ -61,16 +88,24 @@
         native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
         native_functions_without_view_copies = concatMap(
             # We need to pull out the view_inplace ops too, since they might have their own derivative entries.
-            lambda g: [g] if isinstance(g, NativeFunction) else list(g.functions(include_copy=False)),
-            native_functions_with_view_groups
+            lambda g: [g]
+            if isinstance(g, NativeFunction)
+            else list(g.functions(include_copy=False)),
+            native_functions_with_view_groups,
         )
-        view_groups = [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)]
+        view_groups = [
+            g
+            for g in native_functions_with_view_groups
+            if isinstance(g, NativeFunctionsViewGroup)
+        ]
 
         # What's the difference between function schema v.s. signature?
         # function schema is the complete declaration including mutability annotation / default value and etc.
         # signature is the canonical schema for a group of functions (in-place/out/functional variants)
         # that are semantically related.
-        functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list)
+        functions_by_signature: Dict[
+            FunctionSchema, List[NativeFunction]
+        ] = defaultdict(list)
         functions_by_schema: Dict[str, NativeFunction] = dict()
         for function in native_functions_without_view_copies:
             functions_by_signature[function.func.signature()].append(function)
@@ -82,39 +117,56 @@
         op_counter = Counter[str]()
 
         infos = [
-            create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
-            for defn in definitions]
+            create_differentiability_info(
+                defn, functions_by_signature, functions_by_schema, op_counter
+            )
+            for defn in definitions
+        ]
         infos += add_view_copy_derivatives(infos, view_groups)
 
         _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
 
     return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
 
+
 @with_native_function
 def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
     return CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
 
-def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...],
-                      available_named_gradients: Sequence[str]) -> Derivative:
-    original_formula = formula
-    arguments: List[NamedCType] = [a.nctype.remove_const_ref() for a in cpp_arguments(f)]
 
-    return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f))
+def create_derivative(
+    f: NativeFunction,
+    formula: str,
+    var_names: Tuple[str, ...],
+    available_named_gradients: Sequence[str],
+) -> Derivative:
+    original_formula = formula
+    arguments: List[NamedCType] = [
+        a.nctype.remove_const_ref() for a in cpp_arguments(f)
+    ]
+
+    return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
     return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns)
 
-    named_returns = [NamedCType(name, type) for name, type in zip(return_names, return_types)]
+    named_returns = [
+        NamedCType(name, type) for name, type in zip(return_names, return_types)
+    ]
 
     formula, saved_inputs = saved_variables(formula, arguments, var_names)
     formula, saved_outputs = saved_variables(formula, named_returns, var_names)
 
-    used_named_gradients = {name for name in available_named_gradients if re.search(IDENT_REGEX.format(name), formula)}
+    used_named_gradients = {
+        name
+        for name in available_named_gradients
+        if re.search(IDENT_REGEX.format(name), formula)
+    }
 
     # Check that the referenced derivatives in the formula are in bounds
     for i in used_gradient_indices(formula):
         if i >= len(f.func.returns):
             raise RuntimeError(
-                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
-                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
+                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
+                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
             )
 
     return Derivative(
@@ -126,7 +178,10 @@
         named_gradients=used_named_gradients,
     )
 
-def create_forward_derivative(f: NativeFunction, formula: str, names: Tuple[str, ...]) -> ForwardDerivative:
+
+def create_forward_derivative(
+    f: NativeFunction, formula: str, names: Tuple[str, ...]
+) -> ForwardDerivative:
     var_names = names
     var_types: Optional[Tuple[Type, ...]] = None
     for r in f.func.returns:
@@ -157,7 +212,9 @@
         required_inputs_fw_grad=None,
         required_inputs_primal=None,
         required_original_self_value=False,
-        is_reusing_outplace_formula=False)
+        is_reusing_outplace_formula=False,
+    )
+
 
 def postprocess_forward_derivatives(
     f: NativeFunction,
@@ -165,22 +222,23 @@
     all_arg_names: List[str],
     derivatives: List[Derivative],
     forward_derivatives: List[ForwardDerivative],
-    args_with_derivatives: Sequence[Binding]
+    args_with_derivatives: Sequence[Binding],
 ) -> List[ForwardDerivative]:
-
     def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
         required_inputs = set()
         for arg in args_with_derivatives:
-            if arg.type == 'at::TensorList':
+            if arg.type == "at::TensorList":
                 # The functions taking TensorList handle everything internally
                 continue
             arg_name = arg.name
 
             found = re.search(IDENT_REGEX.format(arg_name), formula)
             if found:
-                raise RuntimeError(f"The forward formula for {defn_name} is using the base name of the {arg_name} "
-                                   f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
-                                   f"value and {arg_name}_t to access the tangent.")
+                raise RuntimeError(
+                    f"The forward formula for {defn_name} is using the base name of the {arg_name} "
+                    f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
+                    f"value and {arg_name}_t to access the tangent."
+                )
 
             found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
             if found:
@@ -194,16 +252,23 @@
         formula = defn.formula
         required_inputs_tangent = find_required_inputs(formula, "_t")
         if formula == "auto_element_wise":
-            if ((not len(args_with_derivatives) == 1) or len(forward_derivatives) > 1
-               or len(forward_derivatives[0].var_names) > 1):
-                raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the "
-                                   "forward definition of gradient as element_wise but this only "
-                                   "works for functions with a single differentiable input and a "
-                                   "single differentiable output.")
+            if (
+                (not len(args_with_derivatives) == 1)
+                or len(forward_derivatives) > 1
+                or len(forward_derivatives[0].var_names) > 1
+            ):
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as element_wise but this only "
+                    "works for functions with a single differentiable input and a "
+                    "single differentiable output."
+                )
             if not len(derivatives) == 1:
-                raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the "
-                                   "forward definition of gradient as element_wise but it does not "
-                                   "defines the gradient formula for its argument which is required.")
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as element_wise but it does not "
+                    "defines the gradient formula for its argument which is required."
+                )
             # This transformation is based on the observation that for element-wise functions, the Jacobian
             # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
             # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
@@ -222,6 +287,7 @@
             # Do replacement 1) of the grad
             def repl(m: Any) -> str:
                 return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
+
             fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
 
             # Do replacement 2) of the input variables
@@ -230,6 +296,7 @@
 
                 def repl(m: Any) -> str:
                     return f"{m.group(1)}{arg_name}_p{m.group(2)}"
+
                 fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
 
             # Do the final conjugate 3)
@@ -240,10 +307,15 @@
             required_inputs_tangent = tuple(all_arg_names)
             formula = fw_formula
         elif formula == "auto_linear":
-            if len(forward_derivatives) > 1 or len(forward_derivatives[0].var_names) > 1:
-                raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the "
-                                   "forward definition of gradient as linear but this only works "
-                                   "for functions with a single differentiable output.")
+            if (
+                len(forward_derivatives) > 1
+                or len(forward_derivatives[0].var_names) > 1
+            ):
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as linear but this only works "
+                    "for functions with a single differentiable output."
+                )
             # This transformation is based on the observation that linear functions can be written as:
             #   y = f(x) = A * x
             # For some matrix A and the Jacobian of the function f is also A.
@@ -269,7 +341,9 @@
                 fw_formula = "at::{}({})".format(defn_name, ", ".join(new_args))
             else:
                 assert Variant.method in f.variants
-                fw_formula = "{}.{}({})".format(new_args[0], defn_name, ", ".join(new_args[1:]))
+                fw_formula = "{}.{}({})".format(
+                    new_args[0], defn_name, ", ".join(new_args[1:])
+                )
 
             # All of the input tangents are always used so all of them are required here.
             required_inputs_tangent = tuple(diff_arg_names)
@@ -281,18 +355,24 @@
         # This call inspects the formula to find for which input's primal are used.
         required_inputs_primal = find_required_inputs(formula, "_p")
 
-        updated_derivatives.append(ForwardDerivative(
-            formula=formula,
-            var_names=defn.var_names,
-            var_types=defn.var_types,
-            required_inputs_fw_grad=required_inputs_tangent,
-            required_inputs_primal=required_inputs_primal,
-            required_original_self_value=False,
-            is_reusing_outplace_formula=False))
+        updated_derivatives.append(
+            ForwardDerivative(
+                formula=formula,
+                var_names=defn.var_names,
+                var_types=defn.var_types,
+                required_inputs_fw_grad=required_inputs_tangent,
+                required_inputs_primal=required_inputs_primal,
+                required_original_self_value=False,
+                is_reusing_outplace_formula=False,
+            )
+        )
 
     return updated_derivatives
 
-def is_forward_derivative_definition(all_arg_names: List[str], names: Tuple[str, ...]) -> bool:
+
+def is_forward_derivative_definition(
+    all_arg_names: List[str], names: Tuple[str, ...]
+) -> bool:
     for name in names:
         if name not in all_arg_names:
             return True
@@ -300,6 +380,7 @@
             return False
     raise RuntimeError("Expected `names` to be non-empty")
 
+
 def create_differentiability_info(
     defn: Dict[Any, Any],
     functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
@@ -308,17 +389,19 @@
 ) -> DifferentiabilityInfo:
     """Processes a single entry `defn` in derivatives.yaml"""
 
-    def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction:
+    def canonical_function(
+        functions: Sequence[NativeFunction], name: str
+    ) -> NativeFunction:
         for f in functions:
             if cpp.name(f.func) == name:
                 return f
         # some functions only have in-place variants
-        assert name + '_' == cpp.name(functions[0].func)
+        assert name + "_" == cpp.name(functions[0].func)
         return functions[0]
 
     def split_names(raw_names: str) -> Tuple[str, ...]:
         """Given "foo, bar", return ["foo", "bar"]."""
-        return tuple(x.strip() for x in raw_names.split(','))
+        return tuple(x.strip() for x in raw_names.split(","))
 
     def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
         """
@@ -327,14 +410,16 @@
         used with double backwards.
         """
 
-        uses_grad = False                   # true if any derivative uses "grad"
-        num_grads_uses = 0                  # count of uses of "grads" or "grads[INDEX]"
-        uses_named_grads = False            # true if any derivative uses "grad_{name}"
+        uses_grad = False  # true if any derivative uses "grad"
+        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
+        uses_named_grads = False  # true if any derivative uses "grad_{name}"
         used_grads_indices: List[int] = []  # which indices of grads are used
         for d in derivatives:
             formula = d.formula
-            uses_grad = uses_grad or bool(re.findall(IDENT_REGEX.format('grad'), formula))
-            num_grads_uses += len(re.findall(IDENT_REGEX.format('grads'), formula))
+            uses_grad = uses_grad or bool(
+                re.findall(IDENT_REGEX.format("grad"), formula)
+            )
+            num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
             uses_named_grads = uses_named_grads or bool(d.named_gradients)
             used_grads_indices.extend(used_gradient_indices(formula))
         # This is a basic sanity check: the number of places we see
@@ -347,26 +432,32 @@
         only_used_grads_indices = num_grads_uses == len(used_grads_indices)
 
         if uses_grad and num_grads_uses > 0:
-            raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml illegally "
-                               "mixes use of 'grad' and 'grads'. Consider replacing "
-                               "occurrences of 'grad' with 'grads[0]'")
+            raise RuntimeError(
+                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
+                "mixes use of 'grad' and 'grads'. Consider replacing "
+                "occurrences of 'grad' with 'grads[0]'"
+            )
 
         if only_used_grads_indices and set(used_grads_indices) == {0}:
-            raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml solely "
-                               "refers to 'grads[0]'.  If the first output is indeed the "
-                               "only differentiable output, replace 'grads[0]' with 'grad'; "
-                               "otherwise, there is a likely error in your derivatives "
-                               "declaration.")
+            raise RuntimeError(
+                f"Derivative definition of {defn_name} in derivatives.yaml solely "
+                "refers to 'grads[0]'.  If the first output is indeed the "
+                "only differentiable output, replace 'grads[0]' with 'grad'; "
+                "otherwise, there is a likely error in your derivatives "
+                "declaration."
+            )
 
         if uses_named_grads and (uses_grad or num_grads_uses > 0):
             raise RuntimeError(
-                f'Derivative definition of {defn_name} in derivatives.yaml illegally '
+                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
-                'only one method for identifying gradients.')
-
+                "only one method for identifying gradients."
+            )
 
     @with_native_function
-    def set_up_derivatives(f: NativeFunction) -> Tuple[
+    def set_up_derivatives(
+        f: NativeFunction,
+    ) -> Tuple[
         Sequence[Derivative],
         Sequence[ForwardDerivative],
         Sequence[Binding],
@@ -380,7 +471,9 @@
         args_with_derivatives_set: Set[str] = set()
 
         all_arg_names = [a.name for a in cpp_arguments(f)]
-        all_ret_names = [r.name for r in f.func.returns]  # only used for the assert below
+        all_ret_names = [
+            r.name for r in f.func.returns
+        ]  # only used for the assert below
         # output_differentiability is captured from the enclosed
         # scope. Don't modify it.
         #
@@ -393,13 +486,15 @@
         differentiability = output_differentiability or [True] * len(f.func.returns)
         # A return is available as a named gradient ...
         available_named_gradients = [
-            f'grad_{ret.name}' for ret, differentiable in zip(f.func.returns, differentiability)
+            f"grad_{ret.name}"
+            for ret, differentiable in zip(f.func.returns, differentiability)
             # if it has not been explicitly made undifferentiable
             if differentiable
             # and if it has a name
             and ret.name is not None
             # and if its type is differentiable
-            and ret.type.is_tensor_like()]
+            and ret.type.is_tensor_like()
+        ]
 
         for raw_names in sorted(defn.keys()):
             formula = defn[raw_names]
@@ -408,62 +503,87 @@
             for name in names:
                 assert not (name in all_arg_names and name in all_ret_names), (
                     f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
-                    f"expected '{name}' to not be both an input arg and named return. ")
+                    f"expected '{name}' to not be both an input arg and named return. "
+                )
 
             if is_forward_derivative_definition(all_arg_names, names):
                 forward_derivatives.append(create_forward_derivative(f, formula, names))
             else:
-                if formula.lower().strip() == 'non_differentiable':
+                if formula.lower().strip() == "non_differentiable":
                     non_differentiable_arg_names += names
                 else:
-                    derivative = create_derivative(f, formula, names,
-                                                   available_named_gradients)
+                    derivative = create_derivative(
+                        f, formula, names, available_named_gradients
+                    )
                     derivatives.append(derivative)
                     args_with_derivatives_set |= set(names)
 
         overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
         if overlap:
-            raise RuntimeError(f'derivatives definition for {defn} have overlapped non_differentiable '
-                               f'and differentiable variables: {overlap}')
+            raise RuntimeError(
+                f"derivatives definition for {defn} have overlapped non_differentiable "
+                f"and differentiable variables: {overlap}"
+            )
 
         # Next, let us determine the list of inputs in order.
         # TODO: do we need eagerly calculate and save it here? Can it be derived
         # from NativeFunction and `derivatives` on callsites instead?
-        args_with_derivatives = [a for a in cpp_arguments(f) if a.name in args_with_derivatives_set]
+        args_with_derivatives = [
+            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
+        ]
 
         # Postprocess forward derivatives definitions now that we know the differentiable arguments
-        forward_derivatives = postprocess_forward_derivatives(f, defn_name, all_arg_names, derivatives,
-                                                              forward_derivatives, args_with_derivatives)
+        forward_derivatives = postprocess_forward_derivatives(
+            f,
+            defn_name,
+            all_arg_names,
+            derivatives,
+            forward_derivatives,
+            args_with_derivatives,
+        )
 
         # Test to see if the use of 'grads' makes sense.
         check_grad_usage(defn_name, derivatives)
 
-        return (derivatives, forward_derivatives, args_with_derivatives,
-                non_differentiable_arg_names, available_named_gradients)
+        return (
+            derivatives,
+            forward_derivatives,
+            args_with_derivatives,
+            non_differentiable_arg_names,
+            available_named_gradients,
+        )
 
     # NB: Removes 'name' from defn dictionary
-    specification = defn.pop('name')
+    specification = defn.pop("name")
     defn_name, _ = split_name_params(specification)
     # NB: Removes 'output_differentiability' from defn dictionary
     #     `None` means all differentiable.
-    output_differentiability = defn.pop('output_differentiability', None)
+    output_differentiability = defn.pop("output_differentiability", None)
     output_differentiability_conditions = None
-    if output_differentiability and any([isinstance(diff, str) for diff in output_differentiability]):
+    if output_differentiability and any(
+        [isinstance(diff, str) for diff in output_differentiability]
+    ):
         if len(output_differentiability) != 1:
-            raise RuntimeError(f'Not supported: for {specification},'
-                               f'output_differentiability must either be '
-                               f'List[bool] or a List[str] where each str is a '
-                               f'condition. In the case where it is a condition, '
-                               f'we only support single-output functions. '
-                               f'Please file us an issue. ')
+            raise RuntimeError(
+                f"Not supported: for {specification},"
+                f"output_differentiability must either be "
+                f"List[bool] or a List[str] where each str is a "
+                f"condition. In the case where it is a condition, "
+                f"we only support single-output functions. "
+                f"Please file us an issue. "
+            )
         output_differentiability_conditions = output_differentiability
         output_differentiability = [True]
 
     schema_function = functions_by_schema.get(specification)
     if not schema_function:
-        avail = '\n'.join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name)
-        raise RuntimeError(f'could not find ATen function for schema: {specification} '
-                           f'.  Available signatures:\n{avail}')
+        avail = "\n".join(
+            k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
+        )
+        raise RuntimeError(
+            f"could not find ATen function for schema: {specification} "
+            f".  Available signatures:\n{avail}"
+        )
 
     # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
     # to map in-place schemas to the out-of-place variants.
@@ -471,24 +591,39 @@
     signature = schema_function.func.signature()
     functions = functions_by_signature[signature]
     if len(functions) == 0:
-        avail = '\n'.join(str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name)
-        raise RuntimeError(f'could not find ATen function for legacy signature: {signature} '
-                           f'corresponding to schema {specification}.  Please report a bug to PyTorch. '
-                           f'Available signatures:\n{avail}')
+        avail = "\n".join(
+            str(k)
+            for k, v in functions_by_signature.items()
+            if cpp.name(k) == defn_name
+        )
+        raise RuntimeError(
+            f"could not find ATen function for legacy signature: {signature} "
+            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
+            f"Available signatures:\n{avail}"
+        )
 
     canonical = canonical_function(functions, defn_name)
-    if 'grad_input_mask' in (a.name for a in cpp_arguments(canonical)):
-        raise RuntimeError(f"Schema for {defn_name} has an argument named grad_input_mask, "
-                           "but this name would be shadowed by our codegen. "
-                           "Please use a different name in native_functions.yaml.")
+    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
+        raise RuntimeError(
+            f"Schema for {defn_name} has an argument named grad_input_mask, "
+            "but this name would be shadowed by our codegen. "
+            "Please use a different name in native_functions.yaml."
+        )
 
-    if 'result' in (a.name for a in cpp_arguments(canonical)):
-        raise RuntimeError(f"Schema for {defn_name} has an argument named result, "
-                           "but this is only allowed for outputs."
-                           "Please use a different name in native_functions.yaml.")
+    if "result" in (a.name for a in cpp_arguments(canonical)):
+        raise RuntimeError(
+            f"Schema for {defn_name} has an argument named result, "
+            "but this is only allowed for outputs."
+            "Please use a different name in native_functions.yaml."
+        )
 
-    (derivatives, forward_derivatives, args_with_derivatives,
-     non_differentiable_arg_names, available_named_gradients) = set_up_derivatives(canonical)
+    (
+        derivatives,
+        forward_derivatives,
+        args_with_derivatives,
+        non_differentiable_arg_names,
+        available_named_gradients,
+    ) = set_up_derivatives(canonical)
 
     used_named_gradients: Set[str] = set()
     for d in derivatives:
@@ -498,7 +633,7 @@
     op = None
     if args_with_derivatives:
         op_prefix = _create_op_prefix(defn_name)
-        op = f'{op_prefix}{op_counter[op_prefix]}'
+        op = f"{op_prefix}{op_counter[op_prefix]}"
         op_counter[op_prefix] += 1
 
     return DifferentiabilityInfo(
@@ -517,7 +652,9 @@
         output_differentiability_conditions=output_differentiability_conditions,
     )
 
-GRAD_INDEX_REGEX = r'(?:^|\W)grads\[(\d+)\]'
+
+GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
+
 
 def used_gradient_indices(formula: str) -> List[int]:
     """Determine a list of gradient indices (the i in grads[i]) that
@@ -528,111 +665,167 @@
     """
     return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
 
+
 def saved_variables(
     formula: str,
     nctypes: List[NamedCType],
     var_names: Tuple[str, ...],
 ) -> Tuple[str, Tuple[SavedAttribute, ...]]:
-
     def stride_expr(name: str) -> str:
         assert var_names == (name,), (
             'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
-            'that ".strides()" is being called on.')
+            'that ".strides()" is being called on.'
+        )
         return f'strides_or_error({name}, "{name}")'
 
     REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
         # replace self.sizes() with self_sizes
-        (r'{}.sizes\(\)', {
-            'suffix': '_sizes',
-            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
-        }),
+        (
+            r"{}.sizes\(\)",
+            {
+                "suffix": "_sizes",
+                "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)),
+            },
+        ),
         # replace self->sizes() with self_sizes_opt
-        (r'{}->sizes\(\)', {
-            'suffix': '_sizes_opt',
-            'nctype': lambda name: NamedCType(name, OptionalCType(BaseCType(intArrayRefT))),
-            'expr': lambda name: f'{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt',
-        }),
+        (
+            r"{}->sizes\(\)",
+            {
+                "suffix": "_sizes_opt",
+                "nctype": lambda name: NamedCType(
+                    name, OptionalCType(BaseCType(intArrayRefT))
+                ),
+                "expr": lambda name: f"{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt",
+            },
+        ),
         # replace self.options() with self_options
-        (r'{}.options\(\)', {
-            'suffix': '_options',
-            'nctype': lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
-        }),
+        (
+            r"{}.options\(\)",
+            {
+                "suffix": "_options",
+                "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
+            },
+        ),
         # replace zeros_like(self) with self_info
-        (r'zeros_like\({}\)', {
-            'suffix': '_info',
-            'nctype': lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
-            'expr': lambda name: name,  # at save-time
-            'res': lambda name: name + '_info.zeros()',  # at eval-time
-        }),
+        (
+            r"zeros_like\({}\)",
+            {
+                "suffix": "_info",
+                "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
+                "expr": lambda name: name,  # at save-time
+                "res": lambda name: name + "_info.zeros()",  # at eval-time
+            },
+        ),
         # replace self.size(2) with self_size_2
-        (r'{}.size\((\w+)\)', {
-            'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
-            'nctype': lambda name: NamedCType(name, BaseCType(longT)),
-        }),
+        (
+            r"{}.size\((\w+)\)",
+            {
+                "suffix": lambda m: "_argsize_{}".format(*m.groups()),
+                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
+            },
+        ),
         # replace self.numel() with self_numel
-        (r'{}.numel\(\)', {
-            'suffix': '_numel',
-            'nctype': lambda name: NamedCType(name, BaseCType(longT)),
-        }),
+        (
+            r"{}.numel\(\)",
+            {
+                "suffix": "_numel",
+                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
+            },
+        ),
         # replace to_args_sizes(self) with self_args_sizes
-        (r'to_args_sizes\({}\)', {
-            'suffix': '_args_sizes',
-            'nctype': lambda name: NamedCType(name, VectorCType(VectorCType(BaseCType(longT)))),
-        }),
+        (
+            r"to_args_sizes\({}\)",
+            {
+                "suffix": "_args_sizes",
+                "nctype": lambda name: NamedCType(
+                    name, VectorCType(VectorCType(BaseCType(longT)))
+                ),
+            },
+        ),
         # replace to_args_scalartypes(self) with self_args_scalartypes
-        (r'to_args_scalartypes\({}\)', {
-            'suffix': '_args_scalartypes',
-            'nctype': lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))),
-        }),
+        (
+            r"to_args_scalartypes\({}\)",
+            {
+                "suffix": "_args_scalartypes",
+                "nctype": lambda name: NamedCType(
+                    name, VectorCType(BaseCType(scalarTypeT))
+                ),
+            },
+        ),
         # replace TensorGeometry(self) with self_geometry
-        (r'TensorGeometry\({}\)', {
-            'suffix': '_geometry',
-            'nctype': lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
-        }),
-        (r'{}.scalar_type\(\)', {
-            'suffix': '_scalar_type',
-            'nctype': lambda name: NamedCType(name, BaseCType(scalarTypeT)),
-        }),
+        (
+            r"TensorGeometry\({}\)",
+            {
+                "suffix": "_geometry",
+                "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
+            },
+        ),
+        (
+            r"{}.scalar_type\(\)",
+            {
+                "suffix": "_scalar_type",
+                "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
+            },
+        ),
         # replace self.dim() with self_dim
-        (r'{}.dim\(\)', {
-            'suffix': '_dim',
-            'nctype': lambda name: NamedCType(name, BaseCType(longT)),
-        }),
+        (
+            r"{}.dim\(\)",
+            {
+                "suffix": "_dim",
+                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
+            },
+        ),
         # replace self.strides() with self_strides
-        (r'{}.strides\(\)', {
-            'suffix': '_strides',
-            'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
-            'expr': stride_expr,
-        }),
+        (
+            r"{}.strides\(\)",
+            {
+                "suffix": "_strides",
+                "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)),
+                "expr": stride_expr,
+            },
+        ),
         # replace self.layout() with self_layout
-        (r'{}.layout\(\)', {
-            'suffix': '_layout',
-            'nctype': lambda name: NamedCType(name, BaseCType(layoutT)),
-        }),
+        (
+            r"{}.layout\(\)",
+            {
+                "suffix": "_layout",
+                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
+            },
+        ),
         # replace self.is_conj() with self_conjugate
-        (r'{}.is_conj\(\)', {
-            'suffix': '_conjugate',
-            'nctype': lambda name: NamedCType(name, BaseCType(boolT)),
-        })
+        (
+            r"{}.is_conj\(\)",
+            {
+                "suffix": "_conjugate",
+                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
+            },
+        ),
     ]
 
     # find which arguments need to be saved
     saved: List[SavedAttribute] = []
 
     for nctype in nctypes:
-        name = nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
+        name = (
+            nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
+        )
         # First search the formula for expressions which can be evaluated
         # when the autograd Function is created to avoid saving variables
         for regex, info in REPLACEMENTS:
+
             def repl(m: Match[str]) -> str:
-                suffix: str = info['suffix'](m) if callable(info['suffix']) else info['suffix']
-                expr: str = info['expr'](name) if 'expr' in info else m.group(0)
-                saved.append(SavedAttribute(
-                    nctype=info['nctype'](name + suffix),
-                    expr=expr,
-                ))
-                if 'res' in info:
-                    replacement: str = info['res'](name)
+                suffix: str = (
+                    info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
+                )
+                expr: str = info["expr"](name) if "expr" in info else m.group(0)
+                saved.append(
+                    SavedAttribute(
+                        nctype=info["nctype"](name + suffix),
+                        expr=expr,
+                    )
+                )
+                if "res" in info:
+                    replacement: str = info["res"](name)
                     return replacement
                 return name + suffix
 
@@ -643,19 +836,23 @@
         # the backward function
         if nctype.type == OptionalCType(BaseCType(stringT)):
             formula = re.sub(
-                rf'\b{name}\b',
-                f'{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt',
-                formula)
+                rf"\b{name}\b",
+                f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt",
+                formula,
+            )
 
         # Find any variables which remain in the formula and save them
         if re.search(IDENT_REGEX.format(name), formula):
-            saved.append(SavedAttribute(
-                nctype=nctype,
-                expr=name,
-            ))
+            saved.append(
+                SavedAttribute(
+                    nctype=nctype,
+                    expr=name,
+                )
+            )
 
     return formula, tuple(saved)
 
+
 def _create_op_prefix(name: str) -> str:
     """Takes a native function name converts to a op prefix name.
 
@@ -669,15 +866,19 @@
     >>> _create_op_prefix('add')
     'AddBackward'
     """
-    camel_case = ''.join([p.title() for p in name.split('_')])
-    return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
+    camel_case = "".join([p.title() for p in name.split("_")])
+    return (camel_case + "Backward").replace("ForwardBackward", "Backward")
 
 
 def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
     seen: Set[str] = set()
     saved: List[SavedAttribute] = []
     for var in vars:
-        name = var.nctype.name.name if isinstance(var.nctype.name, SpecialArgName) else var.nctype.name
+        name = (
+            var.nctype.name.name
+            if isinstance(var.nctype.name, SpecialArgName)
+            else var.nctype.name
+        )
         if name in seen:
             continue
         seen.add(name)
diff --git a/tools/build_libtorch.py b/tools/build_libtorch.py
index c263e50..c550877 100644
--- a/tools/build_libtorch.py
+++ b/tools/build_libtorch.py
@@ -11,13 +11,22 @@
 from tools.build_pytorch_libs import build_caffe2
 from tools.setup_helpers.cmake import CMake
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # Placeholder for future interface. For now just gives a nice -h.
-    parser = argparse.ArgumentParser(description='Build libtorch')
-    parser.add_argument('--rerun-cmake', action="store_true", help='rerun cmake')
-    parser.add_argument('--cmake-only', action="store_true",
-                        help='Stop once cmake terminates. Leave users a chance to adjust build options')
+    parser = argparse.ArgumentParser(description="Build libtorch")
+    parser.add_argument("--rerun-cmake", action="store_true", help="rerun cmake")
+    parser.add_argument(
+        "--cmake-only",
+        action="store_true",
+        help="Stop once cmake terminates. Leave users a chance to adjust build options",
+    )
     options = parser.parse_args()
 
-    build_caffe2(version=None, cmake_python_library=None, build_python=False,
-                 rerun_cmake=options.rerun_cmake, cmake_only=options.cmake_only, cmake=CMake())
+    build_caffe2(
+        version=None,
+        cmake_python_library=None,
+        build_python=False,
+        rerun_cmake=options.rerun_cmake,
+        cmake_only=options.cmake_only,
+        cmake=CMake(),
+    )
diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py
index 60a2e3c..eba8ea1 100644
--- a/tools/build_pytorch_libs.py
+++ b/tools/build_pytorch_libs.py
@@ -9,23 +9,29 @@
 
 from setuptools import distutils  # type: ignore[import]
 
-def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
-    vc_arch = 'x64' if IS_64BIT else 'x86'
 
-    if platform.machine() == 'ARM64':
-        vc_arch = 'x64_arm64'
+def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
+    vc_arch = "x64" if IS_64BIT else "x86"
+
+    if platform.machine() == "ARM64":
+        vc_arch = "x64_arm64"
 
         # First Win11 Windows on Arm build version that supports x64 emulation
         # is 10.0.22000.
         win11_1st_version = (10, 0, 22000)
-        current_win_version = tuple(int(version_part) for version_part in
-                                    platform.version().split('.'))
+        current_win_version = tuple(
+            int(version_part) for version_part in platform.version().split(".")
+        )
         if current_win_version < win11_1st_version:
-            vc_arch = 'x86_arm64'
-            print("Warning: 32-bit toolchain will be used, but 64-bit linker "
-                  "is recommended to avoid out-of-memory linker error!")
-            print("Warning: Please consider upgrading to Win11, where x64 "
-                  "emulation is enabled!")
+            vc_arch = "x86_arm64"
+            print(
+                "Warning: 32-bit toolchain will be used, but 64-bit linker "
+                "is recommended to avoid out-of-memory linker error!"
+            )
+            print(
+                "Warning: Please consider upgrading to Win11, where x64 "
+                "emulation is enabled!"
+            )
 
     vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
     # Keys in `_get_vc_env` are always lowercase.
@@ -46,19 +52,21 @@
     # you should NEVER add something to this list. It is bad practice to
     # have cmake read the environment
     my_env = os.environ.copy()
-    if 'CUDA_HOME' in my_env:  # Keep CUDA_HOME. This env variable is still used in other part.
-        my_env['CUDA_BIN_PATH'] = my_env['CUDA_HOME']
+    if (
+        "CUDA_HOME" in my_env
+    ):  # Keep CUDA_HOME. This env variable is still used in other part.
+        my_env["CUDA_BIN_PATH"] = my_env["CUDA_HOME"]
     elif IS_WINDOWS:  # we should eventually make this as part of FindCUDA.
-        cuda_win = glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
+        cuda_win = glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*")
         if len(cuda_win) > 0:
-            my_env['CUDA_BIN_PATH'] = cuda_win[0]
+            my_env["CUDA_BIN_PATH"] = cuda_win[0]
 
     if IS_WINDOWS and USE_NINJA:
         # When using Ninja under Windows, the gcc toolchain will be chosen as
         # default. But it should be set to MSVC as the user's first choice.
         my_env = _overlay_windows_vcvars(my_env)
-        my_env.setdefault('CC', 'cl')
-        my_env.setdefault('CXX', 'cl')
+        my_env.setdefault("CC", "cl")
+        my_env.setdefault("CXX", "cl")
     return my_env
 
 
@@ -71,18 +79,15 @@
     cmake: CMake,
 ) -> None:
     my_env = _create_build_env()
-    build_test = not check_negative_env_flag('BUILD_TEST')
-    cmake.generate(version,
-                   cmake_python_library,
-                   build_python,
-                   build_test,
-                   my_env,
-                   rerun_cmake)
+    build_test = not check_negative_env_flag("BUILD_TEST")
+    cmake.generate(
+        version, cmake_python_library, build_python, build_test, my_env, rerun_cmake
+    )
     if cmake_only:
         return
     cmake.build(my_env)
     if build_python:
-        caffe2_proto_dir = os.path.join(cmake.build_dir, 'caffe2', 'proto')
-        for proto_file in glob(os.path.join(caffe2_proto_dir, '*.py')):
-            if proto_file != os.path.join(caffe2_proto_dir, '__init__.py'):
-                shutil.copy(proto_file, os.path.join('caffe2', 'proto'))
+        caffe2_proto_dir = os.path.join(cmake.build_dir, "caffe2", "proto")
+        for proto_file in glob(os.path.join(caffe2_proto_dir, "*.py")):
+            if proto_file != os.path.join(caffe2_proto_dir, "__init__.py"):
+                shutil.copy(proto_file, os.path.join("caffe2", "proto"))
diff --git a/tools/code_analyzer/gen_op_registration_allowlist.py b/tools/code_analyzer/gen_op_registration_allowlist.py
index 00f880d..65e5685 100644
--- a/tools/code_analyzer/gen_op_registration_allowlist.py
+++ b/tools/code_analyzer/gen_op_registration_allowlist.py
@@ -16,24 +16,26 @@
 
 DepGraph = Dict[str, Set[str]]
 
+
 def canonical_name(opname: str) -> str:
     # Skip the overload name part as it's not supported by code analyzer yet.
-    return opname.split('.', 1)[0]
+    return opname.split(".", 1)[0]
+
 
 def load_op_dep_graph(fname: str) -> DepGraph:
-    with open(fname, 'r') as stream:
+    with open(fname, "r") as stream:
         result = defaultdict(set)
         for op in yaml.safe_load(stream):
-            op_name = canonical_name(op['name'])
-            for dep in op.get('depends', []):
-                dep_name = canonical_name(dep['name'])
+            op_name = canonical_name(op["name"])
+            for dep in op.get("depends", []):
+                dep_name = canonical_name(dep["name"])
                 result[op_name].add(dep_name)
         return dict(result)
 
 
 def load_root_ops(fname: str) -> List[str]:
     result = []
-    with open(fname, 'r') as stream:
+    with open(fname, "r") as stream:
         for op in yaml.safe_load(stream):
             result.append(canonical_name(op))
     return result
@@ -49,7 +51,7 @@
 
     # The dependency graph might contain a special entry with key = `__BASE__`
     # and value = (set of `base` ops to always include in custom build).
-    queue.append('__BASE__')
+    queue.append("__BASE__")
 
     # The dependency graph might contain a special entry with key = `__ROOT__`
     # and value = (set of ops reachable from C++ functions). Insert the special
@@ -58,7 +60,7 @@
     # '__ROOT__' is only needed for full-jit. Keep it only for training.
     # TODO: when FL is migrated from full-jit to lite trainer, remove '__ROOT__'
     if train:
-        queue.append('__ROOT__')
+        queue.append("__ROOT__")
 
     while queue:
         cur = queue.pop()
@@ -69,21 +71,25 @@
 
     return sorted(result)
 
+
 def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
-    return ' '.join(gen_transitive_closure(dep_graph, root_ops))
+    return " ".join(gen_transitive_closure(dep_graph, root_ops))
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description='Util to produce transitive dependencies for custom build')
+        description="Util to produce transitive dependencies for custom build"
+    )
     parser.add_argument(
-        '--op-dependency',
-        help='input yaml file of op dependency graph '
-             '- can be omitted for custom build with static dispatch')
+        "--op-dependency",
+        help="input yaml file of op dependency graph "
+        "- can be omitted for custom build with static dispatch",
+    )
     parser.add_argument(
-        '--root-ops',
+        "--root-ops",
         required=True,
-        help='input yaml file of root (directly used) operators')
+        help="input yaml file of root (directly used) operators",
+    )
     args = parser.parse_args()
 
     deps = load_op_dep_graph(args.op_dependency) if args.op_dependency else {}
diff --git a/tools/code_analyzer/gen_oplist.py b/tools/code_analyzer/gen_oplist.py
index 010b420..ea58488 100644
--- a/tools/code_analyzer/gen_oplist.py
+++ b/tools/code_analyzer/gen_oplist.py
@@ -7,11 +7,15 @@
 from typing import Set, List, Any
 
 import yaml
-from tools.codegen.selective_build.selector import combine_selective_builders, SelectiveBuilder
+from tools.codegen.selective_build.selector import (
+    combine_selective_builders,
+    SelectiveBuilder,
+)
 from tools.lite_interpreter.gen_selected_mobile_ops_header import (
     write_selected_mobile_ops,
 )
 
+
 def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
     ops = []
     for (op_name, op) in selective_builder.operators.items():
@@ -125,7 +129,7 @@
     )
     options = parser.parse_args()
 
-    if (os.path.isfile(options.model_file_list_path)):
+    if os.path.isfile(options.model_file_list_path):
         print("Processing model file: ", options.model_file_list_path)
         model_dicts = []
         model_dict = yaml.safe_load(open(options.model_file_list_path))
@@ -180,5 +184,6 @@
         selective_builder,
     )
 
+
 if __name__ == "__main__":
     main(sys.argv)
diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py
index 230ce92..37034c5 100644
--- a/tools/codegen/api/cpp.py
+++ b/tools/codegen/api/cpp.py
@@ -38,7 +38,7 @@
     intArrayRefT,
     optionalIntArrayRefT,
     tensorOptionsT,
-    symIntArrayRefT
+    symIntArrayRefT,
 )
 from tools.codegen import local
 from tools.codegen.utils import assert_never
@@ -155,7 +155,7 @@
             return NamedCType(binds, BaseCType(tensorListT))
         elif str(t.elem) == "Scalar":
             return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
-        elif str(t.elem) == 'SymInt':
+        elif str(t.elem) == "SymInt":
             return NamedCType(binds, BaseCType(symIntArrayRefT))
         elif str(t.elem) == "Dimname":
             return NamedCType(binds, BaseCType(dimnameListT))
diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py
index 6858552..4f8cadf 100644
--- a/tools/codegen/api/python.py
+++ b/tools/codegen/api/python.py
@@ -1224,9 +1224,9 @@
             return "intlist"
         elif str(t) == "float[]":
             return "doublelist"
-        elif str(t.elem) == 'SymInt':
+        elif str(t.elem) == "SymInt":
             # accept definite size
-            return 'symintlist'
+            return "symintlist"
         elif str(t) == "Scalar[]":
             return "scalarlist"
     raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
diff --git a/tools/codegen/api/translate.py b/tools/codegen/api/translate.py
index 964aa25..47aaeae 100644
--- a/tools/codegen/api/translate.py
+++ b/tools/codegen/api/translate.py
@@ -25,7 +25,7 @@
     intArrayRefT,
     scalar_t,
     opmath_t,
-    optionalIntArrayRefT
+    optionalIntArrayRefT,
 )
 
 # This file implements a small program synthesis engine that implements
diff --git a/tools/coverage_plugins_package/setup.py b/tools/coverage_plugins_package/setup.py
index c93f612..0125069 100644
--- a/tools/coverage_plugins_package/setup.py
+++ b/tools/coverage_plugins_package/setup.py
@@ -6,8 +6,8 @@
 setuptools.setup(
     name="coverage-plugins",
     version="0.0.1",
-    author='PyTorch Team',
-    author_email='[email protected]',
+    author="PyTorch Team",
+    author_email="[email protected]",
     description="plug-in to coverage for PyTorch JIT",
     long_description=long_description,
     long_description_content_type="text/markdown",
diff --git a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py
index 8dcd313..a64670b 100644
--- a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py
+++ b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py
@@ -1,4 +1,4 @@
-'''
+"""
 This coverage plug-in attempts to cover JIT'd functions and methods that were previously missed in code coverage. Any
 function and method that was passed through/decorated with torch.jit.script or torch.jit.script_method should now be
 marked covered when coverage is run with this plug-in.
@@ -6,39 +6,54 @@
 DISCLAIMER: note that this will mark the entire JIT'd function/method as covered without seeking proof that the
 compiled code has been executed. This means that even if the code chunk is merely compiled and not run, it will get
 marked as covered.
-'''
+"""
 
 from coverage import CoveragePlugin, CoverageData  # type: ignore[import]
-from inspect import ismodule, isclass, ismethod, isfunction, iscode, getsourcefile, getsourcelines
+from inspect import (
+    ismodule,
+    isclass,
+    ismethod,
+    isfunction,
+    iscode,
+    getsourcefile,
+    getsourcelines,
+)
 from time import time
 from typing import Any
 
 # All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
 # `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
 # https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine
-cov_data = CoverageData(basename=f'.coverage.jit.{time()}')
+cov_data = CoverageData(basename=f".coverage.jit.{time()}")
 
 
 def is_not_builtin_class(obj: Any) -> bool:
-    return isclass(obj) and not type(obj).__module__ == 'builtins'
+    return isclass(obj) and not type(obj).__module__ == "builtins"
 
 
 class JitPlugin(CoveragePlugin):  # type: ignore[misc, no-any-unimported]
-    '''
+    """
     dynamic_context is an overridden function that gives us access to every frame run during the coverage process. We
     look for when the function being run is `should_drop`, as all functions that get passed into `should_drop` will be
     compiled and thus should be marked as covered.
-    '''
+    """
+
     def dynamic_context(self, frame: Any) -> None:
-        if frame.f_code.co_name == 'should_drop':
-            obj = frame.f_locals['fn']
+        if frame.f_code.co_name == "should_drop":
+            obj = frame.f_locals["fn"]
             # The many conditions in the if statement below are based on the accepted arguments to getsourcefile. Based
             # on its documentation (https://docs.python.org/3/library/inspect.html#inspect.getsourcefile), the argument
             # must be a module, class, method, function, traceback, frame, or code object AND it cannot be a built-in
             # module, class, or function.
             # Currently, we DO NOT include tracebacks or frames as they should not be JIT'd, and we have not checked for
             # built-in modules or functions as those do not seem to be JIT'd either.
-            if is_not_builtin_class(obj) or ismodule(obj) or ismethod(obj) or isfunction(obj) or iscode(obj):
+            if (
+                is_not_builtin_class(obj)
+                or ismodule(obj)
+                or ismethod(obj)
+                or isfunction(obj)
+                or iscode(obj)
+            ):
                 filename = getsourcefile(obj)
                 # We don't want to report for filename = None
                 if filename:
@@ -51,9 +66,14 @@
                     except OSError:
                         pass
                     else:
-                        line_data = {filename: range(starting_lineno, starting_lineno + len(sourcelines))}
+                        line_data = {
+                            filename: range(
+                                starting_lineno, starting_lineno + len(sourcelines)
+                            )
+                        }
                         cov_data.add_lines(line_data)
         super().dynamic_context(frame)
 
+
 def coverage_init(reg: Any, options: Any) -> None:
     reg.add_dynamic_context(JitPlugin())
diff --git a/tools/download_mnist.py b/tools/download_mnist.py
index dfb0f95..80894ad 100644
--- a/tools/download_mnist.py
+++ b/tools/download_mnist.py
@@ -6,15 +6,15 @@
 import sys
 
 MIRRORS = [
-    'http://yann.lecun.com/exdb/mnist/',
-    'https://ossci-datasets.s3.amazonaws.com/mnist/',
+    "http://yann.lecun.com/exdb/mnist/",
+    "https://ossci-datasets.s3.amazonaws.com/mnist/",
 ]
 
 RESOURCES = [
-    'train-images-idx3-ubyte.gz',
-    'train-labels-idx1-ubyte.gz',
-    't10k-images-idx3-ubyte.gz',
-    't10k-labels-idx1-ubyte.gz',
+    "train-images-idx3-ubyte.gz",
+    "train-labels-idx1-ubyte.gz",
+    "t10k-images-idx3-ubyte.gz",
+    "t10k-labels-idx1-ubyte.gz",
 ]
 
 
@@ -25,23 +25,23 @@
 ) -> None:
     if file_size != -1:
         percent = min(1, (chunk_number * chunk_size) / file_size)
-        bar = '#' * int(64 * percent)
-        sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
+        bar = "#" * int(64 * percent)
+        sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
 
 
 def download(destination_path: str, resource: str, quiet: bool) -> None:
     if os.path.exists(destination_path):
         if not quiet:
-            print('{} already exists, skipping ...'.format(destination_path))
+            print("{} already exists, skipping ...".format(destination_path))
     else:
         for mirror in MIRRORS:
             url = mirror + resource
-            print('Downloading {} ...'.format(url))
+            print("Downloading {} ...".format(url))
             try:
                 hook = None if quiet else report_download_progress
                 urlretrieve(url, destination_path, reporthook=hook)
             except (URLError, ConnectionError) as e:
-                print('Failed to download (trying next):\n{}'.format(e))
+                print("Failed to download (trying next):\n{}".format(e))
                 continue
             finally:
                 if not quiet:
@@ -49,32 +49,32 @@
                     print()
             break
         else:
-            raise RuntimeError('Error downloading resource!')
+            raise RuntimeError("Error downloading resource!")
 
 
 def unzip(zipped_path: str, quiet: bool) -> None:
     unzipped_path = os.path.splitext(zipped_path)[0]
     if os.path.exists(unzipped_path):
         if not quiet:
-            print('{} already exists, skipping ... '.format(unzipped_path))
+            print("{} already exists, skipping ... ".format(unzipped_path))
         return
-    with gzip.open(zipped_path, 'rb') as zipped_file:
-        with open(unzipped_path, 'wb') as unzipped_file:
+    with gzip.open(zipped_path, "rb") as zipped_file:
+        with open(unzipped_path, "wb") as unzipped_file:
             unzipped_file.write(zipped_file.read())
             if not quiet:
-                print('Unzipped {} ...'.format(zipped_path))
+                print("Unzipped {} ...".format(zipped_path))
 
 
 def main() -> None:
     parser = argparse.ArgumentParser(
-        description='Download the MNIST dataset from the internet')
+        description="Download the MNIST dataset from the internet"
+    )
     parser.add_argument(
-        '-d', '--destination', default='.', help='Destination directory')
+        "-d", "--destination", default=".", help="Destination directory"
+    )
     parser.add_argument(
-        '-q',
-        '--quiet',
-        action='store_true',
-        help="Don't report about progress")
+        "-q", "--quiet", action="store_true", help="Don't report about progress"
+    )
     options = parser.parse_args()
 
     if not os.path.exists(options.destination):
@@ -86,8 +86,8 @@
             download(path, resource, options.quiet)
             unzip(path, options.quiet)
     except KeyboardInterrupt:
-        print('Interrupted')
+        print("Interrupted")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/extract_scripts.py b/tools/extract_scripts.py
index 1090886..7a9a29d 100755
--- a/tools/extract_scripts.py
+++ b/tools/extract_scripts.py
@@ -18,84 +18,85 @@
 
 
 def extract(step: Step) -> Optional[Script]:
-    run = step.get('run')
+    run = step.get("run")
 
     # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell
-    shell = step.get('shell', 'bash')
+    shell = step.get("shell", "bash")
     extension = {
-        'bash': '.sh',
-        'pwsh': '.ps1',
-        'python': '.py',
-        'sh': '.sh',
-        'cmd': '.cmd',
-        'powershell': '.ps1',
+        "bash": ".sh",
+        "pwsh": ".ps1",
+        "python": ".py",
+        "sh": ".sh",
+        "cmd": ".cmd",
+        "powershell": ".ps1",
     }.get(shell)
 
-    is_gh_script = step.get('uses', '').startswith('actions/github-script@')
-    gh_script = step.get('with', {}).get('script')
+    is_gh_script = step.get("uses", "").startswith("actions/github-script@")
+    gh_script = step.get("with", {}).get("script")
 
     if run is not None and extension is not None:
         script = {
-            'bash': f'#!/usr/bin/env bash\nset -eo pipefail\n{run}',
-            'sh': f'#!/usr/bin/env sh\nset -e\n{run}',
+            "bash": f"#!/usr/bin/env bash\nset -eo pipefail\n{run}",
+            "sh": f"#!/usr/bin/env sh\nset -e\n{run}",
         }.get(shell, run)
-        return {'extension': extension, 'script': script}
+        return {"extension": extension, "script": script}
     elif is_gh_script and gh_script is not None:
-        return {'extension': '.js', 'script': gh_script}
+        return {"extension": ".js", "script": gh_script}
     else:
         return None
 
 
 def main() -> None:
     parser = argparse.ArgumentParser()
-    parser.add_argument('--out', required=True)
+    parser.add_argument("--out", required=True)
     args = parser.parse_args()
 
     out = Path(args.out)
     if out.exists():
-        sys.exit(f'{out} already exists; aborting to avoid overwriting')
+        sys.exit(f"{out} already exists; aborting to avoid overwriting")
 
     gha_expressions_found = False
 
-    for p in Path('.github/workflows').iterdir():
+    for p in Path(".github/workflows").iterdir():
         with open(p, "rb") as f:
             workflow = yaml.safe_load(f)
 
-        for job_name, job in workflow['jobs'].items():
+        for job_name, job in workflow["jobs"].items():
             job_dir = out / p / job_name
             if "steps" not in job:
                 continue
-            steps = job['steps']
+            steps = job["steps"]
             index_chars = len(str(len(steps) - 1))
             for i, step in enumerate(steps, start=1):
                 extracted = extract(step)
                 if extracted:
-                    script = extracted['script']
-                    step_name = step.get('name', '')
-                    if '${{' in script:
+                    script = extracted["script"]
+                    step_name = step.get("name", "")
+                    if "${{" in script:
                         gha_expressions_found = True
                         print(
-                            f'{p} job `{job_name}` step {i}: {step_name}',
-                            file=sys.stderr
+                            f"{p} job `{job_name}` step {i}: {step_name}",
+                            file=sys.stderr,
                         )
 
                     job_dir.mkdir(parents=True, exist_ok=True)
 
                     sanitized = re.sub(
-                        '[^a-zA-Z_]+', '_',
-                        f'_{step_name}',
-                    ).rstrip('_')
-                    extension = extracted['extension']
-                    filename = f'{i:0{index_chars}}{sanitized}{extension}'
+                        "[^a-zA-Z_]+",
+                        "_",
+                        f"_{step_name}",
+                    ).rstrip("_")
+                    extension = extracted["extension"]
+                    filename = f"{i:0{index_chars}}{sanitized}{extension}"
                     (job_dir / filename).write_text(script)
 
     if gha_expressions_found:
         sys.exit(
-            'Each of the above scripts contains a GitHub Actions '
-            '${{ <expression> }} which must be replaced with an `env` variable'
-            ' for security reasons.'
+            "Each of the above scripts contains a GitHub Actions "
+            "${{ <expression> }} which must be replaced with an `env` variable"
+            " for security reasons."
         )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py
index f1bb4fa..0a1ae07c 100755
--- a/tools/fast_nvcc/fast_nvcc.py
+++ b/tools/fast_nvcc/fast_nvcc.py
@@ -14,12 +14,11 @@
 import subprocess
 import sys
 import time
-from typing import (Awaitable, DefaultDict, Dict, List, Match, Optional, Set,
-                    cast)
+from typing import Awaitable, DefaultDict, Dict, List, Match, Optional, Set, cast
 
 from typing_extensions import TypedDict
 
-help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]...
+help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
 
 Run the commands given by nvcc --dryrun, in parallel.
 
@@ -31,61 +30,61 @@
 instance passing --help (after "--") doesn't work since the --help
 execution path doesn't compile anything, so adding --dryrun there gives
 nothing in stderr.
-'''
+"""
 parser = argparse.ArgumentParser(help_msg)
 parser.add_argument(
-    '--faithful',
-    action='store_true',
+    "--faithful",
+    action="store_true",
     help="don't modify the commands given by nvcc (slower)",
 )
 parser.add_argument(
-    '--graph',
-    metavar='FILE.gv',
-    help='write Graphviz DOT file with execution graph',
+    "--graph",
+    metavar="FILE.gv",
+    help="write Graphviz DOT file with execution graph",
 )
 parser.add_argument(
-    '--nvcc',
-    metavar='PATH',
-    default='nvcc',
+    "--nvcc",
+    metavar="PATH",
+    default="nvcc",
     help='path to nvcc (default is just "nvcc")',
 )
 parser.add_argument(
-    '--save',
-    metavar='DIR',
-    help='copy intermediate files from each command into DIR',
+    "--save",
+    metavar="DIR",
+    help="copy intermediate files from each command into DIR",
 )
 parser.add_argument(
-    '--sequential',
-    action='store_true',
-    help='sequence commands instead of using the graph (slower)',
+    "--sequential",
+    action="store_true",
+    help="sequence commands instead of using the graph (slower)",
 )
 parser.add_argument(
-    '--table',
-    metavar='FILE.csv',
-    help='write CSV with times and intermediate file sizes',
+    "--table",
+    metavar="FILE.csv",
+    help="write CSV with times and intermediate file sizes",
 )
 parser.add_argument(
-    '--verbose',
-    metavar='FILE.txt',
-    help='like nvcc --verbose, but expanded and into a file',
+    "--verbose",
+    metavar="FILE.txt",
+    help="like nvcc --verbose, but expanded and into a file",
 )
 default_config = parser.parse_args([])
 
 
 # docs about temporary directories used by NVCC
-url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html'
-url_vars = f'{url_base}#keeping-intermediate-phase-files'
+url_base = "https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html"
+url_vars = f"{url_base}#keeping-intermediate-phase-files"
 
 
 # regex for temporary file names
-re_tmp = r'(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)'
+re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
 
 
 def fast_nvcc_warn(warning: str) -> None:
     """
     Warn the user about something regarding fast_nvcc.
     """
-    print(f'warning (fast_nvcc): {warning}', file=sys.stderr)
+    print(f"warning (fast_nvcc): {warning}", file=sys.stderr)
 
 
 def warn_if_windows() -> None:
@@ -95,7 +94,7 @@
     # use os.name instead of platform.system() because there is a
     # platform.py file in this directory, making it very difficult to
     # import the platform module from the Python standard library
-    if os.name == 'nt':
+    if os.name == "nt":
         fast_nvcc_warn("untested on Windows, might not work; see this URL:")
         fast_nvcc_warn(url_vars)
 
@@ -104,24 +103,24 @@
     """
     Warn the user that using fast_nvcc with some flags might not work.
     """
-    file_path_specs = 'file-and-path-specifications'
-    guiding_driver = 'options-for-guiding-compiler-driver'
+    file_path_specs = "file-and-path-specifications"
+    guiding_driver = "options-for-guiding-compiler-driver"
     scary_flags = {
-        '--objdir-as-tempdir': file_path_specs,
-        '-objtemp': file_path_specs,
-        '--keep': guiding_driver,
-        '-keep': guiding_driver,
-        '--keep-dir': guiding_driver,
-        '-keep-dir': guiding_driver,
-        '--save-temps': guiding_driver,
-        '-save-temps': guiding_driver,
+        "--objdir-as-tempdir": file_path_specs,
+        "-objtemp": file_path_specs,
+        "--keep": guiding_driver,
+        "-keep": guiding_driver,
+        "--keep-dir": guiding_driver,
+        "-keep-dir": guiding_driver,
+        "--save-temps": guiding_driver,
+        "-save-temps": guiding_driver,
     }
     for arg in args:
         for flag, frag in scary_flags.items():
-            if re.match(fr'^{re.escape(flag)}(?:=.*)?$', arg):
-                fast_nvcc_warn(f'{flag} not supported since it interacts with')
-                fast_nvcc_warn('TMPDIR, so fast_nvcc may break; see this URL:')
-                fast_nvcc_warn(f'{url_base}#{frag}')
+            if re.match(rf"^{re.escape(flag)}(?:=.*)?$", arg):
+                fast_nvcc_warn(f"{flag} not supported since it interacts with")
+                fast_nvcc_warn("TMPDIR, so fast_nvcc may break; see this URL:")
+                fast_nvcc_warn(f"{url_base}#{frag}")
 
 
 class DryunData(TypedDict):
@@ -135,18 +134,18 @@
     Return parsed environment variables and commands from nvcc --dryrun.
     """
     result = subprocess.run(  # type: ignore[call-overload]
-        [binary, '--dryrun'] + args,
+        [binary, "--dryrun"] + args,
         capture_output=True,
-        encoding='ascii',  # this is just a guess
+        encoding="ascii",  # this is just a guess
     )
-    print(result.stdout, end='')
+    print(result.stdout, end="")
     env = {}
     commands = []
     for line in result.stderr.splitlines():
-        match = re.match(r'^#\$ (.*)$', line)
+        match = re.match(r"^#\$ (.*)$", line)
         if match:
-            stripped, = match.groups()
-            mapping = re.match(r'^(\w+)=(.*)$', stripped)
+            (stripped,) = match.groups()
+            mapping = re.match(r"^(\w+)=(.*)$", stripped)
             if mapping:
                 name, val = mapping.groups()
                 env[name] = val
@@ -154,14 +153,14 @@
                 commands.append(stripped)
         else:
             print(line, file=sys.stderr)
-    return {'env': env, 'commands': commands, 'exit_code': result.returncode}
+    return {"env": env, "commands": commands, "exit_code": result.returncode}
 
 
 def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
     """
     Warn the user that setting TMPDIR with fast_nvcc might not work.
     """
-    if os.getenv('TMPDIR') or 'TMPDIR' in env:
+    if os.getenv("TMPDIR") or "TMPDIR" in env:
         fast_nvcc_warn("TMPDIR is set, might not work; see this URL:")
         fast_nvcc_warn(url_vars)
 
@@ -183,17 +182,17 @@
     """
     Guess the contents of the .module_id file contained within command.
     """
-    if command[0] == 'cicc':
+    if command[0] == "cicc":
         path = command[-3]
-    elif command[0] == 'cudafe++':
+    elif command[0] == "cudafe++":
         path = command[-1]
-    middle = pathlib.PurePath(path).name.replace('-', '_').replace('.', '_')
+    middle = pathlib.PurePath(path).name.replace("-", "_").replace(".", "_")
     # this suffix is very wrong (the real one is far less likely to be
     # unique), but it seems difficult to find a rule that reproduces the
     # real suffixes, so here's one that, while inaccurate, is at least
     # hopefully as straightforward as possible
     suffix = hashlib.md5(str.encode(middle)).hexdigest()[:8]
-    return f'_{len(middle)}_{middle}_{suffix}'
+    return f"_{len(middle)}_{middle}_{suffix}"
 
 
 def unique_module_id_files(commands: List[str]) -> List[str]:
@@ -206,14 +205,14 @@
         arr = []
 
         def uniqueify(s: Match[str]) -> str:
-            filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0))
+            filename = re.sub(r"\-(\d+)", r"-\1-" + str(i), s.group(0))
             arr.append(filename)
             return filename
 
-        line = re.sub(re_tmp + r'.module_id', uniqueify, line)
-        line = re.sub(r'\s*\-\-gen\_module\_id\_file\s*', ' ', line)
+        line = re.sub(re_tmp + r".module_id", uniqueify, line)
+        line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
         if arr:
-            filename, = arr
+            (filename,) = arr
             if not module_id:
                 module_id = module_id_contents(shlex.split(line))
             uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
@@ -225,7 +224,7 @@
     """
     Add --force to all rm commands.
     """
-    return [f'{c} --force' if c.startswith('rm ') else c for c in commands]
+    return [f"{c} --force" if c.startswith("rm ") else c for c in commands]
 
 
 def print_verbose_output(
@@ -238,12 +237,12 @@
     Human-readably write nvcc --dryrun data to stderr.
     """
     padding = len(str(len(commands) - 1))
-    with open(filename, 'w') as f:
+    with open(filename, "w") as f:
         for name, val in env.items():
             print(f'#{" "*padding}$ {name}={val}', file=f)
         for i, command in enumerate(commands):
-            prefix = f'{str(i).rjust(padding)}$ '
-            print(f'#{prefix}{command[0]}', file=f)
+            prefix = f"{str(i).rjust(padding)}$ "
+            print(f"#{prefix}{command[0]}", file=f)
             for part in command[1:]:
                 print(f'#{" "*len(prefix)}{part}', file=f)
 
@@ -262,7 +261,7 @@
     """
     Return fully-qualified names of all tmp files referenced by command.
     """
-    return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)]
+    return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
 
 
 def nvcc_data_dependencies(commands: List[str]) -> Graph:
@@ -291,11 +290,11 @@
                     for filename in fatbins[dep]:
                         if filename in tmp_files:
                             deps.add(tmp_files[filename])
-            if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'):
+            if tmp.endswith(".fatbin.c") and not line.startswith("fatbinary"):
                 fatbins[i].add(tmp)
             else:
                 tmp_files[tmp] = i
-        if line.startswith('rm ') and not deps:
+        if line.startswith("rm ") and not deps:
             deps.add(i - 1)
         graph.append(deps)
     return graph
@@ -329,7 +328,7 @@
     Warn the user if the execution graph is not weakly connected.
     """
     if not is_weakly_connected(graph):
-        fast_nvcc_warn('execution graph is not (weakly) connected')
+        fast_nvcc_warn("execution graph is not (weakly) connected")
 
 
 def print_dot_graph(
@@ -341,18 +340,19 @@
     """
     Print a DOT file displaying short versions of the commands in graph.
     """
+
     def name(k: int) -> str:
         return f'"{k} {os.path.basename(commands[k][0])}"'
-    with open(filename, 'w') as f:
-        print('digraph {', file=f)
+
+    with open(filename, "w") as f:
+        print("digraph {", file=f)
         # print all nodes, in case it's disconnected
         for i in range(len(graph)):
-            print(f'    {name(i)};', file=f)
+            print(f"    {name(i)};", file=f)
         for i, deps in enumerate(graph):
             for j in deps:
-                print(f'    {name(j)} -> {name(i)};', file=f)
-        print('}', file=f)
-
+                print(f"    {name(j)} -> {name(i)};", file=f)
+        print("}", file=f)
 
 
 class Result(TypedDict, total=False):
@@ -378,7 +378,7 @@
     for task in deps:
         dep_result = await task
         # abort if a previous step failed
-        if 'exit_code' not in dep_result or dep_result['exit_code'] != 0:
+        if "exit_code" not in dep_result or dep_result["exit_code"] != 0:
             return {}
     if gather_data:
         t1 = time.monotonic()
@@ -390,17 +390,17 @@
     )
     stdout, stderr = await proc.communicate()
     code = cast(int, proc.returncode)
-    results: Result = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
+    results: Result = {"exit_code": code, "stdout": stdout, "stderr": stderr}
     if gather_data:
         t2 = time.monotonic()
-        results['time'] = t2 - t1
+        results["time"] = t2 - t1
         sizes = {}
         for tmp_file in files_mentioned(command):
             if os.path.exists(tmp_file):
                 sizes[tmp_file] = os.path.getsize(tmp_file)
             else:
                 sizes[tmp_file] = 0
-        results['files'] = sizes
+        results["files"] = sizes
     if save:
         dest = pathlib.Path(save) / str(i)
         dest.mkdir()
@@ -424,14 +424,18 @@
     tasks: List[Awaitable[Result]] = []
     for i, (command, indices) in enumerate(zip(commands, graph)):
         deps = {tasks[j] for j in indices}
-        tasks.append(asyncio.create_task(run_command(  # type: ignore[attr-defined]
-            command,
-            env=env,
-            deps=deps,
-            gather_data=gather_data,
-            i=i,
-            save=save,
-        )))
+        tasks.append(
+            asyncio.create_task(
+                run_command(  # type: ignore[attr-defined]
+                    command,
+                    env=env,
+                    deps=deps,
+                    gather_data=gather_data,
+                    i=i,
+                    save=save,
+                )
+            )
+        )
     return [await task for task in tasks]
 
 
@@ -440,8 +444,8 @@
     Print captured stdout and stderr from commands.
     """
     for result in command_results:
-        sys.stdout.write(result.get('stdout', b'').decode('ascii'))
-        sys.stderr.write(result.get('stderr', b'').decode('ascii'))
+        sys.stdout.write(result.get("stdout", b"").decode("ascii"))
+        sys.stderr.write(result.get("stderr", b"").decode("ascii"))
 
 
 def write_log_csv(
@@ -455,15 +459,15 @@
     """
     tmp_files: List[str] = []
     for result in command_results:
-        tmp_files.extend(result.get('files', {}).keys())
-    with open(filename, 'w', newline='') as csvfile:
-        fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files))
+        tmp_files.extend(result.get("files", {}).keys())
+    with open(filename, "w", newline="") as csvfile:
+        fieldnames = ["command", "seconds"] + list(dict.fromkeys(tmp_files))
         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
         writer.writeheader()
         for i, result in enumerate(command_results):
-            command = f'{i} {os.path.basename(command_parts[i][0])}'
-            row = {'command': command, 'seconds': result.get('time', 0)}
-            writer.writerow({**row, **result.get('files', {})})
+            command = f"{i} {os.path.basename(command_parts[i][0])}"
+            row = {"command": command, "seconds": result.get("time", 0)}
+            writer.writerow({**row, **result.get("files", {})})
 
 
 def exit_code(results: List[Result]) -> int:
@@ -471,7 +475,7 @@
     Aggregate individual exit codes into a single code.
     """
     for result in results:
-        code = result.get('exit_code', 0)
+        code = result.get("exit_code", 0)
         if code != 0:
             return code
     return 0
@@ -497,9 +501,9 @@
     warn_if_windows()
     warn_if_tmpdir_flag(args)
     dryrun_data = nvcc_dryrun_data(config.nvcc, args)
-    env = dryrun_data['env']
+    env = dryrun_data["env"]
     warn_if_tmpdir_set(env)
-    commands = dryrun_data['commands']
+    commands = dryrun_data["commands"]
     if not config.faithful:
         commands = make_rm_force(unique_module_id_files(commands))
 
@@ -523,13 +527,15 @@
         )
     if config.sequential:
         graph = straight_line_dependencies(commands)
-    results = asyncio.run(run_graph(  # type: ignore[attr-defined]
-        env=env,
-        commands=commands,
-        graph=graph,
-        gather_data=bool(config.table),
-        save=config.save,
-    ))
+    results = asyncio.run(
+        run_graph(  # type: ignore[attr-defined]
+            env=env,
+            commands=commands,
+            graph=graph,
+            gather_data=bool(config.table),
+            save=config.save,
+        )
+    )
     print_command_outputs(results)
     if config.table:
         write_log_csv(command_parts, results, filename=config.table)
@@ -537,10 +543,10 @@
 
 
 def our_arg(arg: str) -> bool:
-    return arg != '--'
+    return arg != "--"
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     argv = sys.argv[1:]
     us = list(itertools.takewhile(our_arg, argv))
     them = list(itertools.dropwhile(our_arg, argv))
diff --git a/tools/gdb/pytorch-gdb.py b/tools/gdb/pytorch-gdb.py
index 46cdcde..0ed5160 100644
--- a/tools/gdb/pytorch-gdb.py
+++ b/tools/gdb/pytorch-gdb.py
@@ -2,6 +2,7 @@
 import textwrap
 from typing import Any
 
+
 class DisableBreakpoints:
     """
     Context-manager to temporarily disable all gdb breakpoints, useful if
@@ -20,6 +21,7 @@
         for b in self.disabled_breakpoints:
             b.enabled = True
 
+
 class TensorRepr(gdb.Command):  # type: ignore[misc, no-any-unimported]
     """
     Print a human readable representation of the given at::Tensor.
@@ -30,23 +32,26 @@
     internally creates a Python wrapper for the given tensor and call repr()
     on it.
     """
+
     __doc__ = textwrap.dedent(__doc__).strip()
 
     def __init__(self) -> None:
-        gdb.Command.__init__(self, 'torch-tensor-repr',
-                             gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION)
+        gdb.Command.__init__(
+            self, "torch-tensor-repr", gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION
+        )
 
     def invoke(self, args: str, from_tty: bool) -> None:
         args = gdb.string_to_argv(args)
         if len(args) != 1:
-            print('Usage: torch-tensor-repr EXP')
+            print("Usage: torch-tensor-repr EXP")
             return
         name = args[0]
         with DisableBreakpoints():
-            res = gdb.parse_and_eval('torch::gdb::tensor_repr(%s)' % name)
-            print('Python-level repr of %s:' % name)
+            res = gdb.parse_and_eval("torch::gdb::tensor_repr(%s)" % name)
+            print("Python-level repr of %s:" % name)
             print(res.string())
             # torch::gdb::tensor_repr returns a malloc()ed buffer, let's free it
-            gdb.parse_and_eval('(void)free(%s)' % int(res))
+            gdb.parse_and_eval("(void)free(%s)" % int(res))
+
 
 TensorRepr()
diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py
index 2ee17b7..e47c61f 100644
--- a/tools/generate_torch_version.py
+++ b/tools/generate_torch_version.py
@@ -5,46 +5,59 @@
 from setuptools import distutils  # type: ignore[import]
 from typing import Optional, Union
 
+
 def get_sha(pytorch_root: Union[str, Path]) -> str:
     try:
-        return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=pytorch_root).decode('ascii').strip()
+        return (
+            subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=pytorch_root)
+            .decode("ascii")
+            .strip()
+        )
     except Exception:
-        return 'Unknown'
+        return "Unknown"
+
 
 def get_torch_version(sha: Optional[str] = None) -> str:
     pytorch_root = Path(__file__).parent.parent
-    version = open(pytorch_root / 'version.txt', 'r').read().strip()
+    version = open(pytorch_root / "version.txt", "r").read().strip()
 
-    if os.getenv('PYTORCH_BUILD_VERSION'):
-        assert os.getenv('PYTORCH_BUILD_NUMBER') is not None
-        build_number = int(os.getenv('PYTORCH_BUILD_NUMBER', ""))
-        version = os.getenv('PYTORCH_BUILD_VERSION', "")
+    if os.getenv("PYTORCH_BUILD_VERSION"):
+        assert os.getenv("PYTORCH_BUILD_NUMBER") is not None
+        build_number = int(os.getenv("PYTORCH_BUILD_NUMBER", ""))
+        version = os.getenv("PYTORCH_BUILD_VERSION", "")
         if build_number > 1:
-            version += '.post' + str(build_number)
-    elif sha != 'Unknown':
+            version += ".post" + str(build_number)
+    elif sha != "Unknown":
         if sha is None:
             sha = get_sha(pytorch_root)
-        version += '+git' + sha[:7]
+        version += "+git" + sha[:7]
     return version
 
+
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Generate torch/version.py from build and environment metadata.")
-    parser.add_argument("--is_debug", type=distutils.util.strtobool, help="Whether this build is debug mode or not.")
+    parser = argparse.ArgumentParser(
+        description="Generate torch/version.py from build and environment metadata."
+    )
+    parser.add_argument(
+        "--is_debug",
+        type=distutils.util.strtobool,
+        help="Whether this build is debug mode or not.",
+    )
     parser.add_argument("--cuda_version", type=str)
     parser.add_argument("--hip_version", type=str)
 
     args = parser.parse_args()
 
     assert args.is_debug is not None
-    args.cuda_version = None if args.cuda_version == '' else args.cuda_version
-    args.hip_version = None if args.hip_version == '' else args.hip_version
+    args.cuda_version = None if args.cuda_version == "" else args.cuda_version
+    args.hip_version = None if args.hip_version == "" else args.hip_version
 
     pytorch_root = Path(__file__).parent.parent
     version_path = pytorch_root / "torch" / "version.py"
     sha = get_sha(pytorch_root)
     version = get_torch_version(sha)
 
-    with open(version_path, 'w') as f:
+    with open(version_path, "w") as f:
         f.write("__version__ = '{}'\n".format(version))
         # NB: This is not 100% accurate, because you could have built the
         # library code with DEBUG, but csrc without DEBUG (in which case
diff --git a/tools/iwyu/fixup.py b/tools/iwyu/fixup.py
index b4d6294..4ce80bb 100644
--- a/tools/iwyu/fixup.py
+++ b/tools/iwyu/fixup.py
@@ -2,7 +2,7 @@
 import re
 
 QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
-ANGLE_INCLUDE_RE = re.compile(r'^#include <(.*)>')
+ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
 
 # By default iwyu will pick the C include, but we prefer the C++ headers
 STD_C_HEADER_MAP = {
@@ -34,25 +34,27 @@
     "<wctype.h>": "<cwctype>",
 }
 
+
 def main() -> None:
     for line in sys.stdin:
         # Convert all quoted includes to angle brackets
         match = QUOTE_INCLUDE_RE.match(line)
         if match is not None:
-            print(f"#include <{match.group(1)}>{line[match.end(0):]}", end='')
+            print(f"#include <{match.group(1)}>{line[match.end(0):]}", end="")
             continue
 
         match = ANGLE_INCLUDE_RE.match(line)
         if match is not None:
             path = f"<{match.group(1)}>"
             new_path = STD_C_HEADER_MAP.get(path, path)
-            tail = line[match.end(0):]
+            tail = line[match.end(0) :]
             if len(tail) > 1:
-                tail = ' ' + tail
-            print(f"#include {new_path}{tail}", end='')
+                tail = " " + tail
+            print(f"#include {new_path}{tail}", end="")
             continue
 
-        print(line, end='')
+        print(line, end="")
+
 
 if __name__ == "__main__":
     main()
diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py
index 976cf3e..e5ed608 100644
--- a/tools/jit/gen_unboxing.py
+++ b/tools/jit/gen_unboxing.py
@@ -56,7 +56,9 @@
             arg_connector = ", "
             # function call and push back to stack
             prefix = "self_base." if sig.method else "at::"
-            translated_args = translate(binding_list, sig.arguments(), method=sig.method)
+            translated_args = translate(
+                binding_list, sig.arguments(), method=sig.method
+            )
             args_str = f"{arg_connector.join(e.expr for e in translated_args)}"
             if len(f.func.returns) == 0:
                 ret_str = ""
@@ -89,9 +91,7 @@
         if not self.selector.is_root_operator(f"aten::{f.func.name}"):
             return ""
         # We unconditionally generate function wrappers,
-        sig_group = CppSignatureGroup.from_native_function(
-            f, method=False
-        )
+        sig_group = CppSignatureGroup.from_native_function(f, method=False)
 
         sig = sig_group.most_faithful_signature()
 
@@ -105,11 +105,13 @@
         for arg in args:
             if not arg.default:
                 arg_cpp = "c10::IValue(c10::nullopt)"
-            elif arg.default.startswith('{'):
+            elif arg.default.startswith("{"):
                 arg_cpp = f"c10::IntArrayRef({arg.default})"
             else:
                 arg_cpp = f"c10::IValue({arg.default})"
-            args_code.append(f"""c10::Argument("{arg.name}", nullptr, c10::nullopt, {arg_cpp})""")
+            args_code.append(
+                f"""c10::Argument("{arg.name}", nullptr, c10::nullopt, {arg_cpp})"""
+            )
 
         returns = f.func.returns
         returns_code = []
@@ -136,10 +138,10 @@
 
 
 def gen_unboxing(
-        *,
-        native_functions: Sequence[NativeFunction],
-        cpu_fm: FileManager,
-        selector: SelectiveBuilder,
+    *,
+    native_functions: Sequence[NativeFunction],
+    cpu_fm: FileManager,
+    selector: SelectiveBuilder,
 ) -> None:
     def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
         return fn.root_name
@@ -158,7 +160,10 @@
         "UnboxingFunctions.h",
         lambda: {
             "declarations": list(
-                mapMaybe(ComputeUnboxingFunctions(Target.DECLARATION, selector), native_functions)
+                mapMaybe(
+                    ComputeUnboxingFunctions(Target.DECLARATION, selector),
+                    native_functions,
+                )
             ),
         },
     )
@@ -166,7 +171,9 @@
         "RegisterCodegenUnboxedKernels.cpp",
         native_functions,
         key_fn=key_func,
-        env_callable=lambda fn: {"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]},
+        env_callable=lambda fn: {
+            "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
+        },
         num_shards=10,
         sharded_keys={"unboxed_ops"},
     )
@@ -184,19 +191,23 @@
         "-d", "--install_dir", help="output directory", default="build/aten/src/ATen"
     )
     parser.add_argument(
-        '-o',
-        '--output-dependencies',
-        help='output a list of dependencies into the given file and exit')
+        "-o",
+        "--output-dependencies",
+        help="output a list of dependencies into the given file and exit",
+    )
     parser.add_argument(
-        '--dry-run', action='store_true',
-        help='run without writing any files (still updates outputs)')
+        "--dry-run",
+        action="store_true",
+        help="run without writing any files (still updates outputs)",
+    )
     parser.add_argument(
-        '--op_selection_yaml_path',
-        help='Provide a path to the operator selection (for custom build) YAML '
-             'that contains the information about the set of selected operators '
-             'and their categories (training, ...). Each operator is either a '
-             'full operator name with overload or just a bare operator name. '
-             'The operator names also contain the namespace prefix (e.g. aten::)')
+        "--op_selection_yaml_path",
+        help="Provide a path to the operator selection (for custom build) YAML "
+        "that contains the information about the set of selected operators "
+        "and their categories (training, ...). Each operator is either a "
+        "full operator name with overload or just a bare operator name. "
+        "The operator names also contain the namespace prefix (e.g. aten::)",
+    )
 
     options = parser.parse_args()
 
diff --git a/tools/linter/adapters/actionlint_linter.py b/tools/linter/adapters/actionlint_linter.py
index b785829..bbc9395 100644
--- a/tools/linter/adapters/actionlint_linter.py
+++ b/tools/linter/adapters/actionlint_linter.py
@@ -43,6 +43,7 @@
     """
 )
 
+
 def run_command(
     args: List[str],
 ) -> "subprocess.CompletedProcess[bytes]":
@@ -64,9 +65,7 @@
     files: List[str],
 ) -> List[LintMessage]:
     try:
-        proc = run_command(
-            [binary] + files
-        )
+        proc = run_command([binary] + files)
     except OSError as err:
         return [
             LintMessage(
diff --git a/tools/linter/adapters/circleci_linter.py b/tools/linter/adapters/circleci_linter.py
index 49bedde..8a76ed3 100644
--- a/tools/linter/adapters/circleci_linter.py
+++ b/tools/linter/adapters/circleci_linter.py
@@ -51,7 +51,11 @@
     start_time = time.monotonic()
     try:
         return subprocess.run(
-            args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True,
+            args,
+            cwd=cwd,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            check=True,
         )
     finally:
         end_time = time.monotonic()
@@ -117,10 +121,13 @@
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description="circleci consistency linter", fromfile_prefix_chars="@",
+        description="circleci consistency linter",
+        fromfile_prefix_chars="@",
     )
     parser.add_argument(
-        "--config-yml", required=True, help="location of config.yml",
+        "--config-yml",
+        required=True,
+        help="location of config.yml",
     )
     parser.add_argument(
         "--regen-script-working-dir",
@@ -133,7 +140,9 @@
         help="location of the config generation script, relative to --regen-script-working-dir",
     )
     parser.add_argument(
-        "--verbose", action="store_true", help="verbose logging",
+        "--verbose",
+        action="store_true",
+        help="verbose logging",
     )
 
     args = parser.parse_args()
diff --git a/tools/linter/adapters/flake8_linter.py b/tools/linter/adapters/flake8_linter.py
index 62c830a..2027443 100644
--- a/tools/linter/adapters/flake8_linter.py
+++ b/tools/linter/adapters/flake8_linter.py
@@ -362,7 +362,9 @@
             assert len(parts) == 2, f"invalid severity `{severity}`"
             severities[parts[0]] = LintSeverity(parts[1])
 
-    lint_messages = check_files(args.filenames, flake8_plugins_path, severities, args.retries)
+    lint_messages = check_files(
+        args.filenames, flake8_plugins_path, severities, args.retries
+    )
     for lint_message in lint_messages:
         print(json.dumps(lint_message._asdict()), flush=True)
 
diff --git a/tools/linter/adapters/grep_linter.py b/tools/linter/adapters/grep_linter.py
index d894305..847d0f6 100644
--- a/tools/linter/adapters/grep_linter.py
+++ b/tools/linter/adapters/grep_linter.py
@@ -43,11 +43,17 @@
     return name.replace("\\", "/") if IS_WINDOWS else name
 
 
-def run_command(args: List[str],) -> "subprocess.CompletedProcess[bytes]":
+def run_command(
+    args: List[str],
+) -> "subprocess.CompletedProcess[bytes]":
     logging.debug("$ %s", " ".join(args))
     start_time = time.monotonic()
     try:
-        return subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE,)
+        return subprocess.run(
+            args,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
     finally:
         end_time = time.monotonic()
         logging.debug("took %dms", (end_time - start_time) * 1000)
@@ -116,13 +122,18 @@
 
 def main() -> None:
     parser = argparse.ArgumentParser(
-        description="grep wrapper linter.", fromfile_prefix_chars="@",
+        description="grep wrapper linter.",
+        fromfile_prefix_chars="@",
     )
     parser.add_argument(
-        "--pattern", required=True, help="pattern to grep for",
+        "--pattern",
+        required=True,
+        help="pattern to grep for",
     )
     parser.add_argument(
-        "--linter-name", required=True, help="name of the linter",
+        "--linter-name",
+        required=True,
+        help="name of the linter",
     )
     parser.add_argument(
         "--error-name",
@@ -142,10 +153,14 @@
         ),
     )
     parser.add_argument(
-        "--verbose", action="store_true", help="verbose logging",
+        "--verbose",
+        action="store_true",
+        help="verbose logging",
     )
     parser.add_argument(
-        "filenames", nargs="+", help="paths to lint",
+        "filenames",
+        nargs="+",
+        help="paths to lint",
     )
     args = parser.parse_args()
 
diff --git a/tools/linter/adapters/nativefunctions_linter.py b/tools/linter/adapters/nativefunctions_linter.py
index ddc9a18..28065f2 100644
--- a/tools/linter/adapters/nativefunctions_linter.py
+++ b/tools/linter/adapters/nativefunctions_linter.py
@@ -44,7 +44,8 @@
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description="native functions linter", fromfile_prefix_chars="@",
+        description="native functions linter",
+        fromfile_prefix_chars="@",
     )
     parser.add_argument(
         "--native-functions-yml",
diff --git a/tools/linter/adapters/newlines_linter.py b/tools/linter/adapters/newlines_linter.py
index 5ce5edc..16e22e9 100644
--- a/tools/linter/adapters/newlines_linter.py
+++ b/tools/linter/adapters/newlines_linter.py
@@ -109,13 +109,18 @@
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description="native functions linter", fromfile_prefix_chars="@",
+        description="native functions linter",
+        fromfile_prefix_chars="@",
     )
     parser.add_argument(
-        "--verbose", action="store_true", help="location of native_functions.yaml",
+        "--verbose",
+        action="store_true",
+        help="location of native_functions.yaml",
     )
     parser.add_argument(
-        "filenames", nargs="+", help="paths to lint",
+        "filenames",
+        nargs="+",
+        help="paths to lint",
     )
 
     args = parser.parse_args()
diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py
index 10fdcea..be76e1c 100644
--- a/tools/linter/adapters/pip_init.py
+++ b/tools/linter/adapters/pip_init.py
@@ -23,12 +23,18 @@
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="pip initializer")
     parser.add_argument(
-        "packages", nargs="+", help="pip packages to install",
+        "packages",
+        nargs="+",
+        help="pip packages to install",
     )
     parser.add_argument(
-        "--verbose", action="store_true", help="verbose logging",
+        "--verbose",
+        action="store_true",
+        help="verbose logging",
     )
-    parser.add_argument("--dry-run", help="do not install anything, just print what would be done.")
+    parser.add_argument(
+        "--dry-run", help="do not install anything, just print what would be done."
+    )
 
     args = parser.parse_args()
 
diff --git a/tools/linter/clang_format_all.py b/tools/linter/clang_format_all.py
index 2a5f937..fccf880 100755
--- a/tools/linter/clang_format_all.py
+++ b/tools/linter/clang_format_all.py
@@ -25,7 +25,7 @@
     "torch/csrc/jit/",
     "torch/csrc/deploy/",
     "test/cpp/jit/",
-    "test/cpp/tensorexpr/"
+    "test/cpp/tensorexpr/",
 ]
 
 CLANG_FORMAT_BLOCK_LIST = {
@@ -37,7 +37,6 @@
 CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp|m|mm)$")
 
 
-
 def get_allowlisted_files() -> Set[str]:
     """
     Parse CLANG_FORMAT_ALLOWLIST and resolve all directories.
@@ -85,7 +84,9 @@
     cmd = "{} -style=file {}".format(CLANG_FORMAT_PATH, filename)
 
     async with semaphore:
-        proc = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE)
+        proc = await asyncio.create_subprocess_shell(
+            cmd, stdout=asyncio.subprocess.PIPE
+        )
         # Read back the formatted file.
         stdout, _ = await proc.communicate()
 
@@ -127,7 +128,12 @@
 
     # Format files in parallel.
     if diff:
-        for f in asyncio.as_completed([file_clang_formatted_correctly(f, semaphore, verbose) for f in get_allowlisted_files()]):
+        for f in asyncio.as_completed(
+            [
+                file_clang_formatted_correctly(f, semaphore, verbose)
+                for f in get_allowlisted_files()
+            ]
+        ):
             ok &= await f
 
         if ok:
@@ -135,10 +141,16 @@
         else:
             print("Some files not formatted correctly")
     else:
-        await asyncio.gather(*[run_clang_format_on_file(f, semaphore, verbose) for f in get_allowlisted_files()])
+        await asyncio.gather(
+            *[
+                run_clang_format_on_file(f, semaphore, verbose)
+                for f in get_allowlisted_files()
+            ]
+        )
 
     return ok
 
+
 def parse_args(args: List[str]) -> argparse.Namespace:
     """
     Parse and return command-line arguments.
@@ -154,8 +166,12 @@
         help="Determine whether running clang-format would produce changes",
     )
     parser.add_argument("--verbose", "-v", action="store_true", default=False)
-    parser.add_argument("--max-processes", type=int, default=50,
-                        help="Maximum number of subprocesses to create to format files in parallel")
+    parser.add_argument(
+        "--max-processes",
+        type=int,
+        default=50,
+        help="Maximum number of subprocesses to create to format files in parallel",
+    )
     return parser.parse_args(args)
 
 
@@ -167,7 +183,9 @@
     # Invoke clang-format on all files in the directories in the allowlist.
     if ok:
         loop = asyncio.get_event_loop()
-        ok = loop.run_until_complete(run_clang_format(options.max_processes, options.diff, options.verbose))
+        ok = loop.run_until_complete(
+            run_clang_format(options.max_processes, options.diff, options.verbose)
+        )
 
     # We have to invert because False -> 0, which is the code to be returned if everything is okay.
     return not ok
diff --git a/tools/linter/clang_format_utils.py b/tools/linter/clang_format_utils.py
index 021ba91..56ebc98 100644
--- a/tools/linter/clang_format_utils.py
+++ b/tools/linter/clang_format_utils.py
@@ -10,11 +10,16 @@
 # This dictionary maps each platform to a relative path to a file containing its reference hash.
 PLATFORM_TO_HASH = {
     "Darwin": os.path.join("tools", "clang_format_hash", "mac", "clang-format-mojave"),
-    "Linux": os.path.join("tools", "clang_format_hash", "linux64", "clang-format-linux64"),
+    "Linux": os.path.join(
+        "tools", "clang_format_hash", "linux64", "clang-format-linux64"
+    ),
 }
 
 CLANG_FORMAT_DIR = os.path.join(PYTORCH_ROOT, ".clang-format-bin")
 CLANG_FORMAT_PATH = os.path.join(CLANG_FORMAT_DIR, "clang-format")
 
+
 def get_and_check_clang_format(verbose: bool = False) -> bool:
-    return bool(download("clang-format", CLANG_FORMAT_DIR, PLATFORM_TO_CF_URL, PLATFORM_TO_HASH))
+    return bool(
+        download("clang-format", CLANG_FORMAT_DIR, PLATFORM_TO_CF_URL, PLATFORM_TO_HASH)
+    )
diff --git a/tools/linter/clang_tidy/__main__.py b/tools/linter/clang_tidy/__main__.py
index 18f2da2..f3be4ff 100644
--- a/tools/linter/clang_tidy/__main__.py
+++ b/tools/linter/clang_tidy/__main__.py
@@ -16,7 +16,8 @@
 
 # Returns '/usr/local/include/python<version number>'
 def get_python_include_dir() -> str:
-    return gp()['include']
+    return gp()["include"]
+
 
 def clang_search_dirs() -> List[str]:
     # Compilers are ordered based on fallback preference
@@ -100,8 +101,9 @@
     "include-dir": [
         "/usr/lib/llvm-11/include/openmp",
         get_python_include_dir(),
-        os.path.join(PYTORCH_ROOT, "third_party/pybind11/include")
-    ] + clang_search_dirs(),
+        os.path.join(PYTORCH_ROOT, "third_party/pybind11/include"),
+    ]
+    + clang_search_dirs(),
     "clang-tidy-exe": INSTALLATION_PATH,
     "compile-commands-dir": "build",
     "config-file": ".clang-tidy",
diff --git a/tools/linter/clang_tidy/generate_build_files.py b/tools/linter/clang_tidy/generate_build_files.py
index 95ff98c..dc90339 100644
--- a/tools/linter/clang_tidy/generate_build_files.py
+++ b/tools/linter/clang_tidy/generate_build_files.py
@@ -6,8 +6,15 @@
 
 def run_cmd(cmd: List[str]) -> None:
     print(f"Running: {cmd}")
-    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,)
-    stdout, stderr = result.stdout.decode("utf-8").strip(), result.stderr.decode("utf-8").strip()
+    result = subprocess.run(
+        cmd,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+    )
+    stdout, stderr = (
+        result.stdout.decode("utf-8").strip(),
+        result.stderr.decode("utf-8").strip(),
+    )
     print(stdout)
     print(stderr)
     if result.returncode != 0:
diff --git a/tools/linter/install/download_bin.py b/tools/linter/install/download_bin.py
index 3bb65ba..1f18d8f 100644
--- a/tools/linter/install/download_bin.py
+++ b/tools/linter/install/download_bin.py
@@ -14,7 +14,9 @@
 
 # PyTorch directory root
 result = subprocess.run(
-    ["git", "rev-parse", "--show-toplevel"], stdout=subprocess.PIPE, check=True,
+    ["git", "rev-parse", "--show-toplevel"],
+    stdout=subprocess.PIPE,
+    check=True,
 )
 PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
 
@@ -96,11 +98,16 @@
         try:
             os.mkdir(output_dir)
         except OSError as e:
-            print(f"Unable to create directory for {name} binary: {output_dir}", file=sys.stderr)
+            print(
+                f"Unable to create directory for {name} binary: {output_dir}",
+                file=sys.stderr,
+            )
             return False
         finally:
             if verbose:
-                print(f"Created directory {output_dir} for {name} binary", file=sys.stderr)
+                print(
+                    f"Created directory {output_dir} for {name} binary", file=sys.stderr
+                )
 
         # If the directory didn't exist, neither did the binary, so download it.
         ok = download_bin(name, output_dir, platform_to_url)
@@ -116,7 +123,10 @@
                 return False
         else:
             if verbose:
-                print(f"Found pre-existing {name} binary, skipping download", file=sys.stderr)
+                print(
+                    f"Found pre-existing {name} binary, skipping download",
+                    file=sys.stderr,
+                )
 
     # Now that the binary is where it should be, hash it.
     actual_bin_hash = compute_file_sha256(output_path)
@@ -143,7 +153,10 @@
 
         if reference_bin_hash != actual_bin_hash:
             print("The downloaded binary is not what was expected!", file=sys.stderr)
-            print(f"Downloaded hash: {repr(actual_bin_hash)} vs expected {reference_bin_hash}", file=sys.stderr)
+            print(
+                f"Downloaded hash: {repr(actual_bin_hash)} vs expected {reference_bin_hash}",
+                file=sys.stderr,
+            )
 
             # Err on the side of caution and try to delete the downloaded binary.
             try:
@@ -151,7 +164,10 @@
                 print("The binary has been deleted just to be safe", file=sys.stderr)
             except OSError as e:
                 print(f"Failed to delete binary: {e}", file=sys.stderr)
-                print("Delete this binary as soon as possible and do not execute it!", file=sys.stderr)
+                print(
+                    "Delete this binary as soon as possible and do not execute it!",
+                    file=sys.stderr,
+                )
 
             return False
         else:
diff --git a/tools/linter/mypy_wrapper.py b/tools/linter/mypy_wrapper.py
index fb1dbcb..ba8f166 100755
--- a/tools/linter/mypy_wrapper.py
+++ b/tools/linter/mypy_wrapper.py
@@ -25,6 +25,7 @@
 from typing import Any, Dict, List, Optional, Set, Tuple
 
 import mypy.api
+
 # not part of the public API, but this is the easiest way to ensure that
 # we agree with what mypy actually does
 import mypy.config_parser
@@ -37,9 +38,11 @@
     config = ConfigParser()
     config.read(config_path)
     # hopefully on Windows this gives posix paths
-    return set(mypy.config_parser.split_and_match_files(
-        config['mypy']['files'],
-    ))
+    return set(
+        mypy.config_parser.split_and_match_files(
+            config["mypy"]["files"],
+        )
+    )
 
 
 # see tools/test/test_mypy_wrapper.py for examples of many of the
@@ -50,7 +53,7 @@
     """
     Return a dict from all our `mypy` ini filenames to their `files`.
     """
-    return {str(ini): read_config(ini) for ini in Path().glob('mypy*.ini')}
+    return {str(ini): read_config(ini) for ini in Path().glob("mypy*.ini")}
 
 
 def split_path(path: str) -> List[str]:
@@ -107,9 +110,7 @@
 
 
 def make_plan(
-    *,
-    configs: Dict[str, Set[str]],
-    files: List[str]
+    *, configs: Dict[str, Set[str]], files: List[str]
 ) -> Dict[str, List[str]]:
     """
     Return a dict from config names to the files to run them with.
@@ -142,18 +143,21 @@
     run at most once for each `mypy` config used by this repo.
     """
     repo_root = Path.cwd()
-    plan = make_plan(configs=config_files(), files=[
-        PurePath(f).relative_to(repo_root).as_posix() for f in files
-    ])
+    plan = make_plan(
+        configs=config_files(),
+        files=[PurePath(f).relative_to(repo_root).as_posix() for f in files],
+    )
     mypy_results = [
         mypy.api.run(
             # insert custom flags after args to avoid being overridden
             # by existing flags in args
-            args + [
+            args
+            + [
                 # don't special-case the last line
-                '--no-error-summary',
-                f'--config-file={config}',
-            ] + filtered
+                "--no-error-summary",
+                f"--config-file={config}",
+            ]
+            + filtered
         )
         # by construction, filtered must be nonempty
         for config, filtered in plan.items()
@@ -165,11 +169,11 @@
             [exit_code for _, _, exit_code in mypy_results],
             default=0,
         ),
-        list(dict.fromkeys(  # remove duplicates, retain order
-            item
-            for stdout, _, _ in mypy_results
-            for item in stdout.splitlines()
-        )),
+        list(
+            dict.fromkeys(  # remove duplicates, retain order
+                item for stdout, _, _ in mypy_results for item in stdout.splitlines()
+            )
+        ),
         [stderr for _, stderr, _ in mypy_results],
     )
 
@@ -212,9 +216,9 @@
     for issue in mypy_issues:
         print(issue)
     for stderr in stderrs:
-        print(stderr, end='', file=sys.stderr)
+        print(stderr, end="", file=sys.stderr)
     sys.exit(exit_code)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main(sys.argv[1:])
diff --git a/tools/linter/trailing_newlines.py b/tools/linter/trailing_newlines.py
index ee743a4..90f2196 100755
--- a/tools/linter/trailing_newlines.py
+++ b/tools/linter/trailing_newlines.py
@@ -4,11 +4,11 @@
 import os
 import sys
 
-NEWLINE, = b'\n'
+(NEWLINE,) = b"\n"
 
 
 def correct_trailing_newlines(filename: str) -> bool:
-    with open(filename, 'rb') as f:
+    with open(filename, "rb") as f:
         a = len(f.read(2))
         if a == 0:
             return True
@@ -33,5 +33,5 @@
     return exit_code
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     sys.exit(main())
diff --git a/tools/linter/translate_annotations.py b/tools/linter/translate_annotations.py
index ed0147e..8e4e300 100755
--- a/tools/linter/translate_annotations.py
+++ b/tools/linter/translate_annotations.py
@@ -6,8 +6,17 @@
 import subprocess
 from bisect import bisect_right
 from collections import defaultdict
-from typing import (Callable, DefaultDict, Generic, List, Optional, Pattern,
-                    Sequence, TypeVar, cast)
+from typing import (
+    Callable,
+    DefaultDict,
+    Generic,
+    List,
+    Optional,
+    Pattern,
+    Sequence,
+    TypeVar,
+    cast,
+)
 
 from typing_extensions import TypedDict
 
@@ -25,7 +34,7 @@
 
 
 # @@ -start,count +start,count @@
-hunk_pattern = r'^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@'
+hunk_pattern = r"^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@"
 
 
 def parse_diff(diff: str) -> Diff:
@@ -37,26 +46,28 @@
         if name_found:
             if hunk_match:
                 old_start, old_count, new_start, new_count = hunk_match.groups()
-                hunks.append({
-                    'old_start': int(old_start),
-                    'old_count': int(old_count or '1'),
-                    'new_start': int(new_start),
-                    'new_count': int(new_count or '1'),
-                })
+                hunks.append(
+                    {
+                        "old_start": int(old_start),
+                        "old_count": int(old_count or "1"),
+                        "new_start": int(new_start),
+                        "new_count": int(new_count or "1"),
+                    }
+                )
         else:
             assert not hunk_match
-            name_match = re.match(r'^--- (?:(?:/dev/null)|(?:a/(.*)))$', line)
+            name_match = re.match(r"^--- (?:(?:/dev/null)|(?:a/(.*)))$", line)
             if name_match:
                 name_found = True
-                name, = name_match.groups()
+                (name,) = name_match.groups()
     return {
-        'old_filename': name,
-        'hunks': hunks,
+        "old_filename": name,
+        "hunks": hunks,
     }
 
 
-T = TypeVar('T')
-U = TypeVar('U')
+T = TypeVar("T")
+U = TypeVar("U")
 
 
 # we want to use bisect.bisect_right to find the closest hunk to a given
@@ -81,21 +92,20 @@
     if line_number < 1:
         return None
 
-    hunks = diff['hunks']
+    hunks = diff["hunks"]
     if not hunks:
         return line_number
 
     keyified = KeyifyList(
-        hunks,
-        lambda hunk: hunk['new_start'] + (0 if hunk['new_count'] > 0 else 1)
+        hunks, lambda hunk: hunk["new_start"] + (0 if hunk["new_count"] > 0 else 1)
     )
     i = bisect_right(cast(Sequence[int], keyified), line_number)
     if i < 1:
         return line_number
 
     hunk = hunks[i - 1]
-    d = line_number - (hunk['new_start'] + (hunk['new_count'] or 1))
-    return None if d < 0 else hunk['old_start'] + (hunk['old_count'] or 1) + d
+    d = line_number - (hunk["new_start"] + (hunk["new_count"] or 1))
+    return None if d < 0 else hunk["old_start"] + (hunk["old_count"] or 1) + d
 
 
 # we use camelCase here because this will be output as JSON and so the
@@ -113,68 +123,61 @@
     m = re.match(regex, line)
     if m:
         try:
-            line_number = int(m.group('lineNumber'))
-            column_number = int(m.group('columnNumber'))
+            line_number = int(m.group("lineNumber"))
+            column_number = int(m.group("columnNumber"))
         except ValueError:
             return None
         return {
-            'filename': m.group('filename'),
-            'lineNumber': line_number,
-            'columnNumber': column_number,
-            'errorCode': m.group('errorCode'),
-            'errorDesc': m.group('errorDesc'),
+            "filename": m.group("filename"),
+            "lineNumber": line_number,
+            "columnNumber": column_number,
+            "errorCode": m.group("errorCode"),
+            "errorDesc": m.group("errorDesc"),
         }
     else:
         return None
 
 
 def translate_all(
-    *,
-    lines: List[str],
-    regex: Pattern[str],
-    commit: str
+    *, lines: List[str], regex: Pattern[str], commit: str
 ) -> List[Annotation]:
     ann_dict: DefaultDict[str, List[Annotation]] = defaultdict(list)
     for line in lines:
         annotation = parse_annotation(regex, line)
         if annotation is not None:
-            ann_dict[annotation['filename']].append(annotation)
+            ann_dict[annotation["filename"]].append(annotation)
     ann_list = []
     for filename, annotations in ann_dict.items():
         raw_diff = subprocess.check_output(
-            ['git', 'diff-index', '--unified=0', commit, filename],
-            encoding='utf-8',
+            ["git", "diff-index", "--unified=0", commit, filename],
+            encoding="utf-8",
         )
         diff = parse_diff(raw_diff) if raw_diff.strip() else None
         # if there is a diff but it doesn't list an old filename, that
         # means the file is absent in the commit we're targeting, so we
         # skip it
-        if not (diff and not diff['old_filename']):
+        if not (diff and not diff["old_filename"]):
             for annotation in annotations:
-                line_number: Optional[int] = annotation['lineNumber']
+                line_number: Optional[int] = annotation["lineNumber"]
                 if diff:
-                    annotation['filename'] = cast(str, diff['old_filename'])
+                    annotation["filename"] = cast(str, diff["old_filename"])
                     line_number = translate(diff, cast(int, line_number))
                 if line_number:
-                    annotation['lineNumber'] = line_number
+                    annotation["lineNumber"] = line_number
                     ann_list.append(annotation)
     return ann_list
 
 
 def main() -> None:
     parser = argparse.ArgumentParser()
-    parser.add_argument('--file')
-    parser.add_argument('--regex')
-    parser.add_argument('--commit')
+    parser.add_argument("--file")
+    parser.add_argument("--regex")
+    parser.add_argument("--commit")
     args = parser.parse_args()
-    with open(args.file, 'r') as f:
+    with open(args.file, "r") as f:
         lines = f.readlines()
-    print(json.dumps(translate_all(
-        lines=lines,
-        regex=args.regex,
-        commit=args.commit
-    )))
+    print(json.dumps(translate_all(lines=lines, regex=args.regex, commit=args.commit)))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py
index e34b7bb..6169baa 100644
--- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py
+++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py
@@ -44,6 +44,7 @@
 
 """
 
+
 def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
     ops = []
     for (op_name, op) in selective_builder.operators.items():
@@ -51,18 +52,24 @@
             ops.append(op_name)
     return set(ops)
 
+
 def get_selected_kernel_dtypes_code(
-        selective_builder: SelectiveBuilder,
+    selective_builder: SelectiveBuilder,
 ) -> str:
     # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
     # generated code in case all kernel dtypes are selected and in case some kernel
     # dtypes are selected (i.e. both cases).
     #
     body = "return true;"
-    if selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False:
+    if (
+        selective_builder.include_all_operators is False
+        and selective_builder.include_all_non_op_selectives is False
+    ):
         body_parts = []
         for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
-            conditions = list(map(lambda x: 'scalar_type == at::ScalarType::' + x, dtypes))
+            conditions = list(
+                map(lambda x: "scalar_type == at::ScalarType::" + x, dtypes)
+            )
             body_parts.append(
                 if_condition_template.substitute(
                     kernel_tag_name=kernel_tag,
@@ -79,8 +86,8 @@
 # 1. The selected root operators
 # 2. The selected kernel dtypes
 def write_selected_mobile_ops(
-        output_file_path: str,
-        selective_builder: SelectiveBuilder,
+    output_file_path: str,
+    selective_builder: SelectiveBuilder,
 ) -> None:
     root_ops = extract_root_operators(selective_builder)
     custom_classes = selective_builder.custom_classes
@@ -90,16 +97,29 @@
         # This condition checks if we are in selective build.
         # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
         if not selective_builder.include_all_operators:
-            body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
+            body_parts.append(
+                "#define TORCH_OPERATOR_WHITELIST "
+                + (";".join(sorted(root_ops)))
+                + ";\n\n"
+            )
             # This condition checks if we are in tracing based selective build
             if selective_builder.include_all_non_op_selectives is False:
-                body_parts.append("#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n")
-                body_parts.append("#define TORCH_BUILD_FEATURE_ALLOWLIST " + (";".join(sorted(build_features))) + ";\n\n")
+                body_parts.append(
+                    "#define TORCH_CUSTOM_CLASS_ALLOWLIST "
+                    + (";".join(sorted(custom_classes)))
+                    + ";\n\n"
+                )
+                body_parts.append(
+                    "#define TORCH_BUILD_FEATURE_ALLOWLIST "
+                    + (";".join(sorted(build_features)))
+                    + ";\n\n"
+                )
 
         body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
         header_contents = "".join(body_parts)
         out_file.write(header_contents.encode("utf-8"))
 
+
 # root_ops: a set of selected root operators for selective build
 # Write the file selected_mobile_ops.h with optionally:
 # 1. The selected root operators from root_ops
@@ -110,7 +130,9 @@
 ) -> None:
     with open(output_file_path, "wb") as out_file:
         body_parts = [selected_mobile_ops_preamble]
-        body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
+        body_parts.append(
+            "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
+        )
 
         selective_builder = SelectiveBuilder.get_nop_selector()
         body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
@@ -118,17 +140,25 @@
         header_contents = "".join(body_parts)
         out_file.write(header_contents.encode("utf-8"))
 
+
 def main() -> None:
     parser = argparse.ArgumentParser(
         description="Generate selected_mobile_ops.h for selective build."
     )
     parser.add_argument(
-        "-p", "--yaml_file_path", type=str, required=True, help="Path to the yaml"
-        " file with a list of operators used by the model."
+        "-p",
+        "--yaml_file_path",
+        type=str,
+        required=True,
+        help="Path to the yaml" " file with a list of operators used by the model.",
     )
     parser.add_argument(
-        "-o", "--output_file_path", type=str, required=True, help="Path to destination"
-        "folder where selected_mobile_ops.h will be written."
+        "-o",
+        "--output_file_path",
+        type=str,
+        required=True,
+        help="Path to destination"
+        "folder where selected_mobile_ops.h will be written.",
     )
     parsed_args = parser.parse_args()
     model_file_name = parsed_args.yaml_file_path
@@ -138,12 +168,13 @@
     with open(model_file_name, "rb") as model_file:
         loaded_model = yaml.load(model_file, Loader=Loader)
 
-
     root_operators_set = set(loaded_model)
     print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
     write_selected_mobile_ops_with_all_dtypes(
         os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
-        root_operators_set)
+        root_operators_set,
+    )
+
 
 if __name__ == "__main__":
     main()
diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py
index deaf65d..5a13958 100644
--- a/tools/lldb/deploy_debugger.py
+++ b/tools/lldb/deploy_debugger.py
@@ -1,10 +1,12 @@
 import lldb  # type: ignore[import]
+
 # load into lldb instance with:
 #   command script import tools/lldb/deploy_debugger.py
 
 target = lldb.debugger.GetSelectedTarget()
 bp = target.BreakpointCreateByRegex("__deploy_register_code")
-bp.SetScriptCallbackBody("""\
+bp.SetScriptCallbackBody(
+    """\
 process = frame.thread.GetProcess()
 target = process.target
 symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress()
@@ -31,4 +33,5 @@
     lldb.debugger.HandleCommand(cmd2)
 
 return False
-""")
+"""
+)
diff --git a/tools/nightly.py b/tools/nightly.py
index 7a46a01..32733c5 100755
--- a/tools/nightly.py
+++ b/tools/nightly.py
@@ -40,8 +40,21 @@
 import subprocess
 from ast import literal_eval
 from argparse import ArgumentParser
-from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List,
-                    Optional, Sequence, Set, Tuple, TypeVar, cast)
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+    cast,
+)
 
 LOGGER: Optional[logging.Logger] = None
 URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
@@ -199,7 +212,13 @@
         return "Branch name to checkout must be supplied with '-b' option"
     # next check that the local repo is clean
     cmd = ["git", "status", "--untracked-files=no", "--porcelain"]
-    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, universal_newlines=True)
+    p = subprocess.run(
+        cmd,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        check=True,
+        universal_newlines=True,
+    )
     if p.stdout.strip():
         return "Need to have clean working tree to checkout!\n\n" + p.stdout
     # next check that the branch name doesn't already exist
@@ -218,7 +237,7 @@
     logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]")
 
 
-F = TypeVar('F', bound=Callable[..., Any])
+F = TypeVar("F", bound=Callable[..., Any])
 
 
 def timed(prefix: str) -> Callable[[F], F]:
@@ -325,7 +344,7 @@
 
 @timed("Installing pytorch nightly binaries")
 def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
-    """"Install pytorch into a temporary directory"""
+    """ "Install pytorch into a temporary directory"""
     pytdir = tempfile.TemporaryDirectory()
     cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
     p = subprocess.run(cmd, check=True)
@@ -369,7 +388,13 @@
     # now cross reference with nightly version
     _ensure_commit(git_version)
     cmd = ["git", "show", "--no-patch", "--format=%s", git_version]
-    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, universal_newlines=True)
+    p = subprocess.run(
+        cmd,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        check=True,
+        universal_newlines=True,
+    )
     m = SHA1_RE.search(p.stdout)
     if m is None:
         raise RuntimeError(
@@ -516,7 +541,13 @@
 
 def _available_envs() -> Dict[str, str]:
     cmd = ["conda", "env", "list"]
-    p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
+    p = subprocess.run(
+        cmd,
+        check=True,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        universal_newlines=True,
+    )
     lines = p.stdout.splitlines()
     envs = {}
     for line in map(str.strip, lines):
diff --git a/tools/onnx/update_default_opset_version.py b/tools/onnx/update_default_opset_version.py
index 358bbfd..a854a1c 100755
--- a/tools/onnx/update_default_opset_version.py
+++ b/tools/onnx/update_default_opset_version.py
@@ -23,9 +23,12 @@
 os.chdir(onnx_dir)
 
 date = datetime.datetime.now() - datetime.timedelta(days=18 * 30)
-onnx_commit = subprocess.check_output(("git", "log", f"--until={date}", "--max-count=1", "--format=%H"),
-                                      encoding="utf-8").strip()
-onnx_tags = subprocess.check_output(("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8")
+onnx_commit = subprocess.check_output(
+    ("git", "log", f"--until={date}", "--max-count=1", "--format=%H"), encoding="utf-8"
+).strip()
+onnx_tags = subprocess.check_output(
+    ("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8"
+)
 tag_tups = []
 semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
 for tag in onnx_tags.splitlines():
@@ -37,23 +40,31 @@
 
 print("Using ONNX release", version_str)
 
-head_commit = subprocess.check_output(("git", "log", "--max-count=1", "--format=%H", "HEAD"),
-                                      encoding="utf-8").strip()
+head_commit = subprocess.check_output(
+    ("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8"
+).strip()
 
 new_default = None
 
-subprocess.check_call(("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL)
+subprocess.check_call(
+    ("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL
+)
 try:
     from onnx import helper  # type: ignore[import]
+
     for version in helper.VERSION_TABLE:
         if version[0] == version_str:
             new_default = version[2]
             print("found new default opset_version", new_default)
             break
     if not new_default:
-        sys.exit(f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}")
+        sys.exit(
+            f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}"
+        )
 finally:
-    subprocess.check_call(("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL)
+    subprocess.check_call(
+        ("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL
+    )
 
 os.chdir(pytorch_dir)
 
@@ -66,13 +77,19 @@
         f.write(content_str)
     print("modified", path)
 
-read_sub_write(os.path.join("torch", "onnx", "symbolic_helper.py"),
-               r"(_default_onnx_opset_version = )\d+")
-read_sub_write(os.path.join("torch", "onnx", "__init__.py"),
-               r"(opset_version \(int, default )\d+")
+
+read_sub_write(
+    os.path.join("torch", "onnx", "symbolic_helper.py"),
+    r"(_default_onnx_opset_version = )\d+",
+)
+read_sub_write(
+    os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+"
+)
 
 print("Updating operator .expect files")
-subprocess.check_call(("python", "setup.py", "develop"),
-                      stdout=DEVNULL, stderr=DEVNULL)
-subprocess.check_call(("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
-                      stdout=DEVNULL, stderr=DEVNULL)
+subprocess.check_call(("python", "setup.py", "develop"), stdout=DEVNULL, stderr=DEVNULL)
+subprocess.check_call(
+    ("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
+    stdout=DEVNULL,
+    stderr=DEVNULL,
+)
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
index e2aacea..f85d99b 100644
--- a/tools/pyi/gen_pyi.py
+++ b/tools/pyi/gen_pyi.py
@@ -3,14 +3,20 @@
 from pprint import pformat
 
 from tools.codegen.model import Variant
-from tools.codegen.api.python import (PythonSignatureGroup,
-                                      PythonSignatureNativeFunctionPair,
-                                      returns_named_tuple_pyi)
+from tools.codegen.api.python import (
+    PythonSignatureGroup,
+    PythonSignatureNativeFunctionPair,
+    returns_named_tuple_pyi,
+)
 from tools.codegen.gen import parse_native_yaml
 from tools.codegen.utils import FileManager
 from typing import Sequence, List, Dict
 
-from tools.autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads
+from tools.autograd.gen_python_functions import (
+    should_generate_py_binding,
+    load_signatures,
+    group_overloads,
+)
 
 """
 This module implements generation of type stubs for PyTorch,
@@ -36,23 +42,29 @@
 read gen_pyi for the gory details.
 """
 
+
 def get_py_torch_functions(
-        python_funcs: Sequence[PythonSignatureNativeFunctionPair],
-        method: bool = False,
+    python_funcs: Sequence[PythonSignatureNativeFunctionPair],
+    method: bool = False,
 ) -> Sequence[PythonSignatureGroup]:
     """
     Get declarations (grouped by name) which should be generated
     as either functions in the "torch" module or methods on Tensor.
     """
+
     def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
-        return (should_generate_py_binding(python_func.function) and
-                not python_func.function.python_module and
-                Variant.function in python_func.function.variants)
+        return (
+            should_generate_py_binding(python_func.function)
+            and not python_func.function.python_module
+            and Variant.function in python_func.function.variants
+        )
 
     def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
-        return (should_generate_py_binding(python_func.function) and
-                not python_func.function.python_module and
-                Variant.method in python_func.function.variants)
+        return (
+            should_generate_py_binding(python_func.function)
+            and not python_func.function.python_module
+            and Variant.method in python_func.function.variants
+        )
 
     should_bind = should_bind_method if method else should_bind_function
     return group_overloads([f for f in python_funcs if should_bind(f)])
@@ -62,76 +74,111 @@
 # the stubs to read on the human eye.
 
 DEVICE_PARAM = "device: Union[_device, str, None]=None"
-FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
+FACTORY_PARAMS = (
+    f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
+)
 
 # this could be more precise w.r.t list contents etc. How to do Ellipsis?
 INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
 
 blocklist = [
-    '__init_subclass__',
-    '__new__',
-    '__subclasshook__',
-    'cdist',
-    'device',
-    'grad',
-    'requires_grad',
-    'range',
+    "__init_subclass__",
+    "__new__",
+    "__subclasshook__",
+    "cdist",
+    "device",
+    "grad",
+    "requires_grad",
+    "range",
     # defined in functional
-    'einsum',
+    "einsum",
     # reduction argument; these bindings don't make sense
-    'binary_cross_entropy_with_logits',
-    'ctc_loss',
-    'cosine_embedding_loss',
-    'hinge_embedding_loss',
-    'kl_div',
-    'margin_ranking_loss',
-    'triplet_margin_loss',
+    "binary_cross_entropy_with_logits",
+    "ctc_loss",
+    "cosine_embedding_loss",
+    "hinge_embedding_loss",
+    "kl_div",
+    "margin_ranking_loss",
+    "triplet_margin_loss",
     # Somehow, these are defined in both _C and in functional. Ick!
-    'broadcast_tensors',
+    "broadcast_tensors",
     # Manually define named tensor type stubs in __init__.pyi.in
-    'align_tensors',
-    'meshgrid',
-    'cartesian_prod',
-    'block_diag',
-    'norm',
-    'chain_matmul',
-    'stft',
-    'tensordot',
-    'split',
-    'unique_consecutive',
-    'atleast_1d',
-    'atleast_2d',
-    'atleast_3d',
+    "align_tensors",
+    "meshgrid",
+    "cartesian_prod",
+    "block_diag",
+    "norm",
+    "chain_matmul",
+    "stft",
+    "tensordot",
+    "split",
+    "unique_consecutive",
+    "atleast_1d",
+    "atleast_2d",
+    "atleast_3d",
     # These are handled specially by python_arg_parser.cpp
-    'add',
-    'add_',
-    'add_out',
-    'sub',
-    'sub_',
-    'sub_out',
-    'mul',
-    'mul_',
-    'mul_out',
-    'div',
-    'div_',
-    'div_out',
-    'true_divide', 'true_divide_', 'true_divide_out',
-    'floor_divide', 'floor_divide_', 'floor_divide_out',
+    "add",
+    "add_",
+    "add_out",
+    "sub",
+    "sub_",
+    "sub_out",
+    "mul",
+    "mul_",
+    "mul_out",
+    "div",
+    "div_",
+    "div_out",
+    "true_divide",
+    "true_divide_",
+    "true_divide_out",
+    "floor_divide",
+    "floor_divide_",
+    "floor_divide_out",
 ]
 
-binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
-              'matmul', 'floordiv',
-              'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow',          # reverse arithmetic
-              'and', 'or', 'xor', 'rand', 'ror', 'rxor',  # logic
-              'iadd', 'iand', 'idiv', 'ilshift', 'imul',
-              'ior', 'irshift', 'isub', 'ixor', 'ifloordiv', 'imod',  # inplace ops
-              )
-symmetric_comparison_ops = ('eq', 'ne')
-asymmetric_comparison_ops = ('ge', 'gt', 'lt', 'le')
+binary_ops = (
+    "add",
+    "sub",
+    "mul",
+    "div",
+    "pow",
+    "lshift",
+    "rshift",
+    "mod",
+    "truediv",
+    "matmul",
+    "floordiv",
+    "radd",
+    "rsub",
+    "rmul",
+    "rtruediv",
+    "rfloordiv",
+    "rpow",  # reverse arithmetic
+    "and",
+    "or",
+    "xor",
+    "rand",
+    "ror",
+    "rxor",  # logic
+    "iadd",
+    "iand",
+    "idiv",
+    "ilshift",
+    "imul",
+    "ior",
+    "irshift",
+    "isub",
+    "ixor",
+    "ifloordiv",
+    "imod",  # inplace ops
+)
+symmetric_comparison_ops = ("eq", "ne")
+asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
 comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
 
-unary_ops = ('neg', 'abs', 'invert')
-to_py_type_ops = ('bool', 'float', 'complex', 'long', 'index', 'int', 'nonzero')
+unary_ops = ("neg", "abs", "invert")
+to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
 all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
 
 
@@ -142,32 +189,35 @@
 
     # we have to do this by hand, because they are hand-bound in Python
 
-    assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
+    assert opname.endswith("__") and opname.startswith("__"), "Unexpected op {}".format(
+        opname
+    )
 
     name = opname[2:-2]
     if name in binary_ops:
-        return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
+        return ["def {}(self, other: Any) -> Tensor: ...".format(opname)]
     elif name in comparison_ops:
-        sig = 'def {}(self, other: Any) -> Tensor: ...'.format(opname)
+        sig = "def {}(self, other: Any) -> Tensor: ...".format(opname)
         if name in symmetric_comparison_ops:
             # unsafe override https://github.com/python/mypy/issues/5704
-            sig += '  # type: ignore[override]'
+            sig += "  # type: ignore[override]"
         return [sig]
     elif name in unary_ops:
-        return ['def {}(self) -> Tensor: ...'.format(opname)]
+        return ["def {}(self) -> Tensor: ...".format(opname)]
     elif name in to_py_type_ops:
-        if name in {'bool', 'float', 'complex'}:
+        if name in {"bool", "float", "complex"}:
             tname = name
-        elif name == 'nonzero':
-            tname = 'bool'
+        elif name == "nonzero":
+            tname = "bool"
         else:
-            tname = 'int'
-        if tname in {'float', 'int', 'bool', 'complex'}:
-            tname = 'builtins.' + tname
-        return ['def {}(self) -> {}: ...'.format(opname, tname)]
+            tname = "int"
+        if tname in {"float", "int", "bool", "complex"}:
+            tname = "builtins." + tname
+        return ["def {}(self) -> {}: ...".format(opname, tname)]
     else:
         raise Exception("unknown op", opname)
 
+
 def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
     type_hints: List[str] = []
 
@@ -185,79 +235,90 @@
     # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
     # Generates the out variant if one exists. Otherwise, generate the functional variant
     type_hint = sig_group.signature.signature_str_pyi(
-        skip_outputs=sig_group.outplace is None)
+        skip_outputs=sig_group.outplace is None
+    )
     type_hints.append(type_hint)
 
     # Some operators also additionally have a vararg variant of their signature
     type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
-        skip_outputs=sig_group.outplace is None)
+        skip_outputs=sig_group.outplace is None
+    )
     if type_hint_vararg:
         type_hints.append(type_hint_vararg)
 
     return type_hints
 
+
 def gen_nn_functional(fm: FileManager) -> None:
     # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
     # through an `_add_docstr` call
     imports = [
-        'conv1d',
-        'conv2d',
-        'conv3d',
-        'conv_transpose1d',
-        'conv_transpose2d',
-        'conv_transpose3d',
-        'conv_tbc',
-        'avg_pool1d',
-        'relu_',
-        'selu_',
-        'celu_',
-        'rrelu_',
-        'pixel_shuffle',
-        'pixel_unshuffle',
-        'channel_shuffle',
-        'native_channel_shuffle',
-        'pdist',
-        'cosine_similarity',
+        "conv1d",
+        "conv2d",
+        "conv3d",
+        "conv_transpose1d",
+        "conv_transpose2d",
+        "conv_transpose3d",
+        "conv_tbc",
+        "avg_pool1d",
+        "relu_",
+        "selu_",
+        "celu_",
+        "rrelu_",
+        "pixel_shuffle",
+        "pixel_unshuffle",
+        "channel_shuffle",
+        "native_channel_shuffle",
+        "pdist",
+        "cosine_similarity",
     ]
     # Functions generated by `torch._jit_internal.boolean_dispatch`
     dispatches = [
-        'fractional_max_pool2d',
-        'fractional_max_pool3d',
-        'max_pool1d',
-        'max_pool2d',
-        'max_pool3d',
-        'adaptive_max_pool1d',
-        'adaptive_max_pool2d',
-        'adaptive_max_pool3d',
+        "fractional_max_pool2d",
+        "fractional_max_pool3d",
+        "max_pool1d",
+        "max_pool2d",
+        "max_pool3d",
+        "adaptive_max_pool1d",
+        "adaptive_max_pool2d",
+        "adaptive_max_pool3d",
     ]
     # Functions directly imported from `torch._C`
     from_c = [
-        'avg_pool2d',
-        'avg_pool3d',
-        'hardtanh_',
-        'elu_',
-        'leaky_relu_',
-        'logsigmoid',
-        'softplus',
-        'softshrink',
-        'one_hot',
+        "avg_pool2d",
+        "avg_pool3d",
+        "hardtanh_",
+        "elu_",
+        "leaky_relu_",
+        "logsigmoid",
+        "softplus",
+        "softshrink",
+        "one_hot",
     ]
     import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
     # TODO make these types more precise
     dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
-    fm.write_with_template('torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: {
-        'imported_hints': import_code,
-        'dispatched_hints': dispatch_code,
-    })
+    fm.write_with_template(
+        "torch/nn/functional.pyi",
+        "torch/nn/functional.pyi.in",
+        lambda: {
+            "imported_hints": import_code,
+            "dispatched_hints": dispatch_code,
+        },
+    )
 
     # functional.pyi already contains the definitions for those functions
     # so, we don't export then to it
-    from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid'])
+    from_c.extend(["hardtanh", "leaky_relu", "hardsigmoid"])
     dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
-    fm.write_with_template('torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: {
-        'imported_hints': import_code,
-        'dispatched_hints': dispatch_code,
-    })
+    fm.write_with_template(
+        "torch/_C/_nn.pyi",
+        "torch/_C/_nn.pyi.in",
+        lambda: {
+            "imported_hints": import_code,
+            "dispatched_hints": dispatch_code,
+        },
+    )
 
 
 def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
@@ -280,119 +341,183 @@
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
     unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
-    unsorted_function_hints.update({
-        'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
-        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
-        'asarray': ['def asarray(obj: Any, *, dtype: Optional[_dtype]=None, '
-                    'device: Union[_device, str, None]=None, copy: Optional[_bool]=None, '
-                    'requires_grad: _bool=False) -> Tensor: ...'],
-        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
-        'frombuffer': ['def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, '
-                       'offset: int=0, device: Union[_device, str, None]=None, '
-                       'requires_grad: _bool=False) -> Tensor: ...'],
-        'numel': ['def numel(self: Tensor) -> _int: ...'],
-        'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
-        'get_num_threads': ['def get_num_threads() -> _int: ...'],
-        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
-        'init_num_threads': ['def init_num_threads() -> None: ...'],
-        'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
-        'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
-        # These functions are explicitly disabled by
-        # SKIP_PYTHON_BINDINGS because they are hand bound.
-        # Correspondingly, we must hand-write their signatures.
-        'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
-        'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
-                              ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
-                              ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
-        'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],'
-                               'col_indices: Union[Tensor, List],'
-                               ' values: Union[Tensor, List], size: Optional[_size]=None,'
-                               ' *, dtype: Optional[_dtype]=None,'
-                               ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
-        '_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
-                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
-                                      ' requires_grad: bool = False) -> Tensor: ...'],
-        '_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],'
-                                      'col_indices: Union[Tensor, List],'
-                                      ' values: Union[Tensor, List], size: List[int],'
-                                      ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
-                                      ' requires_grad: bool = False) -> Tensor: ...'],
-        'range': ['def range(start: Number, end: Number,'
-                  ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
-                  .format(FACTORY_PARAMS)],
-        'arange': ['def arange(start: Number, end: Number, step: Number, *,'
-                   ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
-                   .format(FACTORY_PARAMS),
-                   'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
-                   .format(FACTORY_PARAMS),
-                   'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
-                   .format(FACTORY_PARAMS)],
-        'linspace': ['def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,'
-                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
-        'logspace': ['def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,'
-                     ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(FACTORY_PARAMS)],
-        'randint': ['def randint(low: _int, high: _int, size: _size, *,'
-                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
-                    .format(FACTORY_PARAMS),
-                    'def randint(high: _int, size: _size, *,'
-                    ' generator: Optional[Generator]=None, {}) -> Tensor: ...'
-                    .format(FACTORY_PARAMS)],
-        'full': ['def full(size: _size, fill_value: Number, *,'
-                 ' out: Optional[Tensor]=None,'
-                 ' layout: _layout=strided, {}) -> Tensor: ...'
-                 .format(FACTORY_PARAMS),
-                 'def full(size: _size, fill_value: Number, *,'
-                 ' names: List[Union[str, None]],'
-                 ' layout: _layout=strided, {}) -> Tensor: ...'
-                 .format(FACTORY_PARAMS)],
-        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
-        'is_inference_mode_enabled': ['def is_inference_mode_enabled() -> _bool: ...'],
-        'nonzero': ['def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...',
-                    'def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
-        'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, '
-                                             'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, '
-                                             'reduce: Optional[bool] = None, reduction: str = ..., '
-                                             'pos_weight: Optional[Tensor] = None) -> Tensor: ...'],
-        'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, '
-                                  'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., '
-                                  'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
-        'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,'
-                     ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'],
-        'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,'
-                                 ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., '
-                                 'reduction: str = ...) -> Tensor: ...'],
-        'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., '
-                   'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'],
-        'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,'
-                                ' margin: float = ..., size_average: Optional[bool] = ..., '
-                                ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
-        'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, '
-                                'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., '
-                                'size_average: Optional[bool] = ..., '
-                                'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'],
-        'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
-        'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
-        'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, '
-                   'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'],
-        'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'],
-        'div': ['def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, '
-                'rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ...'],
-    })
-    for binop in ['mul', 'true_divide', 'floor_divide']:
+    unsorted_function_hints.update(
+        {
+            "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."],
+            "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."],
+            "asarray": [
+                "def asarray(obj: Any, *, dtype: Optional[_dtype]=None, "
+                "device: Union[_device, str, None]=None, copy: Optional[_bool]=None, "
+                "requires_grad: _bool=False) -> Tensor: ..."
+            ],
+            "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."],
+            "frombuffer": [
+                "def frombuffer(buffer: Any, *, dtype: _dtype, count: int=-1, "
+                "offset: int=0, device: Union[_device, str, None]=None, "
+                "requires_grad: _bool=False) -> Tensor: ..."
+            ],
+            "numel": ["def numel(self: Tensor) -> _int: ..."],
+            "as_tensor": [
+                "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."
+            ],
+            "get_num_threads": ["def get_num_threads() -> _int: ..."],
+            "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."],
+            "init_num_threads": ["def init_num_threads() -> None: ..."],
+            "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."],
+            "set_num_interop_threads": [
+                "def set_num_interop_threads(num: _int) -> None: ..."
+            ],
+            # These functions are explicitly disabled by
+            # SKIP_PYTHON_BINDINGS because they are hand bound.
+            # Correspondingly, we must hand-write their signatures.
+            "tensor": [
+                "def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)
+            ],
+            "sparse_coo_tensor": [
+                "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],"
+                " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,"
+                " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
+            ],
+            "sparse_csr_tensor": [
+                "def sparse_csr_tensor(crow_indices: Union[Tensor, List],"
+                "col_indices: Union[Tensor, List],"
+                " values: Union[Tensor, List], size: Optional[_size]=None,"
+                " *, dtype: Optional[_dtype]=None,"
+                " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
+            ],
+            "_sparse_coo_tensor_unsafe": [
+                "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],"
+                " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
+                " requires_grad: bool = False) -> Tensor: ..."
+            ],
+            "_sparse_csr_tensor_unsafe": [
+                "def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],"
+                "col_indices: Union[Tensor, List],"
+                " values: Union[Tensor, List], size: List[int],"
+                " dtype: Optional[_dtype] = None, device: Optional[_device] = None,"
+                " requires_grad: bool = False) -> Tensor: ..."
+            ],
+            "range": [
+                "def range(start: Number, end: Number,"
+                " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                )
+            ],
+            "arange": [
+                "def arange(start: Number, end: Number, step: Number, *,"
+                " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                ),
+                "def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                ),
+                "def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                ),
+            ],
+            "linspace": [
+                "def linspace(start: Number, end: Number, steps: Optional[_int]=None, *,"
+                " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(FACTORY_PARAMS)
+            ],
+            "logspace": [
+                "def logspace(start: Number, end: Number, steps: Optional[_int]=None, base: _float=10.0, *,"
+                " out: Optional[Tensor]=None, {}) -> Tensor: ...".format(FACTORY_PARAMS)
+            ],
+            "randint": [
+                "def randint(low: _int, high: _int, size: _size, *,"
+                " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                ),
+                "def randint(high: _int, size: _size, *,"
+                " generator: Optional[Generator]=None, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                ),
+            ],
+            "full": [
+                "def full(size: _size, fill_value: Number, *,"
+                " out: Optional[Tensor]=None,"
+                " layout: _layout=strided, {}) -> Tensor: ...".format(FACTORY_PARAMS),
+                "def full(size: _size, fill_value: Number, *,"
+                " names: List[Union[str, None]],"
+                " layout: _layout=strided, {}) -> Tensor: ...".format(FACTORY_PARAMS),
+            ],
+            "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."],
+            "is_inference_mode_enabled": [
+                "def is_inference_mode_enabled() -> _bool: ..."
+            ],
+            "nonzero": [
+                "def nonzero(input: Tensor, *, as_tuple: Literal[False]=False, out: Optional[Tensor]=None) -> Tensor: ...",
+                "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
+            ],
+            "binary_cross_entropy_with_logits": [
+                "def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, "
+                "weight: Optional[Tensor] = None, size_average: Optional[bool] = None, "
+                "reduce: Optional[bool] = None, reduction: str = ..., "
+                "pos_weight: Optional[Tensor] = None) -> Tensor: ..."
+            ],
+            "cosine_embedding_loss": [
+                "def cosine_embedding_loss(input1: Tensor, input2: Tensor, "
+                "target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., "
+                "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
+            ],
+            "ctc_loss": [
+                "def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,"
+                " blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ..."
+            ],
+            "hinge_embedding_loss": [
+                "def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,"
+                " size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., "
+                "reduction: str = ...) -> Tensor: ..."
+            ],
+            "kl_div": [
+                "def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., "
+                "reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ..."
+            ],
+            "margin_ranking_loss": [
+                "def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,"
+                " margin: float = ..., size_average: Optional[bool] = ..., "
+                " reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
+            ],
+            "triplet_margin_loss": [
+                "def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, "
+                "margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., "
+                "size_average: Optional[bool] = ..., "
+                "reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ..."
+            ],
+            "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
+            "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
+            "saddmm": [
+                "def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, "
+                "alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ..."
+            ],
+            "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
+            "div": [
+                "def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, "
+                "rounding_mode: Optional[str] = None, out: Optional[Tensor]=None) -> Tensor: ..."
+            ],
+        }
+    )
+    for binop in ["mul", "true_divide", "floor_divide"]:
         unsorted_function_hints[binop].append(
-            'def {}(input: Union[Tensor, Number],'
-            ' other: Union[Tensor, Number],'
-            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
-    for binop in ['add', 'sub']:
+            "def {}(input: Union[Tensor, Number],"
+            " other: Union[Tensor, Number],"
+            " *, out: Optional[Tensor]=None) -> Tensor: ...".format(binop)
+        )
+    for binop in ["add", "sub"]:
         unsorted_function_hints[binop].append(
-            'def {}(input: Union[Tensor, Number],'
-            ' other: Union[Tensor, Number],'
-            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
+            "def {}(input: Union[Tensor, Number],"
+            " other: Union[Tensor, Number],"
+            " *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...".format(
+                binop
+            )
+        )
 
     native_functions = parse_native_yaml(native_yaml_path).native_functions
     native_functions = list(filter(should_generate_py_binding, native_functions))
 
-    function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True)
+    function_signatures = load_signatures(
+        native_functions, deprecated_yaml_path, method=False, pyi=True
+    )
     sig_groups = get_py_torch_functions(function_signatures)
     for group in sorted(sig_groups, key=lambda g: g.signature.name):
         name = group.signature.name
@@ -410,118 +535,187 @@
     function_hints = []
     for name, hints in sorted(unsorted_function_hints.items()):
         if len(hints) > 1:
-            hints = ['@overload\n' + h for h in hints]
+            hints = ["@overload\n" + h for h in hints]
         function_hints += hints
 
     # Generate type signatures for Tensor methods
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
     unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
-    unsorted_tensor_method_hints.update({
-        'size': ['def size(self) -> Size: ...',
-                 'def size(self, dim: _int) -> _int: ...'],
-        'stride': ['def stride(self) -> Tuple[_int]: ...',
-                   'def stride(self, _int) -> _int: ...'],
-        'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
-                     format(FACTORY_PARAMS)],
-        'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
-        # new and __init__ have the same signatures differ only in return type
-        # Adapted from legacy_tensor_ctor and legacy_tensor_new
-        'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
-                'def new(self, storage: Storage) -> Tensor: ...',
-                'def new(self, other: Tensor) -> Tensor: ...',
-                'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
-                ],
-        '__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
-                     'def __init__(self, storage: Storage) -> None: ...',
-                     'def __init__(self, other: Tensor) -> None: ...',
-                     'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
-                     ],
-        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
-        '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."],
-        '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
-        '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
-                        " -> None: ...".format(INDICES)],
-        'tolist': ['def tolist(self) -> List: ...'],
-        'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
-        'element_size': ['def element_size(self) -> _int: ...'],
-        'data_ptr': ['def data_ptr(self) -> _int: ...'],
-        'dim': ['def dim(self) -> _int: ...'],
-        'nonzero': ['def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...',
-                    'def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...'],
-        'numel': ['def numel(self) -> _int: ...'],
-        'ndimension': ['def ndimension(self) -> _int: ...'],
-        'nelement': ['def nelement(self) -> _int: ...'],
-        'cuda': ['def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'],
-        'numpy': ['def numpy(self) -> Any: ...'],
-        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
-        'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
-        'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'],
-        'storage': ['def _storage(self) -> Storage: ...'],
-        'storage_type': ['def storage_type(self) -> Storage: ...'],
-        'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
-                 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
-                 ],
-        'get_device': ['def get_device(self) -> _int: ...'],
-        'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'],
-        'has_names': ['def has_names(self) -> _bool: ...'],
-        'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
-        '_is_view': ['def _is_view(self) -> _bool: ...'],
-        'is_cuda': ['is_cuda: _bool'],
-        'is_leaf': ['is_leaf: _bool'],
-        'is_nested': ['is_nested: _bool'],
-        'is_sparse': ['is_sparse: _bool'],
-        'is_sparse_csr' : ['is_sparse_csr: _bool'],
-        'is_quantized': ['is_quantized: _bool'],
-        'is_meta': ['is_meta: _bool'],
-        'is_ort': ['is_ort: _bool'],
-        'is_mkldnn': ['is_mkldnn: _bool'],
-        'is_vulkan': ['is_vulkan: _bool'],
-        'is_ipu': ['is_ipu: _bool'],
-        'storage_offset': ['def storage_offset(self) -> _int: ...'],
-        'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
-               'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
-               'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
-               'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
-               ],
-        'item': ["def item(self) -> Number: ..."],
-        'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
-        'set_': ['def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...',
-                 'def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...'],
-        'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
-                  'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
-        'div': ['def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
-        'div_': ['def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...'],
-    })
-    for binop in ['mul', 'true_divide', 'floor_divide']:
+    unsorted_tensor_method_hints.update(
+        {
+            "size": [
+                "def size(self) -> Size: ...",
+                "def size(self, dim: _int) -> _int: ...",
+            ],
+            "stride": [
+                "def stride(self) -> Tuple[_int]: ...",
+                "def stride(self, _int) -> _int: ...",
+            ],
+            "new_ones": [
+                "def new_ones(self, size: _size, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                )
+            ],
+            "new_tensor": [
+                "def new_tensor(self, data: Any, {}) -> Tensor: ...".format(
+                    FACTORY_PARAMS
+                )
+            ],
+            # new and __init__ have the same signatures differ only in return type
+            # Adapted from legacy_tensor_ctor and legacy_tensor_new
+            "new": [
+                "def new(self, *args: Any, {}) ->Tensor: ...".format(DEVICE_PARAM),
+                "def new(self, storage: Storage) -> Tensor: ...",
+                "def new(self, other: Tensor) -> Tensor: ...",
+                "def new(self, size: _size, *, {}) -> Tensor: ...".format(DEVICE_PARAM),
+            ],
+            "__init__": [
+                "def __init__(self, *args: Any, {}) -> None: ...".format(DEVICE_PARAM),
+                "def __init__(self, storage: Storage) -> None: ...",
+                "def __init__(self, other: Tensor) -> None: ...",
+                "def __init__(self, size: _size, *, {}) -> None: ...".format(
+                    DEVICE_PARAM
+                ),
+            ],
+            "as_subclass": ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
+            "_make_subclass": [
+                "def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."
+            ],
+            "__getitem__": ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
+            "__setitem__": [
+                "def __setitem__(self, {}, val: Union[Tensor, Number])"
+                " -> None: ...".format(INDICES)
+            ],
+            "tolist": ["def tolist(self) -> List: ..."],
+            "requires_grad_": [
+                "def requires_grad_(self, mode: _bool=True) -> Tensor: ..."
+            ],
+            "element_size": ["def element_size(self) -> _int: ..."],
+            "data_ptr": ["def data_ptr(self) -> _int: ..."],
+            "dim": ["def dim(self) -> _int: ..."],
+            "nonzero": [
+                "def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...",
+                "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
+            ],
+            "numel": ["def numel(self) -> _int: ..."],
+            "ndimension": ["def ndimension(self) -> _int: ..."],
+            "nelement": ["def nelement(self) -> _int: ..."],
+            "cuda": [
+                "def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ..."
+            ],
+            "numpy": ["def numpy(self) -> Any: ..."],
+            "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."],
+            "map_": [
+                "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..."
+            ],
+            "map2_": [
+                "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
+            ],
+            "storage": ["def _storage(self) -> Storage: ..."],
+            "storage_type": ["def storage_type(self) -> Storage: ..."],
+            "type": [
+                "def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...",
+                "def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...",
+            ],
+            "get_device": ["def get_device(self) -> _int: ..."],
+            "contiguous": [
+                "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..."
+            ],
+            "has_names": ["def has_names(self) -> _bool: ..."],
+            "is_contiguous": [
+                "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..."
+            ],
+            "_is_view": ["def _is_view(self) -> _bool: ..."],
+            "is_cuda": ["is_cuda: _bool"],
+            "is_leaf": ["is_leaf: _bool"],
+            "is_nested": ["is_nested: _bool"],
+            "is_sparse": ["is_sparse: _bool"],
+            "is_sparse_csr": ["is_sparse_csr: _bool"],
+            "is_quantized": ["is_quantized: _bool"],
+            "is_meta": ["is_meta: _bool"],
+            "is_ort": ["is_ort: _bool"],
+            "is_mkldnn": ["is_mkldnn: _bool"],
+            "is_vulkan": ["is_vulkan: _bool"],
+            "is_ipu": ["is_ipu: _bool"],
+            "storage_offset": ["def storage_offset(self) -> _int: ..."],
+            "to": [
+                "def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
+                "def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, "
+                "non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
+                "def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...",
+            ],
+            "item": ["def item(self) -> Number: ..."],
+            "copy_": [
+                "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
+            ],
+            "set_": [
+                "def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
+                "def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
+            ],
+            "split": [
+                "def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
+                "def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...",
+            ],
+            "div": [
+                "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
+            ],
+            "div_": [
+                "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
+            ],
+        }
+    )
+    for binop in ["mul", "true_divide", "floor_divide"]:
         for inplace in [False, True]:
-            out_suffix = ', *, out: Optional[Tensor]=None'
+            out_suffix = ", *, out: Optional[Tensor]=None"
             if inplace:
-                binop += '_'
-                out_suffix = ''
+                binop += "_"
+                out_suffix = ""
             unsorted_tensor_method_hints[binop].append(
-                'def {}(self, other: Union[Tensor, Number]{})'
-                ' -> Tensor: ...'.format(binop, out_suffix))
-    for binop in ['add', 'sub']:
+                "def {}(self, other: Union[Tensor, Number]{})"
+                " -> Tensor: ...".format(binop, out_suffix)
+            )
+    for binop in ["add", "sub"]:
         for inplace in [False, True]:
-            out_suffix = ', out: Optional[Tensor]=None'
+            out_suffix = ", out: Optional[Tensor]=None"
             if inplace:
-                binop += '_'
-                out_suffix = ''
+                binop += "_"
+                out_suffix = ""
             unsorted_tensor_method_hints[binop].append(
-                'def {}(self, other: Union[Tensor, Number], '
-                '*, alpha: Optional[Number]=1{})'
-                ' -> Tensor: ...'.format(binop, out_suffix))
-    simple_conversions = ['byte', 'char', 'cpu', 'double', 'float',
-                          'half', 'int', 'long', 'short', 'bool',
-                          'bfloat16']
+                "def {}(self, other: Union[Tensor, Number], "
+                "*, alpha: Optional[Number]=1{})"
+                " -> Tensor: ...".format(binop, out_suffix)
+            )
+    simple_conversions = [
+        "byte",
+        "char",
+        "cpu",
+        "double",
+        "float",
+        "half",
+        "int",
+        "long",
+        "short",
+        "bool",
+        "bfloat16",
+    ]
     for name in simple_conversions:
-        unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
+        unsorted_tensor_method_hints[name].append(
+            "def {}(self) -> Tensor: ...".format(name)
+        )
 
     # pyi tensor methods don't currently include deprecated signatures for some reason
     # TODO: we should probably add them in
-    tensor_method_signatures = load_signatures(native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
-    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)
+    tensor_method_signatures = load_signatures(
+        native_functions,
+        deprecated_yaml_path,
+        method=True,
+        skip_deprecated=True,
+        pyi=True,
+    )
+    tensor_method_sig_groups = get_py_torch_functions(
+        tensor_method_signatures, method=True
+    )
 
     for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name):
         name = group.signature.name
@@ -537,13 +731,13 @@
                 namedtuples[tuple_name] = tuple_def
 
     for op in all_ops:
-        name = '__{}__'.format(op)
+        name = "__{}__".format(op)
         unsorted_tensor_method_hints[name] += sig_for_ops(name)
 
     tensor_method_hints = []
     for name, hints in sorted(unsorted_tensor_method_hints.items()):
         if len(hints) > 1:
-            hints = ['@overload\n' + h for h in hints]
+            hints = ["@overload\n" + h for h in hints]
         tensor_method_hints += hints
 
     # TODO: Missing type hints for nn
@@ -551,96 +745,174 @@
     # Generate namedtuple definitions
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-    namedtuple_defs = ['{} = {}'.format(name, defn) for name, defn in namedtuples.items()]
+    namedtuple_defs = [
+        "{} = {}".format(name, defn) for name, defn in namedtuples.items()
+    ]
 
     # Generate type signatures for legacy classes
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
     # TODO: These are deprecated, maybe we shouldn't type hint them
     legacy_storage_base_hints = []
-    dt = ('Double', 'Float', 'Long', 'Int',
-          'Short', 'Char', 'Byte', 'Bool',
-          'Half', 'BFloat16', 'ComplexDouble',
-          'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32', 'QUInt4x2', 'QUInt2x4')
+    dt = (
+        "Double",
+        "Float",
+        "Long",
+        "Int",
+        "Short",
+        "Char",
+        "Byte",
+        "Bool",
+        "Half",
+        "BFloat16",
+        "ComplexDouble",
+        "ComplexFloat",
+        "QUInt8",
+        "QInt8",
+        "QInt32",
+        "QUInt4x2",
+        "QUInt2x4",
+    )
     for c in dt:
-        legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))
+        legacy_storage_base_hints.append("class {}StorageBase(object): ...".format(c))
     for c in dt:
-        legacy_storage_base_hints.append('class Cuda{}StorageBase(object): ...'.format(c))
+        legacy_storage_base_hints.append(
+            "class Cuda{}StorageBase(object): ...".format(c)
+        )
 
     legacy_class_hints = []
-    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
-              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
-        legacy_class_hints.append('class {}(Tensor): ...'.format(c))
+    for c in (
+        "DoubleTensor",
+        "FloatTensor",
+        "LongTensor",
+        "IntTensor",
+        "ShortTensor",
+        "HalfTensor",
+        "CharTensor",
+        "ByteTensor",
+        "BoolTensor",
+    ):
+        legacy_class_hints.append("class {}(Tensor): ...".format(c))
 
     # Generate type signatures for dtype classes
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
     # TODO: don't explicitly list dtypes here; get it from canonical
     # source
-    dtype_class_hints = ['{}: dtype = ...'.format(n)
-                         for n in
-                         ['float32', 'float', 'float64', 'double', 'float16', 'bfloat16', 'half',
-                          'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
-                          'complex32', 'complex64', 'cfloat', 'complex128', 'cdouble',
-                          'quint8', 'qint8', 'qint32', 'bool', 'quint4x2', 'quint2x4']]
+    dtype_class_hints = [
+        "{}: dtype = ...".format(n)
+        for n in [
+            "float32",
+            "float",
+            "float64",
+            "double",
+            "float16",
+            "bfloat16",
+            "half",
+            "uint8",
+            "int8",
+            "int16",
+            "short",
+            "int32",
+            "int",
+            "int64",
+            "long",
+            "complex32",
+            "complex64",
+            "cfloat",
+            "complex128",
+            "cdouble",
+            "quint8",
+            "qint8",
+            "qint32",
+            "bool",
+            "quint4x2",
+            "quint2x4",
+        ]
+    ]
 
     # Generate __all__ directive
     # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
     # Include only the functions that contain hints, to prevent undefined
     # symbols to be included in the `__all__` directive.
-    hinted_function_names = [name for name, hint in unsorted_function_hints.items() if hint]
+    hinted_function_names = [
+        name for name, hint in unsorted_function_hints.items() if hint
+    ]
     all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
-    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
-    all_directive[0] = '__all__ = {}'.format(all_directive[0])
+    all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
+    all_directive[0] = "__all__ = {}".format(all_directive[0])
 
     # Write out the stub
     # ~~~~~~~~~~~~~~~~~~
 
     env = {
-        'namedtuple_defs': namedtuple_defs,
-        'function_hints': function_hints,
-        'tensor_method_hints': tensor_method_hints,
-        'legacy_class_hints': legacy_class_hints,
-        'legacy_storage_base_hints': legacy_storage_base_hints,
-        'dtype_class_hints': dtype_class_hints,
-        'all_directive': all_directive
+        "namedtuple_defs": namedtuple_defs,
+        "function_hints": function_hints,
+        "tensor_method_hints": tensor_method_hints,
+        "legacy_class_hints": legacy_class_hints,
+        "legacy_storage_base_hints": legacy_storage_base_hints,
+        "dtype_class_hints": dtype_class_hints,
+        "all_directive": all_directive,
     }
-    fm.write_with_template('torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
-        'generated_comment': '@' + 'generated from torch/_C/__init__.pyi.in',
-        **env,
-    })
-    fm.write_with_template('torch/_C/_VariableFunctions.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
-        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
-        **env,
-    })
-    fm.write_with_template('torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
-        'generated_comment': '@' + 'generated from torch/_C/_VariableFunctions.pyi.in',
-        **env,
-    })
-    fm.write_with_template('torch/return_types.pyi', 'torch/_C/return_types.pyi.in', lambda: {
-        'generated_comment': '@' + 'generated from torch/_C/return_types.pyi',
-        **env,
-    })
+    fm.write_with_template(
+        "torch/_C/__init__.pyi",
+        "torch/_C/__init__.pyi.in",
+        lambda: {
+            "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",
+            **env,
+        },
+    )
+    fm.write_with_template(
+        "torch/_C/_VariableFunctions.pyi",
+        "torch/_C/_VariableFunctions.pyi.in",
+        lambda: {
+            "generated_comment": "@"
+            + "generated from torch/_C/_VariableFunctions.pyi.in",
+            **env,
+        },
+    )
+    fm.write_with_template(
+        "torch/_VF.pyi",
+        "torch/_C/_VariableFunctions.pyi.in",
+        lambda: {
+            "generated_comment": "@"
+            + "generated from torch/_C/_VariableFunctions.pyi.in",
+            **env,
+        },
+    )
+    fm.write_with_template(
+        "torch/return_types.pyi",
+        "torch/_C/return_types.pyi.in",
+        lambda: {
+            "generated_comment": "@" + "generated from torch/_C/return_types.pyi",
+            **env,
+        },
+    )
     gen_nn_functional(fm)
 
 
 def main() -> None:
-    parser = argparse.ArgumentParser(
-        description='Generate type stubs for PyTorch')
-    parser.add_argument('--native-functions-path', metavar='NATIVE',
-                        default='aten/src/ATen/native/native_functions.yaml',
-                        help='path to native_functions.yaml')
-    parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED',
-                        default='tools/autograd/deprecated.yaml',
-                        help='path to deprecated.yaml')
-    parser.add_argument('--out', metavar='OUT',
-                        default='.',
-                        help='path to output directory')
+    parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch")
+    parser.add_argument(
+        "--native-functions-path",
+        metavar="NATIVE",
+        default="aten/src/ATen/native/native_functions.yaml",
+        help="path to native_functions.yaml",
+    )
+    parser.add_argument(
+        "--deprecated-functions-path",
+        metavar="DEPRECATED",
+        default="tools/autograd/deprecated.yaml",
+        help="path to deprecated.yaml",
+    )
+    parser.add_argument(
+        "--out", metavar="OUT", default=".", help="path to output directory"
+    )
     args = parser.parse_args()
-    fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False)
+    fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False)
     gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/render_junit.py b/tools/render_junit.py
index 28e617a..68adadd 100644
--- a/tools/render_junit.py
+++ b/tools/render_junit.py
@@ -16,12 +16,15 @@
 except ImportError:
     print("rich not found, for color output use 'pip install rich'")
 
+
 def parse_junit_reports(path_to_reports: str) -> List[TestCase]:  # type: ignore[no-any-unimported]
     def parse_file(path: str) -> List[TestCase]:  # type: ignore[no-any-unimported]
         try:
             return convert_junit_to_testcases(JUnitXml.fromfile(path))
         except Exception as err:
-            rich.print(f":Warning: [yellow]Warning[/yellow]: Failed to read {path}: {err}")
+            rich.print(
+                f":Warning: [yellow]Warning[/yellow]: Failed to read {path}: {err}"
+            )
             return []
 
     if not os.path.exists(path_to_reports):
@@ -46,6 +49,7 @@
             testcases.append(item)
     return testcases
 
+
 def render_tests(testcases: List[TestCase]) -> None:  # type: ignore[no-any-unimported]
     num_passed = 0
     num_skipped = 0
@@ -64,14 +68,15 @@
             else:
                 num_skipped += 1
                 continue
-            rich.print(f"{icon} [bold red]{testcase.classname}.{testcase.name}[/bold red]")
+            rich.print(
+                f"{icon} [bold red]{testcase.classname}.{testcase.name}[/bold red]"
+            )
             print(f"{result.text}")
     rich.print(f":white_check_mark: {num_passed} [green]Passed[green]")
     rich.print(f":dash: {num_skipped} [grey]Skipped[grey]")
     rich.print(f":rotating_light: {num_failed} [grey]Failed[grey]")
 
 
-
 def parse_args() -> Any:
     parser = argparse.ArgumentParser(
         description="Render xunit output for failed tests",
diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py
index fa892df..4bf1747 100644
--- a/tools/setup_helpers/__init__.py
+++ b/tools/setup_helpers/__init__.py
@@ -8,8 +8,8 @@
     for d in path:
         fname = os.path.join(d, thefile)
         fnames = [fname]
-        if sys.platform == 'win32':
-            exts = os.environ.get('PATHEXT', '').split(os.pathsep)
+        if sys.platform == "win32":
+            exts = os.environ.get("PATHEXT", "").split(os.pathsep)
             fnames += [fname + ext for ext in exts]
         for name in fnames:
             if os.access(name, os.F_OK | os.X_OK) and not os.path.isdir(name):
diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py
index c73224a..98d9566 100644
--- a/tools/setup_helpers/cmake.py
+++ b/tools/setup_helpers/cmake.py
@@ -1,7 +1,6 @@
 "Manages CMake."
 
 
-
 import multiprocessing
 import os
 import platform
@@ -13,7 +12,7 @@
 from typing import IO, Any, Dict, List, Optional, Union, cast
 
 from . import which
-from .env import (BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag)
+from .env import BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag
 from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR
 
 
@@ -21,20 +20,23 @@
     try:
         os.makedirs(d, exist_ok=True)
     except OSError as e:
-        raise RuntimeError(f"Failed to create folder {os.path.abspath(d)}: {e.strerror}") from e
+        raise RuntimeError(
+            f"Failed to create folder {os.path.abspath(d)}: {e.strerror}"
+        ) from e
 
 
 # Ninja
 # Use ninja if it is on the PATH. Previous version of PyTorch required the
 # ninja python package, but we no longer use it, so we do not have to import it
-USE_NINJA = (not check_negative_env_flag('USE_NINJA') and
-             which('ninja') is not None)
+USE_NINJA = not check_negative_env_flag("USE_NINJA") and which("ninja") is not None
 
 
 CMakeValue = Optional[Union[bool, str]]
 
 
-def convert_cmake_value_to_python_value(cmake_value: str, cmake_type: str) -> CMakeValue:
+def convert_cmake_value_to_python_value(
+    cmake_value: str, cmake_type: str
+) -> CMakeValue:
     r"""Convert a CMake value in a string form to a Python value.
 
     Args:
@@ -47,18 +49,24 @@
 
     cmake_type = cmake_type.upper()
     up_val = cmake_value.upper()
-    if cmake_type == 'BOOL':
+    if cmake_type == "BOOL":
         # https://gitlab.kitware.com/cmake/community/wikis/doc/cmake/VariablesListsStrings#boolean-values-in-cmake
-        return not (up_val in ('FALSE', 'OFF', 'N', 'NO', '0', '', 'NOTFOUND') or up_val.endswith('-NOTFOUND'))
-    elif cmake_type == 'FILEPATH':
-        if up_val.endswith('-NOTFOUND'):
+        return not (
+            up_val in ("FALSE", "OFF", "N", "NO", "0", "", "NOTFOUND")
+            or up_val.endswith("-NOTFOUND")
+        )
+    elif cmake_type == "FILEPATH":
+        if up_val.endswith("-NOTFOUND"):
             return None
         else:
             return cmake_value
     else:  # Directly return the cmake_value.
         return cmake_value
 
-def get_cmake_cache_variables_from_file(cmake_cache_file: IO[str]) -> Dict[str, CMakeValue]:
+
+def get_cmake_cache_variables_from_file(
+    cmake_cache_file: IO[str],
+) -> Dict[str, CMakeValue]:
     r"""Gets values in CMakeCache.txt into a dictionary.
 
     Args:
@@ -70,7 +78,7 @@
     results = dict()
     for i, line in enumerate(cmake_cache_file, 1):
         line = line.strip()
-        if not line or line.startswith(('#', '//')):
+        if not line or line.startswith(("#", "//")):
             # Blank or comment line, skip
             continue
 
@@ -83,19 +91,24 @@
         #   USE_CUDA:=ON
         #   Intel(R) MKL-DNN_SOURCE_DIR:STATIC=/path/to/pytorch/third_party/ideep/mkl-dnn
         #   "OpenMP_COMPILE_RESULT_CXX_openmp:experimental":INTERNAL=FALSE
-        matched = re.match(r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line)
+        matched = re.match(
+            r'("?)(.+?)\1(?::\s*([a-zA-Z_-][a-zA-Z0-9_-]*)?)?\s*=\s*(.*)', line
+        )
         if matched is None:  # Illegal line
-            raise ValueError('Unexpected line {} in {}: {}'.format(i, repr(cmake_cache_file), line))
+            raise ValueError(
+                "Unexpected line {} in {}: {}".format(i, repr(cmake_cache_file), line)
+            )
         _, variable, type_, value = matched.groups()
         if type_ is None:
-            type_ = ''
-        if type_.upper() in ('INTERNAL', 'STATIC'):
+            type_ = ""
+        if type_.upper() in ("INTERNAL", "STATIC"):
             # CMake internal variable, do not touch
             continue
         results[variable] = convert_cmake_value_to_python_value(value, type_)
 
     return results
 
+
 class CMake:
     "Manages cmake."
 
@@ -110,31 +123,36 @@
         Returns:
           string: The path to CMakeCache.txt.
         """
-        return os.path.join(self.build_dir, 'CMakeCache.txt')
+        return os.path.join(self.build_dir, "CMakeCache.txt")
 
     @staticmethod
     def _get_cmake_command() -> str:
         "Returns cmake command."
 
-        cmake_command = 'cmake'
+        cmake_command = "cmake"
         if IS_WINDOWS:
             return cmake_command
-        cmake3_version = CMake._get_version(which('cmake3'))
-        cmake_version = CMake._get_version(which('cmake'))
+        cmake3_version = CMake._get_version(which("cmake3"))
+        cmake_version = CMake._get_version(which("cmake"))
 
         _cmake_min_version = LooseVersion("3.10.0")
-        if all((ver is None or ver < _cmake_min_version for ver in [cmake_version, cmake3_version])):
-            raise RuntimeError('no cmake or cmake3 with version >= 3.10.0 found')
+        if all(
+            (
+                ver is None or ver < _cmake_min_version
+                for ver in [cmake_version, cmake3_version]
+            )
+        ):
+            raise RuntimeError("no cmake or cmake3 with version >= 3.10.0 found")
 
         if cmake3_version is None:
-            cmake_command = 'cmake'
+            cmake_command = "cmake"
         elif cmake_version is None:
-            cmake_command = 'cmake3'
+            cmake_command = "cmake3"
         else:
             if cmake3_version >= cmake_version:
-                cmake_command = 'cmake3'
+                cmake_command = "cmake3"
             else:
-                cmake_command = 'cmake'
+                cmake_command = "cmake"
         return cmake_command
 
     @staticmethod
@@ -143,16 +161,16 @@
 
         if cmd is None:
             return None
-        for line in check_output([cmd, '--version']).decode('utf-8').split('\n'):
-            if 'version' in line:
-                return LooseVersion(line.strip().split(' ')[2])
-        raise RuntimeError('no version found')
+        for line in check_output([cmd, "--version"]).decode("utf-8").split("\n"):
+            if "version" in line:
+                return LooseVersion(line.strip().split(" ")[2])
+        raise RuntimeError("no version found")
 
     def run(self, args: List[str], env: Dict[str, str]) -> None:
         "Executes cmake with arguments and an environment."
 
         command = [self._cmake_command] + args
-        print(' '.join(command))
+        print(" ".join(command))
         try:
             check_call(command, cwd=self.build_dir, env=env)
         except (CalledProcessError, KeyboardInterrupt) as e:
@@ -166,7 +184,7 @@
         "Adds definitions to a cmake argument list."
         for key, value in sorted(kwargs.items()):
             if value is not None:
-                args.append('-D{}={}'.format(key, value))
+                args.append("-D{}={}".format(key, value))
 
     def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
         r"""Gets values in CMakeCache.txt into a dictionary.
@@ -190,48 +208,54 @@
         if rerun and os.path.isfile(self._cmake_cache_file):
             os.remove(self._cmake_cache_file)
 
-        ninja_build_file = os.path.join(self.build_dir, 'build.ninja')
+        ninja_build_file = os.path.join(self.build_dir, "build.ninja")
         if os.path.exists(self._cmake_cache_file) and not (
-                USE_NINJA and not os.path.exists(ninja_build_file)):
+            USE_NINJA and not os.path.exists(ninja_build_file)
+        ):
             # Everything's in place. Do not rerun.
             return
 
         args = []
         if USE_NINJA:
             # Avoid conflicts in '-G' and the `CMAKE_GENERATOR`
-            os.environ['CMAKE_GENERATOR'] = 'Ninja'
-            args.append('-GNinja')
+            os.environ["CMAKE_GENERATOR"] = "Ninja"
+            args.append("-GNinja")
         elif IS_WINDOWS:
-            generator = os.getenv('CMAKE_GENERATOR', 'Visual Studio 15 2017')
-            supported = ['Visual Studio 15 2017', 'Visual Studio 16 2019']
+            generator = os.getenv("CMAKE_GENERATOR", "Visual Studio 15 2017")
+            supported = ["Visual Studio 15 2017", "Visual Studio 16 2019"]
             if generator not in supported:
-                print('Unsupported `CMAKE_GENERATOR`: ' + generator)
-                print('Please set it to one of the following values: ')
-                print('\n'.join(supported))
+                print("Unsupported `CMAKE_GENERATOR`: " + generator)
+                print("Please set it to one of the following values: ")
+                print("\n".join(supported))
                 sys.exit(1)
-            args.append('-G' + generator)
+            args.append("-G" + generator)
             toolset_dict = {}
-            toolset_version = os.getenv('CMAKE_GENERATOR_TOOLSET_VERSION')
+            toolset_version = os.getenv("CMAKE_GENERATOR_TOOLSET_VERSION")
             if toolset_version is not None:
-                toolset_dict['version'] = toolset_version
-                curr_toolset = os.getenv('VCToolsVersion')
+                toolset_dict["version"] = toolset_version
+                curr_toolset = os.getenv("VCToolsVersion")
                 if curr_toolset is None:
-                    print('When you specify `CMAKE_GENERATOR_TOOLSET_VERSION`, you must also '
-                          'activate the vs environment of this version. Please read the notes '
-                          'in the build steps carefully.')
+                    print(
+                        "When you specify `CMAKE_GENERATOR_TOOLSET_VERSION`, you must also "
+                        "activate the vs environment of this version. Please read the notes "
+                        "in the build steps carefully."
+                    )
                     sys.exit(1)
             if IS_64BIT:
-                if platform.machine() == 'ARM64':
-                    args.append('-A ARM64')
+                if platform.machine() == "ARM64":
+                    args.append("-A ARM64")
                 else:
-                    args.append('-Ax64')
-                    toolset_dict['host'] = 'x64'
+                    args.append("-Ax64")
+                    toolset_dict["host"] = "x64"
             if toolset_dict:
-                toolset_expr = ','.join(["{}={}".format(k, v) for k, v in toolset_dict.items()])
-                args.append('-T' + toolset_expr)
+                toolset_expr = ",".join(
+                    ["{}={}".format(k, v) for k, v in toolset_dict.items()]
+                )
+                args.append("-T" + toolset_expr)
 
-        base_dir = os.path.dirname(os.path.dirname(os.path.dirname(
-            os.path.abspath(__file__))))
+        base_dir = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        )
         install_dir = os.path.join(base_dir, "torch")
 
         _mkdir_p(install_dir)
@@ -246,49 +270,53 @@
             # Key: environment variable name. Value: Corresponding variable name to be passed to CMake. If you are
             # adding a new build option to this block: Consider making these two names identical and adding this option
             # in the block below.
-            '_GLIBCXX_USE_CXX11_ABI': 'GLIBCXX_USE_CXX11_ABI',
-            'CUDNN_LIB_DIR': 'CUDNN_LIBRARY',
-            'USE_CUDA_STATIC_LINK': 'CAFFE2_STATIC_LINK_CUDA',
+            "_GLIBCXX_USE_CXX11_ABI": "GLIBCXX_USE_CXX11_ABI",
+            "CUDNN_LIB_DIR": "CUDNN_LIBRARY",
+            "USE_CUDA_STATIC_LINK": "CAFFE2_STATIC_LINK_CUDA",
         }
-        additional_options.update({
-            # Build options that have the same environment variable name and CMake variable name and that do not start
-            # with "BUILD_", "USE_", or "CMAKE_". If you are adding a new build option, also make sure you add it to
-            # CMakeLists.txt.
-            var: var for var in
-            ('BLAS',
-             'BUILDING_WITH_TORCH_LIBS',
-             'CUDA_HOST_COMILER',
-             'CUDA_NVCC_EXECUTABLE',
-             'CUDA_SEPARABLE_COMPILATION',
-             'CUDNN_LIBRARY',
-             'CUDNN_INCLUDE_DIR',
-             'CUDNN_ROOT',
-             'EXPERIMENTAL_SINGLE_THREAD_POOL',
-             'INSTALL_TEST',
-             'JAVA_HOME',
-             'INTEL_MKL_DIR',
-             'INTEL_OMP_DIR',
-             'MKL_THREADING',
-             'MKLDNN_CPU_RUNTIME',
-             'MSVC_Z7_OVERRIDE',
-             'CAFFE2_USE_MSVC_STATIC_RUNTIME',
-             'Numa_INCLUDE_DIR',
-             'Numa_LIBRARIES',
-             'ONNX_ML',
-             'ONNX_NAMESPACE',
-             'ATEN_THREADING',
-             'WERROR',
-             'OPENSSL_ROOT_DIR',
-             'STATIC_DISPATCH_BACKEND')
-        })
+        additional_options.update(
+            {
+                # Build options that have the same environment variable name and CMake variable name and that do not start
+                # with "BUILD_", "USE_", or "CMAKE_". If you are adding a new build option, also make sure you add it to
+                # CMakeLists.txt.
+                var: var
+                for var in (
+                    "BLAS",
+                    "BUILDING_WITH_TORCH_LIBS",
+                    "CUDA_HOST_COMILER",
+                    "CUDA_NVCC_EXECUTABLE",
+                    "CUDA_SEPARABLE_COMPILATION",
+                    "CUDNN_LIBRARY",
+                    "CUDNN_INCLUDE_DIR",
+                    "CUDNN_ROOT",
+                    "EXPERIMENTAL_SINGLE_THREAD_POOL",
+                    "INSTALL_TEST",
+                    "JAVA_HOME",
+                    "INTEL_MKL_DIR",
+                    "INTEL_OMP_DIR",
+                    "MKL_THREADING",
+                    "MKLDNN_CPU_RUNTIME",
+                    "MSVC_Z7_OVERRIDE",
+                    "CAFFE2_USE_MSVC_STATIC_RUNTIME",
+                    "Numa_INCLUDE_DIR",
+                    "Numa_LIBRARIES",
+                    "ONNX_ML",
+                    "ONNX_NAMESPACE",
+                    "ATEN_THREADING",
+                    "WERROR",
+                    "OPENSSL_ROOT_DIR",
+                    "STATIC_DISPATCH_BACKEND",
+                )
+            }
+        )
 
         # Aliases which are lower priority than their canonical option
         low_priority_aliases = {
-            'CUDA_HOST_COMPILER': 'CMAKE_CUDA_HOST_COMPILER',
-            'CUDAHOSTCXX': 'CUDA_HOST_COMPILER',
-            'CMAKE_CUDA_HOST_COMPILER': 'CUDA_HOST_COMPILER',
-            'CMAKE_CUDA_COMPILER': 'CUDA_NVCC_EXECUTABLE',
-            'CUDACXX': 'CUDA_NVCC_EXECUTABLE'
+            "CUDA_HOST_COMPILER": "CMAKE_CUDA_HOST_COMPILER",
+            "CUDAHOSTCXX": "CUDA_HOST_COMPILER",
+            "CMAKE_CUDA_HOST_COMPILER": "CUDA_HOST_COMPILER",
+            "CMAKE_CUDA_COMPILER": "CUDA_NVCC_EXECUTABLE",
+            "CUDACXX": "CUDA_NVCC_EXECUTABLE",
         }
         for var, val in my_env.items():
             # We currently pass over all environment variables that start with "BUILD_", "USE_", and "CMAKE_". This is
@@ -299,7 +327,9 @@
             true_var = additional_options.get(var)
             if true_var is not None:
                 build_options[true_var] = val
-            elif var.startswith(('BUILD_', 'USE_', 'CMAKE_')) or var.endswith(('EXITCODE', 'EXITCODE__TRYRUN_OUTPUT')):
+            elif var.startswith(("BUILD_", "USE_", "CMAKE_")) or var.endswith(
+                ("EXITCODE", "EXITCODE__TRYRUN_OUTPUT")
+            ):
                 build_options[var] = val
 
             if var in low_priority_aliases:
@@ -308,68 +338,81 @@
                     build_options[key] = val
 
         # The default value cannot be easily obtained in CMakeLists.txt. We set it here.
-        py_lib_path = sysconfig.get_path('purelib')
-        cmake_prefix_path = build_options.get('CMAKE_PREFIX_PATH', None)
+        py_lib_path = sysconfig.get_path("purelib")
+        cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH", None)
         if cmake_prefix_path:
             build_options["CMAKE_PREFIX_PATH"] = (
                 cast(str, py_lib_path) + ";" + cast(str, cmake_prefix_path)
             )
         else:
-            build_options['CMAKE_PREFIX_PATH'] = py_lib_path
+            build_options["CMAKE_PREFIX_PATH"] = py_lib_path
 
         # Some options must be post-processed. Ideally, this list will be shrunk to only one or two options in the
         # future, as CMake can detect many of these libraries pretty comfortably. We have them here for now before CMake
         # integration is completed. They appear here not in the CMake.defines call below because they start with either
         # "BUILD_" or "USE_" and must be overwritten here.
-        build_options.update({
-            # Note: Do not add new build options to this dict if it is directly read from environment variable -- you
-            # only need to add one in `CMakeLists.txt`. All build options that start with "BUILD_", "USE_", or "CMAKE_"
-            # are automatically passed to CMake; For other options you can add to additional_options above.
-            'BUILD_PYTHON': build_python,
-            'BUILD_TEST': build_test,
-            # Most library detection should go to CMake script, except this one, which Python can do a much better job
-            # due to NumPy's inherent Pythonic nature.
-            'USE_NUMPY': USE_NUMPY,
-        })
+        build_options.update(
+            {
+                # Note: Do not add new build options to this dict if it is directly read from environment variable -- you
+                # only need to add one in `CMakeLists.txt`. All build options that start with "BUILD_", "USE_", or "CMAKE_"
+                # are automatically passed to CMake; For other options you can add to additional_options above.
+                "BUILD_PYTHON": build_python,
+                "BUILD_TEST": build_test,
+                # Most library detection should go to CMake script, except this one, which Python can do a much better job
+                # due to NumPy's inherent Pythonic nature.
+                "USE_NUMPY": USE_NUMPY,
+            }
+        )
 
         # Options starting with CMAKE_
         cmake__options = {
-            'CMAKE_INSTALL_PREFIX': install_dir,
+            "CMAKE_INSTALL_PREFIX": install_dir,
         }
 
         # We set some CMAKE_* options in our Python build code instead of relying on the user's direct settings. Emit an
         # error if the user also attempts to set these CMAKE options directly.
         specified_cmake__options = set(build_options).intersection(cmake__options)
         if len(specified_cmake__options) > 0:
-            print(', '.join(specified_cmake__options) +
-                  ' should not be specified in the environment variable. They are directly set by PyTorch build script.')
+            print(
+                ", ".join(specified_cmake__options)
+                + " should not be specified in the environment variable. They are directly set by PyTorch build script."
+            )
             sys.exit(1)
         build_options.update(cmake__options)
 
-        CMake.defines(args,
-                      PYTHON_EXECUTABLE=sys.executable,
-                      PYTHON_LIBRARY=cmake_python_library,
-                      PYTHON_INCLUDE_DIR=sysconfig.get_path('include'),
-                      TORCH_BUILD_VERSION=version,
-                      NUMPY_INCLUDE_DIR=NUMPY_INCLUDE_DIR,
-                      **build_options)
+        CMake.defines(
+            args,
+            PYTHON_EXECUTABLE=sys.executable,
+            PYTHON_LIBRARY=cmake_python_library,
+            PYTHON_INCLUDE_DIR=sysconfig.get_path("include"),
+            TORCH_BUILD_VERSION=version,
+            NUMPY_INCLUDE_DIR=NUMPY_INCLUDE_DIR,
+            **build_options,
+        )
 
-        expected_wrapper = '/usr/local/opt/ccache/libexec'
+        expected_wrapper = "/usr/local/opt/ccache/libexec"
         if IS_DARWIN and os.path.exists(expected_wrapper):
-            if 'CMAKE_C_COMPILER' not in build_options and 'CC' not in os.environ:
+            if "CMAKE_C_COMPILER" not in build_options and "CC" not in os.environ:
                 CMake.defines(args, CMAKE_C_COMPILER="{}/gcc".format(expected_wrapper))
-            if 'CMAKE_CXX_COMPILER' not in build_options and 'CXX' not in os.environ:
-                CMake.defines(args, CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper))
+            if "CMAKE_CXX_COMPILER" not in build_options and "CXX" not in os.environ:
+                CMake.defines(
+                    args, CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper)
+                )
 
         for env_var_name in my_env:
-            if env_var_name.startswith('gh'):
+            if env_var_name.startswith("gh"):
                 # github env vars use utf-8, on windows, non-ascii code may
                 # cause problem, so encode first
                 try:
                     my_env[env_var_name] = str(my_env[env_var_name].encode("utf-8"))
                 except UnicodeDecodeError as e:
-                    shex = ':'.join('{:02x}'.format(ord(c)) for c in my_env[env_var_name])
-                    print('Invalid ENV[{}] = {}'.format(env_var_name, shex), file=sys.stderr)
+                    shex = ":".join(
+                        "{:02x}".format(ord(c)) for c in my_env[env_var_name]
+                    )
+                    print(
+                        "Invalid ENV[{}] = {}".format(env_var_name, shex),
+                        file=sys.stderr,
+                    )
                     print(e, file=sys.stderr)
         # According to the CMake manual, we should pass the arguments first,
         # and put the directory as the last element. Otherwise, these flags
@@ -385,7 +428,14 @@
 
         from .env import build_type
 
-        build_args = ['--build', '.', '--target', 'install', '--config', build_type.build_type_string]
+        build_args = [
+            "--build",
+            ".",
+            "--target",
+            "install",
+            "--config",
+            build_type.build_type_string,
+        ]
 
         # Determine the parallelism according to the following
         # priorities:
@@ -395,7 +445,7 @@
 
         # Allow the user to set parallelism explicitly. If unset,
         # we'll try to figure it out.
-        max_jobs = os.getenv('MAX_JOBS')
+        max_jobs = os.getenv("MAX_JOBS")
 
         if max_jobs is not None or not USE_NINJA:
             # Ninja is capable of figuring out the parallelism on its
@@ -414,10 +464,10 @@
             # build_args += ['-j', max_jobs] would be sufficient by
             # then. Until then, we use "--" to pass parameters to the
             # underlying build system.
-            build_args += ['--']
+            build_args += ["--"]
             if IS_WINDOWS and not USE_NINJA:
                 # We are likely using msbuild here
-                build_args += ['/p:CL_MPCount={}'.format(max_jobs)]
+                build_args += ["/p:CL_MPCount={}".format(max_jobs)]
             else:
-                build_args += ['-j', max_jobs]
+                build_args += ["-j", max_jobs]
         self.run(build_args, my_env)
diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py
index d658acd..bf693ca 100644
--- a/tools/setup_helpers/env.py
+++ b/tools/setup_helpers/env.py
@@ -6,37 +6,41 @@
 from typing import Iterable, List, Optional, cast
 
 
-IS_WINDOWS = (platform.system() == 'Windows')
-IS_DARWIN = (platform.system() == 'Darwin')
-IS_LINUX = (platform.system() == 'Linux')
+IS_WINDOWS = platform.system() == "Windows"
+IS_DARWIN = platform.system() == "Darwin"
+IS_LINUX = platform.system() == "Linux"
 
-IS_CONDA = 'conda' in sys.version or 'Continuum' in sys.version or any([x.startswith('CONDA') for x in os.environ])
-CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..')
+IS_CONDA = (
+    "conda" in sys.version
+    or "Continuum" in sys.version
+    or any([x.startswith("CONDA") for x in os.environ])
+)
+CONDA_DIR = os.path.join(os.path.dirname(sys.executable), "..")
 
-IS_64BIT = (struct.calcsize("P") == 8)
+IS_64BIT = struct.calcsize("P") == 8
 
-BUILD_DIR = 'build'
+BUILD_DIR = "build"
 
 
-def check_env_flag(name: str, default: str = '') -> bool:
-    return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
+def check_env_flag(name: str, default: str = "") -> bool:
+    return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
 
 
-def check_negative_env_flag(name: str, default: str = '') -> bool:
-    return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N']
+def check_negative_env_flag(name: str, default: str = "") -> bool:
+    return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
 
 
 def gather_paths(env_vars: Iterable[str]) -> List[str]:
-    return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars)))
+    return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
 
 
 def lib_paths_from_base(base_path: str) -> List[str]:
-    return [os.path.join(base_path, s) for s in ['lib/x64', 'lib', 'lib64']]
+    return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
 
 
 # We promised that CXXFLAGS should also be affected by CFLAGS
-if 'CFLAGS' in os.environ and 'CXXFLAGS' not in os.environ:
-    os.environ['CXXFLAGS'] = os.environ['CFLAGS']
+if "CFLAGS" in os.environ and "CXXFLAGS" not in os.environ:
+    os.environ["CXXFLAGS"] = os.environ["CFLAGS"]
 
 
 class BuildType(object):
@@ -55,39 +59,40 @@
             self.build_type_string = cmake_build_type_env
             return
 
-        cmake_cache_txt = os.path.join(BUILD_DIR, 'CMakeCache.txt')
+        cmake_cache_txt = os.path.join(BUILD_DIR, "CMakeCache.txt")
         if os.path.isfile(cmake_cache_txt):
             # Found CMakeCache.txt. Use the build type specified in it.
             from .cmake import get_cmake_cache_variables_from_file
+
             with open(cmake_cache_txt) as f:
                 cmake_cache_vars = get_cmake_cache_variables_from_file(f)
             # Normally it is anti-pattern to determine build type from CMAKE_BUILD_TYPE because it is not used for
             # multi-configuration build tools, such as Visual Studio and XCode. But since we always communicate with
             # CMake using CMAKE_BUILD_TYPE from our Python scripts, this is OK here.
-            self.build_type_string = cast(str, cmake_cache_vars['CMAKE_BUILD_TYPE'])
+            self.build_type_string = cast(str, cmake_cache_vars["CMAKE_BUILD_TYPE"])
         else:
-            self.build_type_string = os.environ.get('CMAKE_BUILD_TYPE', 'Release')
+            self.build_type_string = os.environ.get("CMAKE_BUILD_TYPE", "Release")
 
     def is_debug(self) -> bool:
         "Checks Debug build."
-        return self.build_type_string == 'Debug'
+        return self.build_type_string == "Debug"
 
     def is_rel_with_deb_info(self) -> bool:
         "Checks RelWithDebInfo build."
-        return self.build_type_string == 'RelWithDebInfo'
+        return self.build_type_string == "RelWithDebInfo"
 
     def is_release(self) -> bool:
         "Checks Release build."
-        return self.build_type_string == 'Release'
+        return self.build_type_string == "Release"
 
 
 # hotpatch environment variable 'CMAKE_BUILD_TYPE'. 'CMAKE_BUILD_TYPE' always prevails over DEBUG or REL_WITH_DEB_INFO.
-if 'CMAKE_BUILD_TYPE' not in os.environ:
-    if check_env_flag('DEBUG'):
-        os.environ['CMAKE_BUILD_TYPE'] = 'Debug'
-    elif check_env_flag('REL_WITH_DEB_INFO'):
-        os.environ['CMAKE_BUILD_TYPE'] = 'RelWithDebInfo'
+if "CMAKE_BUILD_TYPE" not in os.environ:
+    if check_env_flag("DEBUG"):
+        os.environ["CMAKE_BUILD_TYPE"] = "Debug"
+    elif check_env_flag("REL_WITH_DEB_INFO"):
+        os.environ["CMAKE_BUILD_TYPE"] = "RelWithDebInfo"
     else:
-        os.environ['CMAKE_BUILD_TYPE'] = 'Release'
+        os.environ["CMAKE_BUILD_TYPE"] = "Release"
 
 build_type = BuildType()
diff --git a/tools/setup_helpers/gen_version_header.py b/tools/setup_helpers/gen_version_header.py
index 963db1d..bd576af 100644
--- a/tools/setup_helpers/gen_version_header.py
+++ b/tools/setup_helpers/gen_version_header.py
@@ -76,7 +76,9 @@
         help="Path to the template (i.e. version.h.in)",
     )
     parser.add_argument(
-        "--version-path", required=True, help="Path to the file specifying the version",
+        "--version-path",
+        required=True,
+        help="Path to the file specifying the version",
     )
     parser.add_argument(
         "--output-path",
diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py
index 7bfe3a3..8685133 100644
--- a/tools/setup_helpers/generate_code.py
+++ b/tools/setup_helpers/generate_code.py
@@ -11,15 +11,15 @@
 except ImportError:
     from yaml import SafeLoader as YamlLoader  # type: ignore[misc]
 
-source_files = {'.py', '.cpp', '.h'}
+source_files = {".py", ".cpp", ".h"}
 
-NATIVE_FUNCTIONS_PATH = 'aten/src/ATen/native/native_functions.yaml'
+NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
 
 # TODO: This is a little inaccurate, because it will also pick
 # up setup_helper scripts which don't affect code generation
 def all_generator_source() -> List[str]:
     r = []
-    for directory, _, filenames in os.walk('tools'):
+    for directory, _, filenames in os.walk("tools"):
         for f in filenames:
             if os.path.splitext(f)[1] in source_files:
                 full = os.path.join(directory, f)
@@ -27,25 +27,26 @@
     return sorted(r)
 
 
-def generate_code(ninja_global: Optional[str] = None,
-                  native_functions_path: Optional[str] = None,
-                  install_dir: Optional[str] = None,
-                  subset: Optional[str] = None,
-                  disable_autograd: bool = False,
-                  force_schema_registration: bool = False,
-                  operator_selector: Any = None) -> None:
+def generate_code(
+    ninja_global: Optional[str] = None,
+    native_functions_path: Optional[str] = None,
+    install_dir: Optional[str] = None,
+    subset: Optional[str] = None,
+    disable_autograd: bool = False,
+    force_schema_registration: bool = False,
+    operator_selector: Any = None,
+) -> None:
     from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
     from tools.autograd.gen_annotated_fn_args import gen_annotated
     from tools.codegen.selective_build.selector import SelectiveBuilder
 
-
     # Build ATen based Variable classes
     if install_dir is None:
-        install_dir = 'torch/csrc'
-        python_install_dir = 'torch/testing/_internal/generated'
+        install_dir = "torch/csrc"
+        python_install_dir = "torch/testing/_internal/generated"
     else:
         python_install_dir = install_dir
-    autograd_gen_dir = os.path.join(install_dir, 'autograd', 'generated')
+    autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")
     for d in (autograd_gen_dir, python_install_dir):
         os.makedirs(d, exist_ok=True)
     autograd_dir = os.fspath(pathlib.Path(__file__).parent.parent / "autograd")
@@ -54,7 +55,8 @@
         gen_autograd_python(
             native_functions_path or NATIVE_FUNCTIONS_PATH,
             autograd_gen_dir,
-            autograd_dir)
+            autograd_dir,
+        )
 
     if operator_selector is None:
         operator_selector = SelectiveBuilder.get_nop_selector()
@@ -73,17 +75,18 @@
         gen_annotated(
             native_functions_path or NATIVE_FUNCTIONS_PATH,
             python_install_dir,
-            autograd_dir)
+            autograd_dir,
+        )
 
 
 def get_selector_from_legacy_operator_selection_list(
-        selected_op_list_path: str,
+    selected_op_list_path: str,
 ) -> Any:
-    with open(selected_op_list_path, 'r') as f:
+    with open(selected_op_list_path, "r") as f:
         # strip out the overload part
         # It's only for legacy config - do NOT copy this code!
         selected_op_list = {
-            opname.split('.', 1)[0] for opname in yaml.load(f, Loader=YamlLoader)
+            opname.split(".", 1)[0] for opname in yaml.load(f, Loader=YamlLoader)
         }
 
     # Internal build doesn't use this flag any more. Only used by OSS
@@ -96,6 +99,7 @@
     is_used_for_training = True
 
     from tools.codegen.selective_build.selector import SelectiveBuilder
+
     selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
         selected_op_list,
         is_root_operator,
@@ -114,10 +118,12 @@
     sys.path.insert(0, root)
     from tools.codegen.selective_build.selector import SelectiveBuilder
 
-    assert not (selected_op_list_path is not None and
-                operators_yaml_path is not None), \
-        ("Expected at most one of selected_op_list_path and " +
-         "operators_yaml_path to be set.")
+    assert not (
+        selected_op_list_path is not None and operators_yaml_path is not None
+    ), (
+        "Expected at most one of selected_op_list_path and "
+        + "operators_yaml_path to be set."
+    )
 
     if selected_op_list_path is None and operators_yaml_path is None:
         return SelectiveBuilder.get_nop_selector()
@@ -128,43 +134,43 @@
 
 
 def main() -> None:
-    parser = argparse.ArgumentParser(description='Autogenerate code')
-    parser.add_argument('--native-functions-path')
-    parser.add_argument('--ninja-global')
-    parser.add_argument('--install_dir')
+    parser = argparse.ArgumentParser(description="Autogenerate code")
+    parser.add_argument("--native-functions-path")
+    parser.add_argument("--ninja-global")
+    parser.add_argument("--install_dir")
     parser.add_argument(
-        '--subset',
-        help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
+        "--subset",
+        help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.',
     )
     parser.add_argument(
-        '--disable-autograd',
+        "--disable-autograd",
         default=False,
-        action='store_true',
-        help='It can skip generating autograd related code when the flag is set',
+        action="store_true",
+        help="It can skip generating autograd related code when the flag is set",
     )
     parser.add_argument(
-        '--selected-op-list-path',
-        help='Path to the YAML file that contains the list of operators to include for custom build.',
+        "--selected-op-list-path",
+        help="Path to the YAML file that contains the list of operators to include for custom build.",
     )
     parser.add_argument(
-        '--operators_yaml_path',
-        help='Path to the model YAML file that contains the list of operators to include for custom build.',
+        "--operators_yaml_path",
+        help="Path to the model YAML file that contains the list of operators to include for custom build.",
     )
     parser.add_argument(
-        '--force_schema_registration',
-        action='store_true',
-        help='force it to generate schema-only registrations for ops that are not'
-        'listed on --selected-op-list'
+        "--force_schema_registration",
+        action="store_true",
+        help="force it to generate schema-only registrations for ops that are not"
+        "listed on --selected-op-list",
     )
     parser.add_argument(
-        '--gen_lazy_ts_backend',
-        action='store_true',
-        help='Enable generation of the torch::lazy TorchScript backend'
+        "--gen_lazy_ts_backend",
+        action="store_true",
+        help="Enable generation of the torch::lazy TorchScript backend",
     )
     parser.add_argument(
-        '--per_operator_headers',
-        action='store_true',
-        help='Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built'
+        "--per_operator_headers",
+        action="store_true",
+        help="Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built",
     )
     options = parser.parse_args()
 
@@ -176,12 +182,14 @@
         options.disable_autograd,
         options.force_schema_registration,
         # options.selected_op_list
-        operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
+        operator_selector=get_selector(
+            options.selected_op_list_path, options.operators_yaml_path
+        ),
     )
 
     if options.gen_lazy_ts_backend:
         aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
-        ts_backend_yaml = os.path.join(aten_path, 'native/ts_native_functions.yaml')
+        ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
         ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
         ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
         if options.install_dir is None:
@@ -189,22 +197,29 @@
         lazy_install_dir = os.path.join(options.install_dir, "lazy/generated")
         os.makedirs(lazy_install_dir, exist_ok=True)
 
-        assert os.path.isfile(ts_backend_yaml), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
-        assert os.path.isfile(ts_native_functions), f"Unable to access {ts_native_functions}"
+        assert os.path.isfile(
+            ts_backend_yaml
+        ), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
+        assert os.path.isfile(
+            ts_native_functions
+        ), f"Unable to access {ts_native_functions}"
         from tools.codegen.gen_lazy_tensor import run_gen_lazy_tensor
         from tools.codegen.dest.lazy_ir import TSLazyIR
-        run_gen_lazy_tensor(aten_path=aten_path,
-                            source_yaml=ts_backend_yaml,
-                            backend_name="TorchScript",
-                            output_dir=lazy_install_dir,
-                            dry_run=False,
-                            impl_path=ts_native_functions,
-                            node_base="TsNode",
-                            node_base_hdr=ts_node_base,
-                            build_in_tree=True,
-                            lazy_ir_cls=TSLazyIR,
-                            per_operator_headers=options.per_operator_headers,
-                            gen_forced_fallback_code=True)
+
+        run_gen_lazy_tensor(
+            aten_path=aten_path,
+            source_yaml=ts_backend_yaml,
+            backend_name="TorchScript",
+            output_dir=lazy_install_dir,
+            dry_run=False,
+            impl_path=ts_native_functions,
+            node_base="TsNode",
+            node_base_hdr=ts_node_base,
+            build_in_tree=True,
+            lazy_ir_cls=TSLazyIR,
+            per_operator_headers=options.per_operator_headers,
+            gen_forced_fallback_code=True,
+        )
 
 
 if __name__ == "__main__":
diff --git a/tools/setup_helpers/numpy_.py b/tools/setup_helpers/numpy_.py
index 882de4b..e93fcfd 100644
--- a/tools/setup_helpers/numpy_.py
+++ b/tools/setup_helpers/numpy_.py
@@ -10,7 +10,7 @@
 
 # Set USE_NUMPY to what the user wants, because even if we fail here, cmake
 # will check for the presence of NumPy again (`cmake/Dependencies.cmake`).
-USE_NUMPY = not check_negative_env_flag('USE_NUMPY')
+USE_NUMPY = not check_negative_env_flag("USE_NUMPY")
 NUMPY_INCLUDE_DIR = None
 
 if USE_NUMPY:
diff --git a/tools/shared/cwrap_common.py b/tools/shared/cwrap_common.py
index 01ff97a..42548b9 100644
--- a/tools/shared/cwrap_common.py
+++ b/tools/shared/cwrap_common.py
@@ -6,17 +6,18 @@
 
 Arg = Dict[str, Any]
 
+
 def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
     new_args = []
     for arg in args:
         # Simple arg declaration of form "<type> <name>"
         if isinstance(arg, str):
-            t, _, name = arg.partition(' ')
-            new_args.append({'type': t, 'name': name})
+            t, _, name = arg.partition(" ")
+            new_args.append({"type": t, "name": name})
         elif isinstance(arg, dict):
-            if 'arg' in arg:
-                arg['type'], _, arg['name'] = arg['arg'].partition(' ')
-                del arg['arg']
+            if "arg" in arg:
+                arg["type"], _, arg["name"] = arg["arg"].partition(" ")
+                del arg["arg"]
             new_args.append(arg)
         else:
             raise AssertionError()
@@ -27,52 +28,66 @@
 
 
 def set_declaration_defaults(declaration: Declaration) -> None:
-    if 'schema_string' not in declaration:
+    if "schema_string" not in declaration:
         # This happens for legacy TH bindings like
         # _thnn_conv_depthwise2d_backward
-        declaration['schema_string'] = ''
-    declaration.setdefault('arguments', [])
-    declaration.setdefault('return', 'void')
-    if 'cname' not in declaration:
-        declaration['cname'] = declaration['name']
-    if 'backends' not in declaration:
-        declaration['backends'] = ['CPU', 'CUDA']
-    assert 'api_name' not in declaration
-    declaration['api_name'] = declaration['name']
+        declaration["schema_string"] = ""
+    declaration.setdefault("arguments", [])
+    declaration.setdefault("return", "void")
+    if "cname" not in declaration:
+        declaration["cname"] = declaration["name"]
+    if "backends" not in declaration:
+        declaration["backends"] = ["CPU", "CUDA"]
+    assert "api_name" not in declaration
+    declaration["api_name"] = declaration["name"]
     # NB: keep this in sync with gen_autograd.py
-    if declaration.get('overload_name'):
-        declaration['type_wrapper_name'] = "{}_{}".format(
-            declaration['name'], declaration['overload_name'])
+    if declaration.get("overload_name"):
+        declaration["type_wrapper_name"] = "{}_{}".format(
+            declaration["name"], declaration["overload_name"]
+        )
     else:
-        declaration['type_wrapper_name'] = declaration['name']
+        declaration["type_wrapper_name"] = declaration["name"]
     # TODO: Uggggh, parsing the schema string here, really???
-    declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
-    if declaration['schema_string']:
-        declaration['unqual_schema_string'] = declaration['schema_string'].split('::')[1]
-        declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
+    declaration["operator_name_with_overload"] = declaration["schema_string"].split(
+        "("
+    )[0]
+    if declaration["schema_string"]:
+        declaration["unqual_schema_string"] = declaration["schema_string"].split("::")[
+            1
+        ]
+        declaration["unqual_operator_name_with_overload"] = declaration[
+            "operator_name_with_overload"
+        ].split("::")[1]
     else:
-        declaration['unqual_schema_string'] = ''
-        declaration['unqual_operator_name_with_overload'] = ''
+        declaration["unqual_schema_string"] = ""
+        declaration["unqual_operator_name_with_overload"] = ""
     # Simulate multiple dispatch, even if it's not necessary
-    if 'options' not in declaration:
-        declaration['options'] = [{
-            'arguments': copy.deepcopy(declaration['arguments']),
-            'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
-        }]
-        del declaration['arguments']
-        del declaration['schema_order_arguments']
+    if "options" not in declaration:
+        declaration["options"] = [
+            {
+                "arguments": copy.deepcopy(declaration["arguments"]),
+                "schema_order_arguments": copy.deepcopy(
+                    declaration["schema_order_arguments"]
+                ),
+            }
+        ]
+        del declaration["arguments"]
+        del declaration["schema_order_arguments"]
     # Parse arguments (some of them can be strings)
-    for option in declaration['options']:
-        option['arguments'] = parse_arguments(option['arguments'])
-        option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
+    for option in declaration["options"]:
+        option["arguments"] = parse_arguments(option["arguments"])
+        option["schema_order_arguments"] = parse_arguments(
+            option["schema_order_arguments"]
+        )
     # Propagate defaults from declaration to options
-    for option in declaration['options']:
+    for option in declaration["options"]:
         for k, v in declaration.items():
             # TODO(zach): why does cwrap not propagate 'name'? I need it
             # propagaged for ATen
-            if k != 'options':
+            if k != "options":
                 option.setdefault(k, v)
 
+
 # TODO(zach): added option to remove keyword handling for C++ which cannot
 # support it.
 
@@ -86,38 +101,41 @@
     remove_self: bool,
 ) -> List[Option]:
     def exclude_arg(arg: Arg) -> bool:
-        return arg['type'] == 'CONSTANT'  # type: ignore[no-any-return]
+        return arg["type"] == "CONSTANT"  # type: ignore[no-any-return]
 
     def exclude_arg_with_self_check(arg: Arg) -> bool:
-        return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
+        return exclude_arg(arg) or (remove_self and arg["name"] == "self")
 
     def signature(option: Option, num_kwarg_only: int) -> str:
         if num_kwarg_only == 0:
             kwarg_only_count = None
         else:
             kwarg_only_count = -num_kwarg_only
-        arg_signature = '#'.join(
-            type_to_signature.get(arg['type'], arg['type'])
-            for arg in option['arguments'][:kwarg_only_count]
-            if not exclude_arg_with_self_check(arg))
+        arg_signature = "#".join(
+            type_to_signature.get(arg["type"], arg["type"])
+            for arg in option["arguments"][:kwarg_only_count]
+            if not exclude_arg_with_self_check(arg)
+        )
         if kwarg_only_count is None:
             return arg_signature
-        kwarg_only_signature = '#'.join(
-            arg['name'] + '#' + arg['type']
-            for arg in option['arguments'][kwarg_only_count:]
-            if not exclude_arg(arg))
+        kwarg_only_signature = "#".join(
+            arg["name"] + "#" + arg["type"]
+            for arg in option["arguments"][kwarg_only_count:]
+            if not exclude_arg(arg)
+        )
         return arg_signature + "#-#" + kwarg_only_signature
+
     seen_signatures = set()
     unique = []
     for option in options:
         # if only check num_kwarg_only == 0 if allow_kwarg == False
-        limit = len(option['arguments']) if allow_kwarg else 0
+        limit = len(option["arguments"]) if allow_kwarg else 0
         for num_kwarg_only in range(0, limit + 1):
             sig = signature(option, num_kwarg_only)
             if sig not in seen_signatures:
                 if num_kwarg_only > 0:
-                    for arg in option['arguments'][-num_kwarg_only:]:
-                        arg['kwarg_only'] = True
+                    for arg in option["arguments"][-num_kwarg_only:]:
+                        arg["kwarg_only"] = True
                 unique.append(option)
                 seen_signatures.add(sig)
                 break
@@ -126,49 +144,48 @@
 
 def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
     def num_args(option: Option) -> int:
-        return len(option['arguments'])
-    declaration['options'].sort(key=num_args, reverse=reverse)
+        return len(option["arguments"])
+
+    declaration["options"].sort(key=num_args, reverse=reverse)
 
 
 class Function(object):
-
     def __init__(self, name: str) -> None:
         self.name = name
-        self.arguments: List['Argument'] = []
+        self.arguments: List["Argument"] = []
 
-    def add_argument(self, arg: 'Argument') -> None:
+    def add_argument(self, arg: "Argument") -> None:
         assert isinstance(arg, Argument)
         self.arguments.append(arg)
 
     def __repr__(self) -> str:
-        return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')'
+        return self.name + "(" + ", ".join(a.__repr__() for a in self.arguments) + ")"
 
 
 class Argument(object):
-
     def __init__(self, _type: str, name: str, is_optional: bool):
         self.type = _type
         self.name = name
         self.is_optional = is_optional
 
     def __repr__(self) -> str:
-        return self.type + ' ' + self.name
+        return self.type + " " + self.name
 
 
 def parse_header(path: str) -> List[Function]:
-    with open(path, 'r') as f:
-        lines: Iterable[Any] = f.read().split('\n')
+    with open(path, "r") as f:
+        lines: Iterable[Any] = f.read().split("\n")
 
     # Remove empty lines and prebackend directives
-    lines = filter(lambda l: l and not l.startswith('#'), lines)
+    lines = filter(lambda l: l and not l.startswith("#"), lines)
     # Remove line comments
-    lines = (l.partition('//') for l in lines)
+    lines = (l.partition("//") for l in lines)
     # Select line and comment part
     lines = ((l[0].strip(), l[2].strip()) for l in lines)
     # Remove trailing special signs
-    lines = ((l[0].rstrip(');').rstrip(','), l[1]) for l in lines)
+    lines = ((l[0].rstrip(");").rstrip(","), l[1]) for l in lines)
     # Split arguments
-    lines = ((l[0].split(','), l[1]) for l in lines)
+    lines = ((l[0].split(","), l[1]) for l in lines)
     # Flatten lines
     new_lines = []
     for l, c in lines:
@@ -182,32 +199,31 @@
     lines = filter(lambda l: l[0], lines)
     generic_functions = []
     for l, c in lines:
-        if l.startswith('TH_API void THNN_'):
-            fn_name = l[len('TH_API void THNN_'):]
-            if fn_name[0] == '(' and fn_name[-2] == ')':
+        if l.startswith("TH_API void THNN_"):
+            fn_name = l[len("TH_API void THNN_") :]
+            if fn_name[0] == "(" and fn_name[-2] == ")":
                 fn_name = fn_name[1:-2]
             else:
                 fn_name = fn_name[:-1]
             generic_functions.append(Function(fn_name))
-        elif l.startswith('TORCH_CUDA_CPP_API void THNN_'):
-            fn_name = l[len('TORCH_CUDA_CPP_API void THNN_'):]
-            if fn_name[0] == '(' and fn_name[-2] == ')':
+        elif l.startswith("TORCH_CUDA_CPP_API void THNN_"):
+            fn_name = l[len("TORCH_CUDA_CPP_API void THNN_") :]
+            if fn_name[0] == "(" and fn_name[-2] == ")":
                 fn_name = fn_name[1:-2]
             else:
                 fn_name = fn_name[:-1]
             generic_functions.append(Function(fn_name))
-        elif l.startswith('TORCH_CUDA_CU_API void THNN_'):
-            fn_name = l[len('TORCH_CUDA_CU_API void THNN_'):]
-            if fn_name[0] == '(' and fn_name[-2] == ')':
+        elif l.startswith("TORCH_CUDA_CU_API void THNN_"):
+            fn_name = l[len("TORCH_CUDA_CU_API void THNN_") :]
+            if fn_name[0] == "(" and fn_name[-2] == ")":
                 fn_name = fn_name[1:-2]
             else:
                 fn_name = fn_name[:-1]
             generic_functions.append(Function(fn_name))
         elif l:
             t, name = l.split()
-            if '*' in name:
-                t = t + '*'
+            if "*" in name:
+                t = t + "*"
                 name = name[1:]
-            generic_functions[-1].add_argument(
-                Argument(t, name, '[OPTIONAL]' in c))
+            generic_functions[-1].add_argument(Argument(t, name, "[OPTIONAL]" in c))
     return generic_functions
diff --git a/tools/shared/module_loader.py b/tools/shared/module_loader.py
index 7482047..910c3a6 100644
--- a/tools/shared/module_loader.py
+++ b/tools/shared/module_loader.py
@@ -5,6 +5,7 @@
 
 def import_module(name: str, path: str) -> ModuleType:
     import importlib.util
+
     spec = importlib.util.spec_from_file_location(name, path)
     module = importlib.util.module_from_spec(spec)
     cast(Loader, spec.loader).exec_module(module)
diff --git a/tools/stats/export_slow_tests.py b/tools/stats/export_slow_tests.py
index 6659438..13afbf9 100644
--- a/tools/stats/export_slow_tests.py
+++ b/tools/stats/export_slow_tests.py
@@ -5,58 +5,74 @@
 import os
 import statistics
 from collections import defaultdict
-from tools.stats.s3_stat_parser import get_previous_reports_for_branch, Report, Version2Report
+from tools.stats.s3_stat_parser import (
+    get_previous_reports_for_branch,
+    Report,
+    Version2Report,
+)
 from typing import cast, DefaultDict, Dict, List, Any
 from urllib.request import urlopen
 
-SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
+SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
 SLOW_TEST_CASE_THRESHOLD_SEC = 60.0
 RELATIVE_DIFFERENCE_THRESHOLD = 0.1
 IGNORED_JOBS = ["asan", "periodic"]
 
+
 def get_test_case_times() -> Dict[str, float]:
-    reports: List[Report] = get_previous_reports_for_branch('origin/viable/strict', "")
+    reports: List[Report] = get_previous_reports_for_branch("origin/viable/strict", "")
     # an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values]))
     test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list)
     for report in reports:
-        if report.get('format_version', 1) != 2:  # type: ignore[misc]
+        if report.get("format_version", 1) != 2:  # type: ignore[misc]
             raise RuntimeError("S3 format currently handled is version 2 only")
         v2report = cast(Version2Report, report)
 
-        if any(job_name in str(report['build_job']) for job_name in IGNORED_JOBS):
+        if any(job_name in str(report["build_job"]) for job_name in IGNORED_JOBS):
             continue
 
-        for test_file in v2report['files'].values():
-            for suitename, test_suite in test_file['suites'].items():
-                for casename, test_case in test_suite['cases'].items():
+        for test_file in v2report["files"].values():
+            for suitename, test_suite in test_file["suites"].items():
+                for casename, test_case in test_suite["cases"].items():
                     # The below attaches a __main__ as that matches the format of test.__class__ in
                     # common_utils.py (where this data will be used), and also matches what the output
                     # of a running test would look like.
-                    name = f'{casename} (__main__.{suitename})'
-                    succeeded: bool = test_case['status'] is None
+                    name = f"{casename} (__main__.{suitename})"
+                    succeeded: bool = test_case["status"] is None
                     if succeeded:
-                        test_names_to_times[name].append(test_case['seconds'])
-    return {test_case: statistics.mean(times) for test_case, times in test_names_to_times.items()}
+                        test_names_to_times[name].append(test_case["seconds"])
+    return {
+        test_case: statistics.mean(times)
+        for test_case, times in test_names_to_times.items()
+    }
 
 
 def filter_slow_tests(test_cases_dict: Dict[str, float]) -> Dict[str, float]:
-    return {test_case: time for test_case, time in test_cases_dict.items() if time >= SLOW_TEST_CASE_THRESHOLD_SEC}
+    return {
+        test_case: time
+        for test_case, time in test_cases_dict.items()
+        if time >= SLOW_TEST_CASE_THRESHOLD_SEC
+    }
 
 
 def get_test_infra_slow_tests() -> Dict[str, float]:
     url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json"
-    contents = urlopen(url, timeout=1).read().decode('utf-8')
+    contents = urlopen(url, timeout=1).read().decode("utf-8")
     return cast(Dict[str, float], json.loads(contents))
 
 
-def too_similar(calculated_times: Dict[str, float], other_times: Dict[str, float], threshold: float) -> bool:
+def too_similar(
+    calculated_times: Dict[str, float], other_times: Dict[str, float], threshold: float
+) -> bool:
     # check that their keys are the same
     if calculated_times.keys() != other_times.keys():
         return False
 
     for test_case, test_time in calculated_times.items():
         other_test_time = other_times[test_case]
-        relative_difference = abs((other_test_time - test_time) / max(other_test_time, test_time))
+        relative_difference = abs(
+            (other_test_time - test_time) / max(other_test_time, test_time)
+        )
         if relative_difference > threshold:
             return False
     return True
@@ -65,38 +81,43 @@
 def export_slow_tests(options: Any) -> None:
     filename = options.filename
     if os.path.exists(filename):
-        print(f'Overwriting existent file: {filename}')
-    with open(filename, 'w+') as file:
+        print(f"Overwriting existent file: {filename}")
+    with open(filename, "w+") as file:
         slow_test_times: Dict[str, float] = filter_slow_tests(get_test_case_times())
         if options.ignore_small_diffs:
             test_infra_slow_tests_dict = get_test_infra_slow_tests()
-            if too_similar(slow_test_times, test_infra_slow_tests_dict, options.ignore_small_diffs):
+            if too_similar(
+                slow_test_times, test_infra_slow_tests_dict, options.ignore_small_diffs
+            ):
                 slow_test_times = test_infra_slow_tests_dict
-        json.dump(slow_test_times, file, indent='    ', separators=(',', ': '), sort_keys=True)
-        file.write('\n')
+        json.dump(
+            slow_test_times, file, indent="    ", separators=(",", ": "), sort_keys=True
+        )
+        file.write("\n")
 
 
 def parse_args() -> argparse.Namespace:
     parser = argparse.ArgumentParser(
-        description='Export a JSON of slow test cases in PyTorch unit test suite')
+        description="Export a JSON of slow test cases in PyTorch unit test suite"
+    )
     parser.add_argument(
-        '-f',
-        '--filename',
-        nargs='?',
+        "-f",
+        "--filename",
+        nargs="?",
         type=str,
         default=SLOW_TESTS_FILE,
         const=SLOW_TESTS_FILE,
-        help='Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests.json',
+        help="Specify a file path to dump slow test times from previous S3 stats. Default file path: .pytorch-slow-tests.json",
     )
     parser.add_argument(
-        '--ignore-small-diffs',
-        nargs='?',
+        "--ignore-small-diffs",
+        nargs="?",
         type=float,
         const=RELATIVE_DIFFERENCE_THRESHOLD,
-        help='Compares generated results with stats/slow-tests.json in pytorch/test-infra. If the relative differences '
-             'between test times for each test are smaller than the threshold and the set of test cases have not '
-             'changed, we will export the stats already in stats/slow-tests.json. Else, we will export the calculated '
-             'results. The default threshold is 10%.',
+        help="Compares generated results with stats/slow-tests.json in pytorch/test-infra. If the relative differences "
+        "between test times for each test are smaller than the threshold and the set of test cases have not "
+        "changed, we will export the stats already in stats/slow-tests.json. Else, we will export the calculated "
+        "results. The default threshold is 10%.",
     )
     return parser.parse_args()
 
@@ -106,5 +127,5 @@
     export_slow_tests(options)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py
index 1b6c190..1afdd14 100644
--- a/tools/stats/import_test_stats.py
+++ b/tools/stats/import_test_stats.py
@@ -8,15 +8,16 @@
 from typing import Any, Callable, Dict, List, Optional, cast
 from urllib.request import urlopen
 
+
 def get_disabled_issues() -> List[str]:
-    pr_body = os.getenv('PR_BODY', '')
-    commit_messages = os.getenv('COMMIT_MESSAGES', '')
+    pr_body = os.getenv("PR_BODY", "")
+    commit_messages = os.getenv("COMMIT_MESSAGES", "")
     # The below regex is meant to match all *case-insensitive* keywords that
     # GitHub has delineated would link PRs to issues, more details here:
     # https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue.
     # E.g., "Close #62851", "fixES #62851" and "RESOLVED #62851" would all match, but not
     # "closes  #62851" --> extra space, "fixing #62851" --> not a keyword, nor "fix 62851" --> no #
-    regex = '(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)'
+    regex = "(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)"
     issue_numbers = [x[5] for x in re.findall(regex, pr_body + commit_messages)]
     print("Ignoring disabled issues: ", issue_numbers)
     return issue_numbers
@@ -24,16 +25,17 @@
 
 IGNORE_DISABLED_ISSUES: List[str] = get_disabled_issues()
 
-SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
-DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
+SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
+DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
 
 FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
 
+
 def fetch_and_cache(
     dirpath: str,
     name: str,
     url: str,
-    process_fn: Callable[[Dict[str, Any]], Dict[str, Any]]
+    process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
 ) -> Dict[str, Any]:
     """
     This fetch and cache utils allows sharing between different process.
@@ -56,18 +58,20 @@
 
     for _ in range(3):
         try:
-            contents = urlopen(url, timeout=5).read().decode('utf-8')
+            contents = urlopen(url, timeout=5).read().decode("utf-8")
             processed_contents = process_fn(json.loads(contents))
             with open(path, "w") as f:
                 f.write(json.dumps(processed_contents))
             return processed_contents
         except Exception as e:
-            print(f'Could not download {url} because: {e}.')
-    print(f'All retries exhausted, downloading {url} failed.')
+            print(f"Could not download {url} because: {e}.")
+    print(f"All retries exhausted, downloading {url} failed.")
     return {}
 
 
-def get_slow_tests(dirpath: str, filename: str = SLOW_TESTS_FILE) -> Optional[Dict[str, float]]:
+def get_slow_tests(
+    dirpath: str, filename: str = SLOW_TESTS_FILE
+) -> Optional[Dict[str, float]]:
     url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json"
     try:
         return fetch_and_cache(dirpath, filename, url, lambda x: x)
@@ -76,28 +80,36 @@
         return {}
 
 
-def get_disabled_tests(dirpath: str, filename: str = DISABLED_TESTS_FILE) -> Optional[Dict[str, Any]]:
+def get_disabled_tests(
+    dirpath: str, filename: str = DISABLED_TESTS_FILE
+) -> Optional[Dict[str, Any]]:
     def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
         disabled_test_from_issues = dict()
-        for item in the_response['items']:
-            title = item['title']
-            key = 'DISABLED '
-            issue_url = item['html_url']
-            issue_number = issue_url.split('/')[-1]
+        for item in the_response["items"]:
+            title = item["title"]
+            key = "DISABLED "
+            issue_url = item["html_url"]
+            issue_number = issue_url.split("/")[-1]
             if title.startswith(key) and issue_number not in IGNORE_DISABLED_ISSUES:
-                test_name = title[len(key):].strip()
-                body = item['body']
+                test_name = title[len(key) :].strip()
+                body = item["body"]
                 platforms_to_skip = []
-                key = 'platforms:'
+                key = "platforms:"
                 for line in body.splitlines():
                     line = line.lower()
                     if line.startswith(key):
                         pattern = re.compile(r"^\s+|\s*,\s*|\s+$")
-                        platforms_to_skip.extend([x for x in pattern.split(line[len(key):]) if x])
-                disabled_test_from_issues[test_name] = (item['html_url'], platforms_to_skip)
+                        platforms_to_skip.extend(
+                            [x for x in pattern.split(line[len(key) :]) if x]
+                        )
+                disabled_test_from_issues[test_name] = (
+                    item["html_url"],
+                    platforms_to_skip,
+                )
         return disabled_test_from_issues
+
     try:
-        url = 'https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests.json'
+        url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests.json"
         return fetch_and_cache(dirpath, filename, url, process_disabled_test)
     except Exception:
         print("Couldn't download test skip set, leaving all tests enabled...")
diff --git a/tools/stats/print_test_stats.py b/tools/stats/print_test_stats.py
index 836ee5f..56724bc 100755
--- a/tools/stats/print_test_stats.py
+++ b/tools/stats/print_test_stats.py
@@ -12,15 +12,41 @@
 import time
 from collections import defaultdict
 from pathlib import Path
-from typing import (Any, DefaultDict, Dict, Iterable, Iterator, List, Optional,
-                    Set, Tuple, cast)
+from typing import (
+    Any,
+    DefaultDict,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 from xml.dom import minidom
 
 from typing_extensions import TypedDict
-from tools.stats.s3_stat_parser import (newify_case, get_S3_object_from_bucket, get_test_stats_summaries_for_job,
-                                        Report, Status, Commit, HAVE_BOTO3, Version2Case, VersionedReport,
-                                        Version1Report, Version2Report, ReportMetaMeta)
-from tools.stats.scribe import send_to_scribe, rds_write, register_rds_schema, schema_from_sample
+from tools.stats.s3_stat_parser import (
+    newify_case,
+    get_S3_object_from_bucket,
+    get_test_stats_summaries_for_job,
+    Report,
+    Status,
+    Commit,
+    HAVE_BOTO3,
+    Version2Case,
+    VersionedReport,
+    Version1Report,
+    Version2Report,
+    ReportMetaMeta,
+)
+from tools.stats.scribe import (
+    send_to_scribe,
+    rds_write,
+    register_rds_schema,
+    schema_from_sample,
+)
 
 
 SimplerSuite = Dict[str, Version2Case]
@@ -61,12 +87,12 @@
 # share a name (for version 2 reports) or using a list of cases rather
 # than a dict.
 def simplify(report: Report) -> SimplerReport:
-    if 'format_version' not in report:  # version 1 implicitly
+    if "format_version" not in report:  # version 1 implicitly
         v1report = cast(Version1Report, report)
         return {
             # we just don't have test filename information sadly, so we
             # just make one fake filename that is the empty string
-            '': {
+            "": {
                 suite_name: {
                     # This clobbers some cases that have duplicate names
                     # because in version 1, we would merge together all
@@ -80,34 +106,34 @@
                     # we're only uploading in the new format (where
                     # everything is also keyed by filename) going
                     # forward, it shouldn't matter too much.
-                    case['name']: newify_case(case)
-                    for case in suite['cases']
+                    case["name"]: newify_case(case)
+                    for case in suite["cases"]
                 }
-                for suite_name, suite in v1report['suites'].items()
+                for suite_name, suite in v1report["suites"].items()
             }
         }
     else:
         v_report = cast(VersionedReport, report)
-        version = v_report['format_version']
+        version = v_report["format_version"]
         if version == 2:
             v2report = cast(Version2Report, v_report)
             return {
                 filename: {
-                    suite_name: suite['cases']
-                    for suite_name, suite in file_data['suites'].items()
+                    suite_name: suite["cases"]
+                    for suite_name, suite in file_data["suites"].items()
                 }
-                for filename, file_data in v2report['files'].items()
+                for filename, file_data in v2report["files"].items()
             }
         else:
-            raise RuntimeError(f'Unknown format version: {version}')
+            raise RuntimeError(f"Unknown format version: {version}")
 
 
 def plural(n: int) -> str:
-    return '' if n == 1 else 's'
+    return "" if n == 1 else "s"
 
 
 def get_base_commit(sha1: str) -> str:
-    default_branch = os.environ.get('GIT_DEFAULT_BRANCH')
+    default_branch = os.environ.get("GIT_DEFAULT_BRANCH")
     # capture None and "" cases
     if not default_branch:
         default_branch = "master"
@@ -124,28 +150,28 @@
     format: Tuple[Tuple[int, int], Tuple[int, int]],
 ) -> str:
     spread_len = format[1][0] + 1 + format[1][1]
-    spread = x['spread']
+    spread = x["spread"]
     if spread is not None:
-        spread_str = f' ± {spread:{spread_len}.{format[1][1]}f}s'
+        spread_str = f" ± {spread:{spread_len}.{format[1][1]}f}s"
     else:
-        spread_str = ' ' * (3 + spread_len + 1)
+        spread_str = " " * (3 + spread_len + 1)
     mean_len = format[0][0] + 1 + format[0][1]
     return f'{x["center"]:{mean_len}.{format[0][1]}f}s{spread_str}'
 
 
 def list_stat(l: List[float]) -> Stat:
     return {
-        'center': statistics.mean(l),
-        'spread': statistics.stdev(l) if len(l) > 1 else None
+        "center": statistics.mean(l),
+        "spread": statistics.stdev(l) if len(l) > 1 else None,
     }
 
 
 def zero_stat() -> Stat:
-    return {'center': 0, 'spread': None}
+    return {"center": 0, "spread": None}
 
 
 def recenter(was: Stat, now: float) -> Stat:
-    return {'center': now - was['center'], 'spread': was['spread']}
+    return {"center": now - was["center"], "spread": was["spread"]}
 
 
 def sum_normals(stats: Iterable[Stat]) -> Stat:
@@ -157,29 +183,29 @@
     """
     l = list(stats)
     spread: Optional[float]
-    if any(stat['spread'] is not None for stat in l):
-        spread = math.sqrt(sum((stat['spread'] or 0)**2 for stat in l))
+    if any(stat["spread"] is not None for stat in l):
+        spread = math.sqrt(sum((stat["spread"] or 0) ** 2 for stat in l))
     else:
         spread = None
     return {
-        'center': sum(stat['center'] for stat in l),
-        'spread': spread,
+        "center": sum(stat["center"] for stat in l),
+        "spread": spread,
     }
 
 
 def format_seconds(seconds: List[float]) -> str:
     if len(seconds) > 0:
         x = list_stat(seconds)
-        return f'total time {display_stat(x, ((5, 2), (4, 2)))}'.strip()
-    return ''
+        return f"total time {display_stat(x, ((5, 2), (4, 2)))}".strip()
+    return ""
 
 
 def show_ancestors(num_commits: int) -> str:
-    return f'    | : ({num_commits} commit{plural(num_commits)})'
+    return f"    | : ({num_commits} commit{plural(num_commits)})"
 
 
 def unlines(lines: List[str]) -> str:
-    return ''.join(f'{line}\n' for line in lines)
+    return "".join(f"{line}\n" for line in lines)
 
 
 def matching_test_times(
@@ -199,8 +225,8 @@
                 if suite:
                     case = suite.get(case_name)
                     if case:
-                        t = case['seconds']
-                        s = case['status']
+                        t = case["seconds"]
+                        s = case["status"]
                         if s == status:
                             times.append(t)
     return times
@@ -232,37 +258,49 @@
     for filename, suite_name in sorted(all_suites):
         case_diffs: List[CaseDiff] = []
         head_suite = head_report.get(filename, {}).get(suite_name)
-        base_cases: Dict[str, Status] = dict(sorted(set.intersection(*[
-            {
-                (n, case['status'])
-                for n, case
-                in report.get(filename, {}).get(suite_name, {}).items()
-            }
-            for report in base_report
-        ] or [set()])))
+        base_cases: Dict[str, Status] = dict(
+            sorted(
+                set.intersection(
+                    *[
+                        {
+                            (n, case["status"])
+                            for n, case in report.get(filename, {})
+                            .get(suite_name, {})
+                            .items()
+                        }
+                        for report in base_report
+                    ]
+                    or [set()]
+                )
+            )
+        )
         case_stats: Dict[str, Stat] = {}
         if head_suite:
-            now = sum(case['seconds'] for case in head_suite.values())
+            now = sum(case["seconds"] for case in head_suite.values())
             if any(
                 filename in report and suite_name in report[filename]
                 for report in base_report
             ):
                 removed_cases: List[CaseDiff] = []
                 for case_name, case_status in base_cases.items():
-                    case_stats[case_name] = list_stat(matching_test_times(
-                        base_reports=base_reports,
-                        filename=filename,
-                        suite_name=suite_name,
-                        case_name=case_name,
-                        status=case_status,
-                    ))
+                    case_stats[case_name] = list_stat(
+                        matching_test_times(
+                            base_reports=base_reports,
+                            filename=filename,
+                            suite_name=suite_name,
+                            case_name=case_name,
+                            status=case_status,
+                        )
+                    )
                     if case_name not in head_suite:
-                        removed_cases.append({
-                            'margin': '-',
-                            'name': case_name,
-                            'was': (case_stats[case_name], case_status),
-                            'now': None,
-                        })
+                        removed_cases.append(
+                            {
+                                "margin": "-",
+                                "name": case_name,
+                                "was": (case_stats[case_name], case_status),
+                                "now": None,
+                            }
+                        )
                 modified_cases: List[CaseDiff] = []
                 added_cases: List[CaseDiff] = []
                 for head_case_name in sorted(head_suite):
@@ -270,70 +308,86 @@
                     if head_case_name in base_cases:
                         stat = case_stats[head_case_name]
                         base_status = base_cases[head_case_name]
-                        if head_case['status'] != base_status:
-                            modified_cases.append({
-                                'margin': '!',
-                                'name': head_case_name,
-                                'was': (stat, base_status),
-                                'now': head_case,
-                            })
+                        if head_case["status"] != base_status:
+                            modified_cases.append(
+                                {
+                                    "margin": "!",
+                                    "name": head_case_name,
+                                    "was": (stat, base_status),
+                                    "now": head_case,
+                                }
+                            )
                     else:
-                        added_cases.append({
-                            'margin': '+',
-                            'name': head_case_name,
-                            'was': None,
-                            'now': head_case,
-                        })
+                        added_cases.append(
+                            {
+                                "margin": "+",
+                                "name": head_case_name,
+                                "was": None,
+                                "now": head_case,
+                            }
+                        )
                 # there might be a bug calculating this stdev, not sure
                 was = sum_normals(case_stats.values())
                 case_diffs = removed_cases + modified_cases + added_cases
                 if case_diffs:
-                    modified_suites.append({
-                        'margin': ' ',
-                        'name': suite_name,
-                        'was': was,
-                        'now': now,
-                        'cases': case_diffs,
-                    })
+                    modified_suites.append(
+                        {
+                            "margin": " ",
+                            "name": suite_name,
+                            "was": was,
+                            "now": now,
+                            "cases": case_diffs,
+                        }
+                    )
             else:
                 for head_case_name in sorted(head_suite):
                     head_case = head_suite[head_case_name]
-                    case_diffs.append({
-                        'margin': ' ',
-                        'name': head_case_name,
-                        'was': None,
-                        'now': head_case,
-                    })
-                added_suites.append({
-                    'margin': '+',
-                    'name': suite_name,
-                    'was': None,
-                    'now': now,
-                    'cases': case_diffs,
-                })
+                    case_diffs.append(
+                        {
+                            "margin": " ",
+                            "name": head_case_name,
+                            "was": None,
+                            "now": head_case,
+                        }
+                    )
+                added_suites.append(
+                    {
+                        "margin": "+",
+                        "name": suite_name,
+                        "was": None,
+                        "now": now,
+                        "cases": case_diffs,
+                    }
+                )
         else:
             for case_name, case_status in base_cases.items():
-                case_stats[case_name] = list_stat(matching_test_times(
-                    base_reports=base_reports,
-                    filename=filename,
-                    suite_name=suite_name,
-                    case_name=case_name,
-                    status=case_status,
-                ))
-                case_diffs.append({
-                    'margin': ' ',
-                    'name': case_name,
-                    'was': (case_stats[case_name], case_status),
-                    'now': None,
-                })
-            removed_suites.append({
-                'margin': '-',
-                'name': suite_name,
-                # there might be a bug calculating this stdev, not sure
-                'was': sum_normals(case_stats.values()),
-                'now': None,
-                'cases': case_diffs,
-            })
+                case_stats[case_name] = list_stat(
+                    matching_test_times(
+                        base_reports=base_reports,
+                        filename=filename,
+                        suite_name=suite_name,
+                        case_name=case_name,
+                        status=case_status,
+                    )
+                )
+                case_diffs.append(
+                    {
+                        "margin": " ",
+                        "name": case_name,
+                        "was": (case_stats[case_name], case_status),
+                        "now": None,
+                    }
+                )
+            removed_suites.append(
+                {
+                    "margin": "-",
+                    "name": suite_name,
+                    # there might be a bug calculating this stdev, not sure
+                    "was": sum_normals(case_stats.values()),
+                    "now": None,
+                    "cases": case_diffs,
+                }
+            )
 
     return removed_suites + modified_suites + added_suites
 
@@ -343,24 +397,24 @@
 
     case_fmt = ((3, 3), (2, 3))
 
-    was = diff['was']
+    was = diff["was"]
     if was:
-        was_line = f'    # was {display_stat(was[0], case_fmt)}'
+        was_line = f"    # was {display_stat(was[0], case_fmt)}"
         was_status = was[1]
         if was_status:
-            was_line += f' ({was_status})'
+            was_line += f" ({was_status})"
         lines.append(was_line)
 
-    now = diff['now']
+    now = diff["now"]
     if now:
-        now_stat: Stat = {'center': now['seconds'], 'spread': None}
-        now_line = f'    # now {display_stat(now_stat, case_fmt)}'
-        now_status = now['status']
+        now_stat: Stat = {"center": now["seconds"], "spread": None}
+        now_line = f"    # now {display_stat(now_stat, case_fmt)}"
+        now_status = now["status"]
         if now_status:
-            now_line += f' ({now_status})'
+            now_line += f" ({now_status})"
         lines.append(now_line)
 
-    return [''] + [f'{diff["margin"]} {l}' for l in lines]
+    return [""] + [f'{diff["margin"]} {l}' for l in lines]
 
 
 def display_suite_diff(diff: SuiteDiff) -> str:
@@ -368,23 +422,23 @@
 
     suite_fmt = ((4, 2), (3, 2))
 
-    was = diff['was']
+    was = diff["was"]
     if was:
-        lines.append(f'    # was {display_stat(was, suite_fmt)}')
+        lines.append(f"    # was {display_stat(was, suite_fmt)}")
 
-    now = diff['now']
+    now = diff["now"]
     if now is not None:
-        now_stat: Stat = {'center': now, 'spread': None}
-        lines.append(f'    # now {display_stat(now_stat, suite_fmt)}')
+        now_stat: Stat = {"center": now, "spread": None}
+        lines.append(f"    # now {display_stat(now_stat, suite_fmt)}")
 
-    for case_diff in diff['cases']:
-        lines.extend([f'  {l}' for l in case_diff_lines(case_diff)])
+    for case_diff in diff["cases"]:
+        lines.extend([f"  {l}" for l in case_diff_lines(case_diff)])
 
-    return unlines([''] + [f'{diff["margin"]} {l}'.rstrip() for l in lines] + [''])
+    return unlines([""] + [f'{diff["margin"]} {l}'.rstrip() for l in lines] + [""])
 
 
 def anomalies(diffs: List[SuiteDiff]) -> str:
-    return ''.join(map(display_suite_diff, diffs))
+    return "".join(map(display_suite_diff, diffs))
 
 
 def graph(
@@ -397,89 +451,91 @@
     other_ancestors: int = 0,
 ) -> str:
     lines = [
-        'Commit graph (base is most recent master ancestor with at least one S3 report):',
-        '',
-        '    : (master)',
-        '    |',
+        "Commit graph (base is most recent master ancestor with at least one S3 report):",
+        "",
+        "    : (master)",
+        "    |",
     ]
 
-    head_time_str = f'           {format_seconds([head_seconds])}'
+    head_time_str = f"           {format_seconds([head_seconds])}"
     if on_master:
-        lines.append(f'    * {head_sha[:10]} (HEAD)   {head_time_str}')
+        lines.append(f"    * {head_sha[:10]} (HEAD)   {head_time_str}")
     else:
-        lines.append(f'    | * {head_sha[:10]} (HEAD) {head_time_str}')
+        lines.append(f"    | * {head_sha[:10]} (HEAD) {head_time_str}")
 
         if ancestry_path > 0:
             lines += [
-                '    | |',
+                "    | |",
                 show_ancestors(ancestry_path),
             ]
 
         if other_ancestors > 0:
             lines += [
-                '    |/|',
+                "    |/|",
                 show_ancestors(other_ancestors),
-                '    |',
+                "    |",
             ]
         else:
-            lines.append('    |/')
+            lines.append("    |/")
 
     is_first = True
     for sha, seconds in base_seconds.items():
         num_runs = len(seconds)
         prefix = str(num_runs).rjust(3)
-        base = '(base)' if is_first and num_runs > 0 else '      '
+        base = "(base)" if is_first and num_runs > 0 else "      "
         if num_runs > 0:
             is_first = False
         t = format_seconds(seconds)
         p = plural(num_runs)
         if t:
-            p = f'{p}, '.ljust(3)
-        lines.append(f'    * {sha[:10]} {base} {prefix} report{p}{t}')
+            p = f"{p}, ".ljust(3)
+        lines.append(f"    * {sha[:10]} {base} {prefix} report{p}{t}")
 
-    lines.extend(['    |', '    :'])
+    lines.extend(["    |", "    :"])
 
     return unlines(lines)
 
 
 def case_delta(case: CaseDiff) -> Stat:
-    was = case['was']
-    now = case['now']
+    was = case["was"]
+    now = case["now"]
     return recenter(
         was[0] if was else zero_stat(),
-        now['seconds'] if now else 0,
+        now["seconds"] if now else 0,
     )
 
 
 def display_final_stat(stat: Stat) -> str:
-    center = stat['center']
-    spread = stat['spread']
+    center = stat["center"]
+    spread = stat["spread"]
     displayed = display_stat(
-        {'center': abs(center), 'spread': spread},
+        {"center": abs(center), "spread": spread},
         ((4, 2), (3, 2)),
     )
     if center < 0:
-        sign = '-'
+        sign = "-"
     elif center > 0:
-        sign = '+'
+        sign = "+"
     else:
-        sign = ' '
-    return f'{sign}{displayed}'.rstrip()
+        sign = " "
+    return f"{sign}{displayed}".rstrip()
 
 
 def summary_line(message: str, d: DefaultDict[str, List[CaseDiff]]) -> str:
     all_cases = [c for cs in d.values() for c in cs]
     tests = len(all_cases)
     suites = len(d)
-    sp = f'{plural(suites)})'.ljust(2)
-    tp = f'{plural(tests)},'.ljust(2)
+    sp = f"{plural(suites)})".ljust(2)
+    tp = f"{plural(tests)},".ljust(2)
     # there might be a bug calculating this stdev, not sure
     stat = sum_normals(case_delta(c) for c in all_cases)
-    return ''.join([
-        f'{message} (across {suites:>4} suite{sp}',
-        f'{tests:>6} test{tp}',
-        f' totaling {display_final_stat(stat)}',
-    ])
+    return "".join(
+        [
+            f"{message} (across {suites:>4} suite{sp}",
+            f"{tests:>6} test{tp}",
+            f" totaling {display_final_stat(stat)}",
+        ]
+    )
 
 
 def summary(analysis: List[SuiteDiff]) -> str:
@@ -489,17 +545,17 @@
 
     for diff in analysis:
         # the use of 'margin' here is not the most elegant
-        name = diff['name']
-        margin = diff['margin']
-        cases = diff['cases']
-        if margin == '-':
+        name = diff["name"]
+        margin = diff["margin"]
+        cases = diff["cases"]
+        if margin == "-":
             removed_tests[name] += cases
-        elif margin == '+':
+        elif margin == "+":
             added_tests[name] += cases
         else:
-            removed = list(filter(lambda c: c['margin'] == '-', cases))
-            added = list(filter(lambda c: c['margin'] == '+', cases))
-            modified = list(filter(lambda c: c['margin'] == '!', cases))
+            removed = list(filter(lambda c: c["margin"] == "-", cases))
+            added = list(filter(lambda c: c["margin"] == "+", cases))
+            modified = list(filter(lambda c: c["margin"] == "!", cases))
             if removed:
                 removed_tests[name] += removed
             if added:
@@ -507,11 +563,13 @@
             if modified:
                 modified_tests[name] += modified
 
-    return unlines([
-        summary_line('Removed ', removed_tests),
-        summary_line('Modified', modified_tests),
-        summary_line('Added   ', added_tests),
-    ])
+    return unlines(
+        [
+            summary_line("Removed ", removed_tests),
+            summary_line("Modified", modified_tests),
+            summary_line("Added   ", added_tests),
+        ]
+    )
 
 
 def regression_info(
@@ -543,44 +601,49 @@
         base_reports=simpler_base,
     )
 
-    return '\n'.join([
-        unlines([
-            '----- Historic stats comparison result ------',
-            '',
-            f'    job: {job_name}',
-            f'    commit: {head_sha}',
-        ]),
-
-        # don't print anomalies, because sometimes due to sharding, the
-        # output from this would be very long and obscure better signal
-
-        # anomalies(analysis),
-
-        graph(
-            head_sha=head_sha,
-            head_seconds=head_report['total_seconds'],
-            base_seconds={
-                c: [r['total_seconds'] for r in rs]
-                for c, rs in base_reports.items()
-            },
-            on_master=on_master,
-            ancestry_path=ancestry_path,
-            other_ancestors=other_ancestors,
-        ),
-        summary(analysis),
-    ])
+    return "\n".join(
+        [
+            unlines(
+                [
+                    "----- Historic stats comparison result ------",
+                    "",
+                    f"    job: {job_name}",
+                    f"    commit: {head_sha}",
+                ]
+            ),
+            # don't print anomalies, because sometimes due to sharding, the
+            # output from this would be very long and obscure better signal
+            # anomalies(analysis),
+            graph(
+                head_sha=head_sha,
+                head_seconds=head_report["total_seconds"],
+                base_seconds={
+                    c: [r["total_seconds"] for r in rs]
+                    for c, rs in base_reports.items()
+                },
+                on_master=on_master,
+                ancestry_path=ancestry_path,
+                other_ancestors=other_ancestors,
+            ),
+            summary(analysis),
+        ]
+    )
 
 
 class TestCase:
     def __init__(self, dom: Any) -> None:
-        self.class_name = str(dom.attributes['classname'].value)
-        self.name = str(dom.attributes['name'].value)
-        self.time = float(dom.attributes['time'].value)
+        self.class_name = str(dom.attributes["classname"].value)
+        self.name = str(dom.attributes["name"].value)
+        self.time = float(dom.attributes["time"].value)
         # The following attribute is currently ONLY used in process_intentional_test_runs for validation
         # reasons. The test filename that populates TestFile is calculated and passed down through the test report path.
         # The reason we don't just use this attribute is because it doesn't exist for cpp tests, e.g., in test_libtorch
-        self.file = str(dom.attributes['file'].value) if dom.hasAttribute('file') else 'N/A - probably a cpp test'
-        error_elements = dom.getElementsByTagName('error')
+        self.file = (
+            str(dom.attributes["file"].value)
+            if dom.hasAttribute("file")
+            else "N/A - probably a cpp test"
+        )
+        error_elements = dom.getElementsByTagName("error")
         # DISCLAIMER: unexpected successes and expected failures are currently not reported in assemble_s3_object
         self.expected_failure = False
         self.skipped = False
@@ -589,25 +652,32 @@
         if len(error_elements) > 0:
             # We are only expecting 1 element here
             error_element = error_elements[0]
-            self.unexpected_success = (error_element.hasAttribute('type') and
-                                       error_element.attributes['type'].value == 'UnexpectedSuccess')
+            self.unexpected_success = (
+                error_element.hasAttribute("type")
+                and error_element.attributes["type"].value == "UnexpectedSuccess"
+            )
             self.errored = not self.unexpected_success
-        skipped_elements = dom.getElementsByTagName('skipped')
+        skipped_elements = dom.getElementsByTagName("skipped")
         if len(skipped_elements) > 0:
             # We are only expecting 1 element here
             skipped_element = skipped_elements[0]
-            self.expected_failure = (skipped_element.hasAttribute('type') and
-                                     skipped_element.attributes['type'].value == 'XFAIL')
+            self.expected_failure = (
+                skipped_element.hasAttribute("type")
+                and skipped_element.attributes["type"].value == "XFAIL"
+            )
             self.skipped = not self.expected_failure
-        self.failed = len(dom.getElementsByTagName('failure')) > 0
+        self.failed = len(dom.getElementsByTagName("failure")) > 0
 
     def __repr__(self) -> str:
         return self.__str__()
 
     def __str__(self) -> str:
-        return f'[TestCase name: {self.name} | class_name: {self.class_name} | file: {self.file} | time: {self.time} | ' \
-            f'expected_failure: {self.expected_failure} | skipped: {self.skipped} | errored: {self.errored} | ' \
-            f'unexpected_success: {self.unexpected_success} | failed: {self.failed}]\n'
+        return (
+            f"[TestCase name: {self.name} | class_name: {self.class_name} | file: {self.file} | time: {self.time} | "
+            f"expected_failure: {self.expected_failure} | skipped: {self.skipped} | errored: {self.errored} | "
+            f"unexpected_success: {self.unexpected_success} | failed: {self.failed}]\n"
+        )
+
 
 class TestSuite:
     def __init__(self, name: str) -> None:
@@ -622,10 +692,12 @@
         self.expected_failure_count = 0
 
     def __repr__(self) -> str:
-        rc = f'{self.name} run_time: {self.total_time:.2f} tests: {len(self.test_cases)}'
+        rc = (
+            f"{self.name} run_time: {self.total_time:.2f} tests: {len(self.test_cases)}"
+        )
         if self.skipped_count > 0:
-            rc += f' skipped: {self.skipped_count}'
-        return f'TestSuite({rc})'
+            rc += f" skipped: {self.skipped_count}"
+        return f"TestSuite({rc})"
 
     def append(self, test_case: TestCase) -> None:
         self.test_cases[test_case.name] = test_case
@@ -638,7 +710,9 @@
 
     def update(self, test_case: TestCase) -> None:
         name = test_case.name
-        assert name in self.test_cases, f'Error: attempting to replace nonexistent test case {name}'
+        assert (
+            name in self.test_cases
+        ), f"Error: attempting to replace nonexistent test case {name}"
         # Note that time for unexpected successes and expected failures are reported as 0s
         self.test_cases[name].time += test_case.time
         self.test_cases[name].failed |= test_case.failed
@@ -650,24 +724,27 @@
 
 # Tests that spawn duplicates (usually only twice) intentionally
 MULTITESTS = [
-    'test_cpp_extensions_aot',
-    'distributed/test_distributed_spawn',
-    'distributed\\test_distributed_spawn',  # for windows
-    'distributed/test_c10d_gloo',
-    'distributed\\test_c10d_gloo',  # for windows
-    'cpp'  # The caffe2 cpp tests spawn duplicate test cases as well.
+    "test_cpp_extensions_aot",
+    "distributed/test_distributed_spawn",
+    "distributed\\test_distributed_spawn",  # for windows
+    "distributed/test_c10d_gloo",
+    "distributed\\test_c10d_gloo",  # for windows
+    "cpp",  # The caffe2 cpp tests spawn duplicate test cases as well.
 ]
 
 
 DuplicatedDict = Dict[str, Dict[str, List[TestCase]]]
 
+
 class TestFile:
     def __init__(self, name: str) -> None:
         self.name = name
         self.total_time = 0.0
         self.test_suites: Dict[str, TestSuite] = dict()
 
-    def append(self, test_case: TestCase, test_type: str, duplicated_tests_dict: DuplicatedDict) -> None:
+    def append(
+        self, test_case: TestCase, test_type: str, duplicated_tests_dict: DuplicatedDict
+    ) -> None:
         suite_name = test_case.class_name
         if suite_name not in self.test_suites:
             self.test_suites[suite_name] = TestSuite(suite_name)
@@ -680,7 +757,9 @@
             if suite_name not in duplicated_tests_dict:
                 duplicated_tests_dict[suite_name] = dict()
             if test_case.name not in duplicated_tests_dict[suite_name]:
-                duplicated_tests_dict[suite_name][test_case.name] = [self.test_suites[suite_name].test_cases[test_case.name]]
+                duplicated_tests_dict[suite_name][test_case.name] = [
+                    self.test_suites[suite_name].test_cases[test_case.name]
+                ]
             duplicated_tests_dict[suite_name][test_case.name].append(test_case)
         else:
             self.test_suites[suite_name].append(test_case)
@@ -693,7 +772,7 @@
     except Exception as e:
         print(f"Error occurred when parsing {path}: {e}")
         return
-    for test_case in dom.getElementsByTagName('testcase'):
+    for test_case in dom.getElementsByTagName("testcase"):
         yield TestCase(test_case)
 
 
@@ -713,11 +792,11 @@
 
 def parse_reports(folder: str) -> Tuple[Dict[str, TestFile], Dict[str, DuplicatedDict]]:
     tests_by_file = dict()
-    duplicated_tests_by_file : Dict[str, DuplicatedDict] = dict()
+    duplicated_tests_by_file: Dict[str, DuplicatedDict] = dict()
     for report in get_recursive_files(folder, ".xml"):
         report_path = Path(report)
         # basename of the directory of test-report is the test filename
-        test_filename = re.sub(r'\.', '/', report_path.parent.name)
+        test_filename = re.sub(r"\.", "/", report_path.parent.name)
         # test type is the parent directory (only applies to dist-*)
         # See: CUSTOM_HANDLERS in test/run_test.py
         test_type = report_path.parent.parent.name
@@ -726,7 +805,9 @@
         if test_filename not in tests_by_file:
             tests_by_file[test_filename] = TestFile(test_filename)
         for test_case in parse_report(report):
-            tests_by_file[test_filename].append(test_case, test_type, duplicated_tests_by_file[test_filename])
+            tests_by_file[test_filename].append(
+                test_case, test_type, duplicated_tests_by_file[test_filename]
+            )
     return tests_by_file, duplicated_tests_by_file
 
 
@@ -754,40 +835,66 @@
     # Do not run duplication checks for test files that spawn duplicate tests intentionally
     # and are not necessarily flaky test reruns.
     if not any(x in test_run.file for x in MULTITESTS):
-        err_msg = f'Warning: unintentional test case duplicates found for {test_run.name} in suite {test_run.class_name}.'
-        report_only = os.getenv('PYTORCH_OVERRIDE_FLAKY_SIGNAL') != '1'
-        if report_only and num_fail + num_errored + num_unexpected_success < 1 or not report_only and num_expected_fail < 1:
-            raise RuntimeWarning(f'{err_msg} Intentional reruns are only triggered when the first run fails or errors, but'
-                                 ' we found no failures nor errors.')
+        err_msg = f"Warning: unintentional test case duplicates found for {test_run.name} in suite {test_run.class_name}."
+        report_only = os.getenv("PYTORCH_OVERRIDE_FLAKY_SIGNAL") != "1"
+        if (
+            report_only
+            and num_fail + num_errored + num_unexpected_success < 1
+            or not report_only
+            and num_expected_fail < 1
+        ):
+            raise RuntimeWarning(
+                f"{err_msg} Intentional reruns are only triggered when the first run fails or errors, but"
+                " we found no failures nor errors."
+            )
         if num_unexpected_success + num_expected_fail < 1:
-            raise RuntimeWarning(f'{err_msg} Intentional reruns should raise at least one unexpected success or expected '
-                                 'failure, but none have been found.')
+            raise RuntimeWarning(
+                f"{err_msg} Intentional reruns should raise at least one unexpected success or expected "
+                "failure, but none have been found."
+            )
         if report_only and num_pass != num_unexpected_success:
-            raise RuntimeWarning(f'{err_msg} Every success in an intentional rerun is shadowed by one unexpected success.'
-                                 f'However, successes = {num_pass} and unexpected successes = {num_unexpected_success}')
+            raise RuntimeWarning(
+                f"{err_msg} Every success in an intentional rerun is shadowed by one unexpected success."
+                f"However, successes = {num_pass} and unexpected successes = {num_unexpected_success}"
+            )
         if not report_only and num_pass > 1:
-            raise RuntimeWarning(f'{err_msg} There should be at most 1 successful run in an intentional rerun that stops'
-                                 f' at first success. The number of successful runs = {num_pass}')
+            raise RuntimeWarning(
+                f"{err_msg} There should be at most 1 successful run in an intentional rerun that stops"
+                f" at first success. The number of successful runs = {num_pass}"
+            )
         if num_skipped > 0:
-            raise RuntimeWarning(f'{err_msg} No skips should occur in intentional reruns, but skips = {num_skipped}')
-    return max(num_unexpected_success, num_pass), num_fail + num_expected_fail + num_errored
+            raise RuntimeWarning(
+                f"{err_msg} No skips should occur in intentional reruns, but skips = {num_skipped}"
+            )
+    return (
+        max(num_unexpected_success, num_pass),
+        num_fail + num_expected_fail + num_errored,
+    )
 
 
-def assemble_flaky_test_stats(duplicated_tests_by_file: Dict[str, DuplicatedDict]) -> Any:
+def assemble_flaky_test_stats(
+    duplicated_tests_by_file: Dict[str, DuplicatedDict]
+) -> Any:
     flaky_tests = []
-    workflow_id = os.environ.get("GITHUB_RUN_ID", os.environ.get("CIRCLE_WORKFLOW_ID", None))
+    workflow_id = os.environ.get(
+        "GITHUB_RUN_ID", os.environ.get("CIRCLE_WORKFLOW_ID", None)
+    )
     for file_name, suite_to_dict in duplicated_tests_by_file.items():
         for suite_name, testcase_to_runs in suite_to_dict.items():
             for testcase_name, list_of_runs in testcase_to_runs.items():
                 num_green, num_red = process_intentional_test_runs(list_of_runs)
-                if num_green > 0 and num_red > 0:   # Flaky tests show different results in consecutive reruns
-                    flaky_tests.append({
-                        "name": testcase_name,
-                        "suite": suite_name,
-                        "file": file_name,
-                        "num_green": num_green,
-                        "num_red": num_red,
-                    })
+                if (
+                    num_green > 0 and num_red > 0
+                ):  # Flaky tests show different results in consecutive reruns
+                    flaky_tests.append(
+                        {
+                            "name": testcase_name,
+                            "suite": suite_name,
+                            "file": file_name,
+                            "num_green": num_green,
+                            "num_red": num_red,
+                        }
+                    )
     if len(flaky_tests) > 0:
         # write to RDS
         register_rds_schema("flaky_tests", schema_from_sample(flaky_tests[0]))
@@ -795,6 +902,7 @@
 
         # write to S3 to go to Rockset as well
         import uuid
+
         for flaky_test in flaky_tests:
             flaky_test["job_id"] = os.environ["GHA_WORKFLOW_JOB_ID"]
             flaky_test["workflow_id"] = workflow_id
@@ -808,11 +916,17 @@
         "build_pr": os.environ.get("PR_NUMBER", os.environ.get("CIRCLE_PR_NUMBER", "")),
         "build_tag": os.environ.get("TAG", os.environ.get("CIRCLE_TAG", "")),
         "build_sha1": os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", "")),
-        "build_base_commit": get_base_commit(os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", "HEAD"))),
+        "build_base_commit": get_base_commit(
+            os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", "HEAD"))
+        ),
         "build_branch": os.environ.get("BRANCH", os.environ.get("CIRCLE_BRANCH", "")),
         "build_job": os.environ.get("JOB_BASE_NAME", os.environ.get("CIRCLE_JOB", "")),
-        "build_workflow_id": os.environ.get("WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID", "")),
-        "build_start_time_epoch": str(int(os.path.getmtime(os.path.realpath(__file__)))),
+        "build_workflow_id": os.environ.get(
+            "WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID", "")
+        ),
+        "build_start_time_epoch": str(
+            int(os.path.getmtime(os.path.realpath(__file__)))
+        ),
     }
 
 
@@ -820,7 +934,7 @@
     test_file: TestFile,
     test_suite: TestSuite,
     test_case: TestCase,
-    meta_info: ReportMetaMeta
+    meta_info: ReportMetaMeta,
 ) -> Dict[str, Dict[str, Any]]:
     return {
         "normal": {
@@ -846,7 +960,9 @@
         [
             {
                 "category": "perfpipe_pytorch_test_times",
-                "message": json.dumps(build_message(test_file, test_suite, test_case, meta_info)),
+                "message": json.dumps(
+                    build_message(test_file, test_suite, test_case, meta_info)
+                ),
                 "line_escape": False,
             }
             for test_file in reports.values()
@@ -865,44 +981,50 @@
 ) -> Version2Report:
     return {
         **build_info(),  # type: ignore[misc]
-        'total_seconds': total_seconds,
-        'format_version': 2,
-        'files': {
+        "total_seconds": total_seconds,
+        "format_version": 2,
+        "files": {
             name: {
-                'total_seconds': test_file.total_time,
-                'suites': {
+                "total_seconds": test_file.total_time,
+                "suites": {
                     name: {
-                        'total_seconds': suite.total_time,
-                        'cases': {
+                        "total_seconds": suite.total_time,
+                        "cases": {
                             name: {
-                                'seconds': case.time,
-                                'status': 'errored' if case.errored else
-                                          'failed' if case.failed else
-                                          'skipped' if case.skipped else None
+                                "seconds": case.time,
+                                "status": "errored"
+                                if case.errored
+                                else "failed"
+                                if case.failed
+                                else "skipped"
+                                if case.skipped
+                                else None,
                             }
                             for name, case in suite.test_cases.items()
                         },
                     }
                     for name, suite in test_file.test_suites.items()
-                }
+                },
             }
             for name, test_file in reports.items()
-        }
+        },
     }
 
 
 def send_report_to_s3(head_report: Version2Report) -> None:
-    job = os.getenv('JOB_BASE_NAME', os.environ.get('CIRCLE_JOB'))
-    sha1 = os.environ.get('SHA1', os.environ.get('CIRCLE_SHA1', ''))
+    job = os.getenv("JOB_BASE_NAME", os.environ.get("CIRCLE_JOB"))
+    sha1 = os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1", ""))
     now = datetime.datetime.utcnow().isoformat()
 
     # SHARD_NUMBER and TEST_CONFIG are specific to GHA, as these details would be included in CIRCLE_JOB already
-    shard = os.environ.get('SHARD_NUMBER', '')
-    test_config = os.environ.get('TEST_CONFIG')
+    shard = os.environ.get("SHARD_NUMBER", "")
+    test_config = os.environ.get("TEST_CONFIG")
 
-    job_report_dirname = f'{job}{f"-{test_config}" if test_config is not None else ""}{shard}'
-    key = f'test_time/{sha1}/{job_report_dirname}/{now}Z.json.bz2'  # Z meaning UTC
-    obj = get_S3_object_from_bucket('ossci-metrics', key)
+    job_report_dirname = (
+        f'{job}{f"-{test_config}" if test_config is not None else ""}{shard}'
+    )
+    key = f"test_time/{sha1}/{job_report_dirname}/{now}Z.json.bz2"  # Z meaning UTC
+    obj = get_S3_object_from_bucket("ossci-metrics", key)
     # use bz2 because the results are smaller than gzip, and the
     # compression time penalty we pay is only about half a second for
     # input files of a few megabytes in size like these JSON files, and
@@ -923,12 +1045,14 @@
         for suite in file.test_suites.values():
             for case in suite.test_cases.values():
                 if case.errored or case.failed:
-                    failures.append({
-                        "name": case.name,
-                        "suite": suite.name,
-                        "file": file.name,
-                        "status": "failure" if case.failed else "error"
-                    })
+                    failures.append(
+                        {
+                            "name": case.name,
+                            "suite": suite.name,
+                            "file": file.name,
+                            "status": "failure" if case.failed else "error",
+                        }
+                    )
 
     if len(failures) > 0:
         register_rds_schema("test_failures", schema_from_sample(failures[0]))
@@ -941,14 +1065,17 @@
     base = get_base_commit(sha1)
 
     count_spec = f"{base}..{sha1}"
-    intermediate_commits = int(subprocess.check_output(
-        ["git", "rev-list", "--count", count_spec],
-        encoding="ascii"
-    ))
-    ancestry_path = int(subprocess.check_output(
-        ["git", "rev-list", "--ancestry-path", "--count", count_spec],
-        encoding="ascii",
-    ))
+    intermediate_commits = int(
+        subprocess.check_output(
+            ["git", "rev-list", "--count", count_spec], encoding="ascii"
+        )
+    )
+    ancestry_path = int(
+        subprocess.check_output(
+            ["git", "rev-list", "--ancestry-path", "--count", count_spec],
+            encoding="ascii",
+        )
+    )
 
     # if current commit is already on main, we need to exclude it from
     # this history; otherwise we include the merge-base
@@ -973,15 +1100,18 @@
             objects[commit].extend(summary)
 
     print()
-    print(regression_info(
-        head_sha=sha1,
-        head_report=head_report,
-        base_reports=objects,
-        job_name=job,
-        on_master=on_master,
-        ancestry_path=ancestry_path - 1,
-        other_ancestors=intermediate_commits - ancestry_path,
-    ), end="")
+    print(
+        regression_info(
+            head_sha=sha1,
+            head_report=head_report,
+            base_reports=objects,
+            job_name=job,
+            on_master=on_master,
+            ancestry_path=ancestry_path - 1,
+            other_ancestors=intermediate_commits - ancestry_path,
+        ),
+        end="",
+    )
 
 
 def positive_integer(value: str) -> float:
@@ -1006,9 +1136,10 @@
     return True
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     import argparse
     import sys
+
     parser = argparse.ArgumentParser(
         "Print statistics from test XML output.",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
diff --git a/tools/stats/s3_stat_parser.py b/tools/stats/s3_stat_parser.py
index 71474bf..b856cfa 100644
--- a/tools/stats/s3_stat_parser.py
+++ b/tools/stats/s3_stat_parser.py
@@ -10,6 +10,7 @@
 try:
     import boto3  # type: ignore[import]
     import botocore  # type: ignore[import]
+
     HAVE_BOTO3 = True
 except ImportError:
     HAVE_BOTO3 = False
@@ -18,10 +19,10 @@
 logger = logging.getLogger(__name__)
 
 
-OSSCI_METRICS_BUCKET = 'ossci-metrics'
+OSSCI_METRICS_BUCKET = "ossci-metrics"
 
 Commit = str  # 40-digit SHA-1 hex string
-Status = Optional[Literal['errored', 'failed', 'skipped']]
+Status = Optional[Literal["errored", "failed", "skipped"]]
 
 
 class CaseMeta(TypedDict):
@@ -85,8 +86,10 @@
 Report = Union[Version1Report, VersionedReport]
 
 if HAVE_BOTO3:
-    S3_RESOURCE_READ_ONLY = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED))
-    S3_RESOURCE = boto3.resource('s3')
+    S3_RESOURCE_READ_ONLY = boto3.resource(
+        "s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)
+    )
+    S3_RESOURCE = boto3.resource("s3")
 
 
 def get_S3_bucket_readonly(bucket_name: str) -> Any:
@@ -98,7 +101,7 @@
 
 
 def case_status(case: Version1Case) -> Status:
-    for k in {'errored', 'failed', 'skipped'}:
+    for k in {"errored", "failed", "skipped"}:
         if case[k]:  # type: ignore[misc]
             return cast(Status, k)
     return None
@@ -106,8 +109,8 @@
 
 def newify_case(case: Version1Case) -> Version2Case:
     return {
-        'seconds': case['seconds'],
-        'status': case_status(case),
+        "seconds": case["seconds"],
+        "status": case_status(case),
     }
 
 
@@ -119,28 +122,28 @@
     test_name: Optional[str],
 ) -> List[Version2Case]:
     cases: List[Version2Case] = []
-    if 'format_version' not in data:  # version 1 implicitly
+    if "format_version" not in data:  # version 1 implicitly
         v1report = cast(Version1Report, data)
-        suites = v1report['suites']
+        suites = v1report["suites"]
         for sname, v1suite in suites.items():
             if not suite_name or sname == suite_name:
-                for v1case in v1suite['cases']:
-                    if not test_name or v1case['name'] == test_name:
+                for v1case in v1suite["cases"]:
+                    if not test_name or v1case["name"] == test_name:
                         cases.append(newify_case(v1case))
     else:
         v_report = cast(VersionedReport, data)
-        version = v_report['format_version']
+        version = v_report["format_version"]
         if version == 2:
             v2report = cast(Version2Report, v_report)
-            for fname, v2file in v2report['files'].items():
+            for fname, v2file in v2report["files"].items():
                 if fname == filename or not filename:
-                    for sname, v2suite in v2file['suites'].items():
+                    for sname, v2suite in v2file["suites"].items():
                         if sname == suite_name or not suite_name:
-                            for cname, v2case in v2suite['cases'].items():
+                            for cname, v2case in v2suite["cases"].items():
                                 if not test_name or cname == test_name:
                                     cases.append(v2case)
         else:
-            raise RuntimeError(f'Unknown format version: {version}')
+            raise RuntimeError(f"Unknown format version: {version}")
     return cases
 
 
@@ -148,19 +151,22 @@
     summary_dict = defaultdict(list)
     for summary in summaries:
         # master summary format: "test_time/{sha}/{job}/file"
-        summary_job = summary.key.split('/')[2]
+        summary_job = summary.key.split("/")[2]
         if summary_job in jobs or len(jobs) == 0:
             binary = summary.get()["Body"].read()
             string = bz2.decompress(binary).decode("utf-8")
             summary_dict[summary_job].append(json.loads(string))
     return summary_dict
 
-def _parse_pr_summaries(summaries: Any, job_prefix: str) -> Dict[str, List[Tuple[Report, str]]]:
+
+def _parse_pr_summaries(
+    summaries: Any, job_prefix: str
+) -> Dict[str, List[Tuple[Report, str]]]:
     summary_dict = defaultdict(list)
     for summary in summaries:
         # PR summary format: "pr_test_time/{pr}/{sha}/{job}/file"
-        summary_job = summary.key.split('/')[3]
-        summary_timestamp = summary.key.split('/')[4][:len("YYYY-MM-ddTHH:mm:ss")]
+        summary_job = summary.key.split("/")[3]
+        summary_timestamp = summary.key.split("/")[4][: len("YYYY-MM-ddTHH:mm:ss")]
         if not job_prefix or len(job_prefix) == 0 or summary_job.startswith(job_prefix):
             binary = summary.get()["Body"].read()
             string = bz2.decompress(binary).decode("utf-8")
@@ -171,18 +177,25 @@
 # Collect and decompress S3 test stats summaries into JSON.
 # data stored on S3 buckets are pathed by {sha}/{job} so we also allow
 # optional jobs filter
-def get_test_stats_summaries(*, sha: str, jobs: Optional[List[str]] = None) -> Dict[str, List[Report]]:
+def get_test_stats_summaries(
+    *, sha: str, jobs: Optional[List[str]] = None
+) -> Dict[str, List[Report]]:
     bucket = get_S3_bucket_readonly(OSSCI_METRICS_BUCKET)
     summaries = bucket.objects.filter(Prefix=f"test_time/{sha}")
     return _parse_master_summaries(summaries, jobs=list(jobs or []))
 
 
-def get_test_stats_summaries_for_job(*, sha: str, job_prefix: str) -> Dict[str, List[Report]]:
+def get_test_stats_summaries_for_job(
+    *, sha: str, job_prefix: str
+) -> Dict[str, List[Report]]:
     bucket = get_S3_bucket_readonly(OSSCI_METRICS_BUCKET)
     summaries = bucket.objects.filter(Prefix=f"test_time/{sha}/{job_prefix}")
     return _parse_master_summaries(summaries, jobs=list())
 
-def get_test_stats_summaries_for_pr(*, pr: str, job_prefix: str) -> Dict[str, List[Tuple[Report, str]]]:
+
+def get_test_stats_summaries_for_pr(
+    *, pr: str, job_prefix: str
+) -> Dict[str, List[Tuple[Report, str]]]:
     bucket = get_S3_bucket_readonly(OSSCI_METRICS_BUCKET)
     summaries = bucket.objects.filter(Prefix=f"pr_test_time/{pr}/")
     return _parse_pr_summaries(summaries, job_prefix=job_prefix)
@@ -191,35 +204,50 @@
 # This function returns a list of S3 test time reports. This function can run into errors if HAVE_BOTO3 = False
 # or the S3 bucket is somehow unavailable. Even though this function goes through ten commits' reports to find a
 # non-empty report, it is still conceivable (though highly unlikely) for this function to return no reports.
-def get_previous_reports_for_branch(branch: str, ci_job_prefix: str = "") -> List[Report]:
+def get_previous_reports_for_branch(
+    branch: str, ci_job_prefix: str = ""
+) -> List[Report]:
     commit_date_ts = subprocess.check_output(
-        ['git', 'show', '-s', '--format=%ct', 'HEAD'],
-        encoding="ascii").strip()
+        ["git", "show", "-s", "--format=%ct", "HEAD"], encoding="ascii"
+    ).strip()
     commit_date = datetime.fromtimestamp(int(commit_date_ts))
     # We go a day before this current commit to avoiding pulling incomplete reports
-    day_before_commit = str(commit_date - timedelta(days=1)).split(' ')[0]
+    day_before_commit = str(commit_date - timedelta(days=1)).split(" ")[0]
     # something like git rev-list --before="2021-03-04" --max-count=10 --remotes="*origin/nightly"
     commits = subprocess.check_output(
-        ["git", "rev-list", f"--before={day_before_commit}", "--max-count=10", f"--remotes=*{branch}"],
-        encoding="ascii").splitlines()
+        [
+            "git",
+            "rev-list",
+            f"--before={day_before_commit}",
+            "--max-count=10",
+            f"--remotes=*{branch}",
+        ],
+        encoding="ascii",
+    ).splitlines()
 
     reports: List[Report] = []
     commit_index = 0
     while len(reports) == 0 and commit_index < len(commits):
         commit = commits[commit_index]
-        logger.info(f'Grabbing reports from commit: {commit}')
-        summaries = get_test_stats_summaries_for_job(sha=commit, job_prefix=ci_job_prefix)
+        logger.info(f"Grabbing reports from commit: {commit}")
+        summaries = get_test_stats_summaries_for_job(
+            sha=commit, job_prefix=ci_job_prefix
+        )
         for job_name, summary in summaries.items():
             reports.append(summary[0])
             if len(summary) > 1:
-                logger.warning(f'WARNING: Multiple summary objects found for {commit}/{job_name}')
+                logger.warning(
+                    f"WARNING: Multiple summary objects found for {commit}/{job_name}"
+                )
         commit_index += 1
     return reports
 
 
-def get_previous_reports_for_pr(pr: str, ci_job_prefix: str = "") -> List[Tuple[Report, str]]:
+def get_previous_reports_for_pr(
+    pr: str, ci_job_prefix: str = ""
+) -> List[Tuple[Report, str]]:
     reports: List[Tuple[Report, str]] = []
-    logger.info(f'Grabbing reports from PR: {[pr]}')
+    logger.info(f"Grabbing reports from PR: {[pr]}")
     summaries = get_test_stats_summaries_for_pr(pr=pr, job_prefix=ci_job_prefix)
     for _, summary in summaries.items():
         reports.extend(summary)
diff --git a/tools/stats/test_history.py b/tools/stats/test_history.py
index d9a1e29..8375144 100755
--- a/tools/stats/test_history.py
+++ b/tools/stats/test_history.py
@@ -7,17 +7,12 @@
 from signal import SIG_DFL, SIGPIPE, signal
 from typing import Dict, Iterator, List, Optional, Set, Tuple
 
-from tools.stats.s3_stat_parser import (Report, get_cases,
-                                        get_test_stats_summaries)
+from tools.stats.s3_stat_parser import Report, get_cases, get_test_stats_summaries
 
 
-def get_git_commit_history(
-    *,
-    path: str,
-    ref: str
-) -> List[Tuple[str, datetime]]:
+def get_git_commit_history(*, path: str, ref: str) -> List[Tuple[str, datetime]]:
     rc = subprocess.check_output(
-        ['git', '-C', path, 'log', '--pretty=format:%H %ct', ref],
+        ["git", "-C", path, "log", "--pretty=format:%H %ct", ref],
     ).decode("latin-1")
     return [
         (x[0], datetime.fromtimestamp(int(x[1]), tz=timezone.utc))
@@ -37,23 +32,20 @@
     num_length = digits + 1 + decimals
     if data:
         cases = get_cases(
-            data=data,
-            filename=filename,
-            suite_name=suite_name,
-            test_name=test_name
+            data=data, filename=filename, suite_name=suite_name, test_name=test_name
         )
         if cases:
             case = cases[0]
-            status = case['status']
+            status = case["status"]
             omitted = len(cases) - 1
             if status:
-                return f'{status.rjust(num_length)} ', omitted
+                return f"{status.rjust(num_length)} ", omitted
             else:
                 return f'{case["seconds"]:{num_length}.{decimals}f}s', omitted
         else:
             return f'{"absent".rjust(num_length)} ', 0
     else:
-        return ' ' * (num_length + 1), 0
+        return " " * (num_length + 1), 0
 
 
 def make_columns(
@@ -83,10 +75,10 @@
         if job in omitted:
             total_omitted += omitted[job]
     if total_omitted > 0:
-        columns.append(f'({total_omitted} job re-runs omitted)')
+        columns.append(f"({total_omitted} job re-runs omitted)")
     if total_suites > 0:
-        columns.append(f'({total_suites} matching suites omitted)')
-    return ' '.join(columns)
+        columns.append(f"({total_suites} matching suites omitted)")
+    return " ".join(columns)
 
 
 def make_lines(
@@ -108,17 +100,17 @@
             )
             if cases:
                 case = cases[0]
-                status = case['status']
+                status = case["status"]
                 line = f'{job} {case["seconds"]}s{f" {status}" if status else ""}'
                 if len(cases) > 1:
-                    line += f' ({len(cases) - 1} matching suites omitted)'
+                    line += f" ({len(cases) - 1} matching suites omitted)"
                 lines.append(line)
             elif job in jobs:
-                lines.append(f'{job} (test not found)')
+                lines.append(f"{job} (test not found)")
     if lines:
         return lines
     else:
-        return ['(no reports in S3)']
+        return ["(no reports in S3)"]
 
 
 def history_lines(
@@ -142,26 +134,24 @@
             summaries = get_test_stats_summaries(sha=sha)
         else:
             summaries = get_test_stats_summaries(sha=sha, jobs=jobs)
-        if mode == 'columns':
+        if mode == "columns":
             assert jobs is not None
             # we assume that get_test_stats_summaries here doesn't
             # return empty lists
-            omitted = {
-                job: len(l) - 1
-                for job, l in summaries.items()
-                if len(l) > 1
-            }
-            lines = [make_columns(
-                jobs=jobs,
-                jsons={job: l[0] for job, l in summaries.items()},
-                omitted=omitted,
-                filename=filename,
-                suite_name=suite_name,
-                test_name=test_name,
-                digits=digits,
-            )]
+            omitted = {job: len(l) - 1 for job, l in summaries.items() if len(l) > 1}
+            lines = [
+                make_columns(
+                    jobs=jobs,
+                    jsons={job: l[0] for job, l in summaries.items()},
+                    omitted=omitted,
+                    filename=filename,
+                    suite_name=suite_name,
+                    test_name=test_name,
+                    digits=digits,
+                )
+            ]
         else:
-            assert mode == 'multiline'
+            assert mode == "multiline"
             lines = make_lines(
                 jobs=set(jobs or []),
                 jsons=summaries,
@@ -181,7 +171,7 @@
 
 
 def description() -> str:
-    return r'''
+    return r"""
 Display the history of a test.
 
 Each line of (non-error) output starts with the timestamp and SHA1 hash
@@ -236,7 +226,7 @@
 Minor note: in columns mode, a blank cell means that no report was found
 in S3, while the word "absent" means that a report was found but the
 indicated test was not found in that report.
-'''
+"""
 
 
 def parse_args(raw: List[str]) -> argparse.Namespace:
@@ -246,61 +236,57 @@
         formatter_class=HelpFormatter,
     )
     parser.add_argument(
-        '--mode',
-        choices=['columns', 'multiline'],
-        help='output format',
-        default='columns',
+        "--mode",
+        choices=["columns", "multiline"],
+        help="output format",
+        default="columns",
     )
     parser.add_argument(
-        '--pytorch',
-        help='path to local PyTorch clone',
-        default='.',
+        "--pytorch",
+        help="path to local PyTorch clone",
+        default=".",
     )
     parser.add_argument(
-        '--ref',
-        help='starting point (most recent Git ref) to display history for',
-        default='master',
+        "--ref",
+        help="starting point (most recent Git ref) to display history for",
+        default="master",
     )
     parser.add_argument(
-        '--delta',
+        "--delta",
         type=int,
-        help='minimum number of hours between commits',
+        help="minimum number of hours between commits",
         default=0,
     )
     parser.add_argument(
-        '--sha-length',
+        "--sha-length",
         type=int,
-        help='length of the prefix of the SHA1 hash to show',
+        help="length of the prefix of the SHA1 hash to show",
         default=40,
     )
     parser.add_argument(
-        '--digits',
+        "--digits",
         type=int,
-        help='(columns) number of digits to display before the decimal point',
+        help="(columns) number of digits to display before the decimal point",
         default=4,
     )
     parser.add_argument(
-        '--all',
-        action='store_true',
-        help='(multiline) ignore listed jobs, show all jobs for each commit',
+        "--all",
+        action="store_true",
+        help="(multiline) ignore listed jobs, show all jobs for each commit",
     )
     parser.add_argument(
-        '--file',
-        help='name of the file containing the test',
+        "--file",
+        help="name of the file containing the test",
     )
     parser.add_argument(
-        '--suite',
-        help='name of the suite containing the test',
+        "--suite",
+        help="name of the suite containing the test",
     )
+    parser.add_argument("--test", help="name of the test", required=True)
     parser.add_argument(
-        '--test',
-        help='name of the test',
-        required=True
-    )
-    parser.add_argument(
-        '--job',
-        help='names of jobs to display columns for, in order',
-        action='append',
+        "--job",
+        help="names of jobs to display columns for, in order",
+        action="append",
         default=[],
     )
     args = parser.parse_args(raw)
@@ -308,7 +294,7 @@
     args.jobs = None if args.all else args.job
     # We dont allow implicit or empty "--jobs", unless "--all" is specified.
     if args.jobs == []:
-        parser.error('No jobs specified.')
+        parser.error("No jobs specified.")
 
     return args
 
diff --git a/tools/stats/upload_binary_size_to_scuba.py b/tools/stats/upload_binary_size_to_scuba.py
index adf1d50..aacaf62 100644
--- a/tools/stats/upload_binary_size_to_scuba.py
+++ b/tools/stats/upload_binary_size_to_scuba.py
@@ -55,7 +55,9 @@
             "build_num": os.environ.get("CIRCLE_BUILD_NUM"),
             "sha1": os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1")),
             "branch": os.environ.get("BRANCH", os.environ.get("CIRCLE_BRANCH")),
-            "workflow_id": os.environ.get("WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID")),
+            "workflow_id": os.environ.get(
+                "WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID")
+            ),
         },
         "int": {
             "time": int(time.time()),
@@ -118,13 +120,17 @@
                     "pkg_type": "{}/{}/{}".format(android_build_type, arch, lib),
                     "cu_ver": "",  # dummy value for derived field `build_name`
                     "py_ver": "",  # dummy value for derived field `build_name`
-                    "pr": os.environ.get("PR_NUMBER", os.environ.get("CIRCLE_PR_NUMBER")),
+                    "pr": os.environ.get(
+                        "PR_NUMBER", os.environ.get("CIRCLE_PR_NUMBER")
+                    ),
                     # This is the only place where we use directly CIRCLE_BUILD_NUM, everywhere else CIRCLE_* vars
                     # are used as fallback, there seems to be no direct analogy between circle build number and GHA IDs
                     "build_num": os.environ.get("CIRCLE_BUILD_NUM"),
                     "sha1": os.environ.get("SHA1", os.environ.get("CIRCLE_SHA1")),
                     "branch": os.environ.get("BRANCH", os.environ.get("CIRCLE_BRANCH")),
-                    "workflow_id": os.environ.get("WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID")),
+                    "workflow_id": os.environ.get(
+                        "WORKFLOW_ID", os.environ.get("CIRCLE_WORKFLOW_ID")
+                    ),
                 },
                 "int": {
                     "time": int(time.time()),
diff --git a/tools/test/test_cmake.py b/tools/test/test_cmake.py
index ecbce07..2c4bead 100644
--- a/tools/test/test_cmake.py
+++ b/tools/test/test_cmake.py
@@ -9,49 +9,60 @@
 import tools.setup_helpers.cmake
 
 
-T = typing.TypeVar('T')
+T = typing.TypeVar("T")
 
 
 class TestCMake(unittest.TestCase):
-
-    @unittest.mock.patch('multiprocessing.cpu_count')
+    @unittest.mock.patch("multiprocessing.cpu_count")
     def test_build_jobs(self, mock_cpu_count: unittest.mock.MagicMock) -> None:
         """Tests that the number of build jobs comes out correctly."""
         mock_cpu_count.return_value = 13
         cases = [
             # MAX_JOBS, USE_NINJA, IS_WINDOWS,         want
-            ((     '8',      True,     False),          ['-j', '8']),  # noqa: E201,E241
-            ((    None,      True,     False),                 None),  # noqa: E201,E241
-            ((     '7',     False,     False),          ['-j', '7']),  # noqa: E201,E241
-            ((    None,     False,     False),         ['-j', '13']),  # noqa: E201,E241
-            ((     '6',      True,      True),          ['-j', '6']),  # noqa: E201,E241
-            ((    None,      True,      True),                 None),  # noqa: E201,E241
-            ((    '11',     False,      True), ['/p:CL_MPCount=11']),  # noqa: E201,E241
-            ((    None,     False,      True), ['/p:CL_MPCount=13']),  # noqa: E201,E241
+            (("8", True, False), ["-j", "8"]),  # noqa: E201,E241
+            ((None, True, False), None),  # noqa: E201,E241
+            (("7", False, False), ["-j", "7"]),  # noqa: E201,E241
+            ((None, False, False), ["-j", "13"]),  # noqa: E201,E241
+            (("6", True, True), ["-j", "6"]),  # noqa: E201,E241
+            ((None, True, True), None),  # noqa: E201,E241
+            (("11", False, True), ["/p:CL_MPCount=11"]),  # noqa: E201,E241
+            ((None, False, True), ["/p:CL_MPCount=13"]),  # noqa: E201,E241
         ]
         for (max_jobs, use_ninja, is_windows), want in cases:
-            with self.subTest(MAX_JOBS=max_jobs, USE_NINJA=use_ninja, IS_WINDOWS=is_windows):
+            with self.subTest(
+                MAX_JOBS=max_jobs, USE_NINJA=use_ninja, IS_WINDOWS=is_windows
+            ):
                 with contextlib.ExitStack() as stack:
-                    stack.enter_context(env_var('MAX_JOBS', max_jobs))
-                    stack.enter_context(unittest.mock.patch.object(tools.setup_helpers.cmake, 'USE_NINJA', use_ninja))
-                    stack.enter_context(unittest.mock.patch.object(tools.setup_helpers.cmake, 'IS_WINDOWS', is_windows))
+                    stack.enter_context(env_var("MAX_JOBS", max_jobs))
+                    stack.enter_context(
+                        unittest.mock.patch.object(
+                            tools.setup_helpers.cmake, "USE_NINJA", use_ninja
+                        )
+                    )
+                    stack.enter_context(
+                        unittest.mock.patch.object(
+                            tools.setup_helpers.cmake, "IS_WINDOWS", is_windows
+                        )
+                    )
 
                     cmake = tools.setup_helpers.cmake.CMake()
 
-                    with unittest.mock.patch.object(cmake, 'run') as cmake_run:
+                    with unittest.mock.patch.object(cmake, "run") as cmake_run:
                         cmake.build({})
 
                     cmake_run.assert_called_once()
-                    call, = cmake_run.mock_calls
+                    (call,) = cmake_run.mock_calls
                     build_args, _ = call.args
 
                 if want is None:
-                    self.assertNotIn('-j', build_args)
+                    self.assertNotIn("-j", build_args)
                 else:
                     self.assert_contains_sequence(build_args, want)
 
     @staticmethod
-    def assert_contains_sequence(sequence: Sequence[T], subsequence: Sequence[T]) -> None:
+    def assert_contains_sequence(
+        sequence: Sequence[T], subsequence: Sequence[T]
+    ) -> None:
         """Raises an assertion if the subsequence is not contained in the sequence."""
         if len(subsequence) == 0:
             return  # all sequences contain the empty subsequence
@@ -63,7 +74,7 @@
             assert len(candidate) == len(subsequence)  # sanity check
             if candidate == subsequence:
                 return  # found it
-        raise AssertionError(f'{subsequence} not found in {sequence}')
+        raise AssertionError(f"{subsequence} not found in {sequence}")
 
 
 @contextlib.contextmanager
diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py
index 0dded01..4e48424 100644
--- a/tools/test/test_codegen.py
+++ b/tools/test/test_codegen.py
@@ -6,70 +6,75 @@
 from tools.autograd import load_derivatives
 import tools.codegen.model
 
-class TestCreateDerivative(unittest.TestCase):
 
+class TestCreateDerivative(unittest.TestCase):
     def test_named_grads(self) -> None:
         schema = tools.codegen.model.FunctionSchema.parse(
-            'func(Tensor a, Tensor b) -> (Tensor x, Tensor y)')
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
+        )
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
         derivative = load_derivatives.create_derivative(
             native_function,
-            formula='func_backward(grad_x, grad_y)',
+            formula="func_backward(grad_x, grad_y)",
             var_names=(),
-            available_named_gradients=['grad_x', 'grad_y'])
-        self.assertSetEqual(derivative.named_gradients, {'grad_x', 'grad_y'})
+            available_named_gradients=["grad_x", "grad_y"],
+        )
+        self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
 
     def test_non_differentiable_output(self) -> None:
-        specification = 'func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)'
+        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
         schema = tools.codegen.model.FunctionSchema.parse(specification)
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
         differentiability_info = load_derivatives.create_differentiability_info(
-            defn={'name': specification,
-                  'a': 'grads[0]',
-                  'b': 'grads[2]',
-                  },
+            defn={
+                "name": specification,
+                "a": "grads[0]",
+                "b": "grads[2]",
+            },
             functions_by_signature={schema.signature(): [native_function]},
             functions_by_schema={specification: native_function},
             op_counter=typing.Counter[str](),
         )
 
-        self.assertSequenceEqual(differentiability_info.available_named_gradients,
-                                 # grad_y is not present because y is a
-                                 # bool and thus not differentiable.
-                                 ['grad_x', 'grad_z'])
+        self.assertSequenceEqual(
+            differentiability_info.available_named_gradients,
+            # grad_y is not present because y is a
+            # bool and thus not differentiable.
+            ["grad_x", "grad_z"],
+        )
 
     def test_indexed_grads(self) -> None:
         schema = tools.codegen.model.FunctionSchema.parse(
-            'func(Tensor a, Tensor b) -> (Tensor x, Tensor y)')
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
+        )
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
         derivative = load_derivatives.create_derivative(
             native_function,
-            formula='func_backward(grads[0], grads[1])',
+            formula="func_backward(grads[0], grads[1])",
             var_names=(),
-            available_named_gradients=['grad_x', 'grad_y'])
+            available_named_gradients=["grad_x", "grad_y"],
+        )
         self.assertSetEqual(derivative.named_gradients, set())
 
     def test_named_grads_and_indexed_grads(self) -> None:
-        specification = 'func(Tensor a, Tensor b) -> (Tensor x, Tensor y)'
+        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
         schema = tools.codegen.model.FunctionSchema.parse(specification)
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
-        with self.assertRaisesRegex(RuntimeError,
-                                    'illegally mixes use of "grad_RETURN_NAME"'):
+        with self.assertRaisesRegex(
+            RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
+        ):
             load_derivatives.create_differentiability_info(
-                defn={'name': specification,
-                      # Uh-oh, the derivatives reference gradients by
-                      # name and by index.
-                      'a': 'grad_x',
-                      'b': 'grads[1]',
-                      },
+                defn={
+                    "name": specification,
+                    # Uh-oh, the derivatives reference gradients by
+                    # name and by index.
+                    "a": "grad_x",
+                    "b": "grads[1]",
+                },
                 functions_by_signature={schema.signature(): [native_function]},
                 functions_by_schema={specification: native_function},
                 op_counter=typing.Counter[str](),
@@ -78,60 +83,59 @@
 
 class TestGenAutogradFunctions(unittest.TestCase):
     def test_non_differentiable_output_invalid_type(self) -> None:
-        specification = 'func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)'
+        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
         schema = tools.codegen.model.FunctionSchema.parse(specification)
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
         differentiability_info = load_derivatives.create_differentiability_info(
-            defn={'name': specification,
-                  'a': 'grad_x',
-                  'b': 'grad_z',
-                  },
+            defn={
+                "name": specification,
+                "a": "grad_x",
+                "b": "grad_z",
+            },
             functions_by_signature={schema.signature(): [native_function]},
             functions_by_schema={specification: native_function},
             op_counter=typing.Counter[str](),
         )
         definition = gen_autograd_functions.process_function(
-            differentiability_info,
-            gen_autograd_functions.FUNCTION_DEFINITION)
+            differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION
+        )
         # grad_z should map to grads[1], not grads[2] because output 1
         # (y) is not differentiable.
-        assert 'grad_z = grads[2]' not in definition
-        assert 'grad_z = grads[1]' in definition
-
+        assert "grad_z = grads[2]" not in definition
+        assert "grad_z = grads[1]" in definition
 
     def test_non_differentiable_output_output_differentiability(self) -> None:
-        specification = 'func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)'
+        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
         schema = tools.codegen.model.FunctionSchema.parse(specification)
-        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION,
-                                              func=schema)
+        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
 
         differentiability_info = load_derivatives.create_differentiability_info(
-            defn={'name': specification,
-                  'a': 'grad_x',
-                  'b': 'grad_z',
-                  'output_differentiability': [True, False, True],
-                  },
+            defn={
+                "name": specification,
+                "a": "grad_x",
+                "b": "grad_z",
+                "output_differentiability": [True, False, True],
+            },
             functions_by_signature={schema.signature(): [native_function]},
             functions_by_schema={specification: native_function},
             op_counter=typing.Counter[str](),
         )
         definition = gen_autograd_functions.process_function(
-            differentiability_info,
-            gen_autograd_functions.FUNCTION_DEFINITION)
+            differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION
+        )
         # grad_z should map to grads[1], not grads[2] because output 1
         # (y) is not differentiable.
-        assert 'grad_z = grads[2]' not in definition
-        assert 'grad_z = grads[1]' in definition
+        assert "grad_z = grads[2]" not in definition
+        assert "grad_z = grads[1]" in definition
 
 
 # Represents the most basic NativeFunction. Use dataclasses.replace()
 # to edit for use.
 DEFAULT_NATIVE_FUNCTION, _ = tools.codegen.model.NativeFunction.from_yaml(
-    {'func': 'func() -> bool'},
-    loc=tools.codegen.model.Location(__file__, 1))
+    {"func": "func() -> bool"}, loc=tools.codegen.model.Location(__file__, 1)
+)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_codegen_model.py b/tools/test/test_codegen_model.py
index 50ea595..59f9563 100644
--- a/tools/test/test_codegen_model.py
+++ b/tools/test/test_codegen_model.py
@@ -10,6 +10,7 @@
 import tools.codegen.gen as gen
 from tools.codegen.gen import LineLoader, parse_native_yaml_struct
 
+
 class TestCodegenModel(expecttest.TestCase):
     def assertParseErrorInline(self, yaml_str: str, expect: str) -> None:
         es = yaml.load(yaml_str, Loader=LineLoader)
@@ -17,8 +18,8 @@
             parse_native_yaml_struct(es)
         except AssertionError as e:
             # hack to strip out the context
-            msg, _ = str(e).split('  in ', 2)
-            self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1)
+            msg, _ = str(e).split("  in ", 2)
+            self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
             return
         self.fail(msg="Did not raise when expected to")
 
@@ -26,7 +27,10 @@
         # parse a single structured group out of the yaml to g
         es = yaml.load(yaml_str, Loader=LineLoader)
         parsed_yaml = parse_native_yaml_struct(es)
-        native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
+        native_functions, backend_indices = (
+            parsed_yaml.native_functions,
+            parsed_yaml.backend_indices,
+        )
         grouped_native_functions = gen.get_grouped_native_functions(native_functions)
         assert len(grouped_native_functions) == 1
         g = grouped_native_functions[0]
@@ -44,81 +48,98 @@
             dest.compute_ufunc_cuda(g)
         except AssertionError as e:
             # hack to strip out the context
-            msg, _ = str(e).split('  in ', 2)
-            self.assertExpectedInline('\n'.join(textwrap.wrap(msg)), expect, skip=1)
+            msg, _ = str(e).split("  in ", 2)
+            self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1)
             return
         self.fail(msg="Did not raise when expected to")
 
     # NB: indent is hardcoded to be two here, so format your yaml accordingly
-    binop_out = 'func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)'
-    ti_binop_out = f'''{binop_out}
+    binop_out = (
+        "func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
+    )
+    ti_binop_out = f"""{binop_out}
   structured: True
-  structured_inherits: TensorIteratorBase'''
-    ti_binop = '''func: binop(Tensor self, Tensor other) -> Tensor
+  structured_inherits: TensorIteratorBase"""
+    ti_binop = """func: binop(Tensor self, Tensor other) -> Tensor
   structured_delegate: binop.out
-'''
+"""
 
-    ti_unop_out = '''func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    ti_unop_out = """func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
   structured: True
-  structured_inherits: TensorIteratorBase'''
-    ti_unop = '''func: unop(Tensor self) -> Tensor
+  structured_inherits: TensorIteratorBase"""
+    ti_unop = """func: unop(Tensor self) -> Tensor
   structured_delegate: unop.out
-'''
+"""
 
     def test_nonstructured_ufunc(self) -> None:
-        yaml_str = f'''\
+        yaml_str = f"""\
 - {self.binop_out}
   ufunc_inner_loop:
     Generic: binop (Bool)
-'''
-        self.assertParseErrorInline(yaml_str, '''\
-ufunc must be structured''')
+"""
+        self.assertParseErrorInline(
+            yaml_str,
+            """\
+ufunc must be structured""",
+        )
 
     def test_overlapping_ufunc_and_dispatch(self) -> None:
-        yaml_str = f'''\
+        yaml_str = f"""\
 - {self.ti_binop_out}
   ufunc_inner_loop:
     Generic: binop (Bool)
   dispatch:
     CPU: binop_cpu
-'''
-        self.assertParseErrorInline(yaml_str, '''\
-ufunc should not have explicit dispatch entry for CPU''')
+"""
+        self.assertParseErrorInline(
+            yaml_str,
+            """\
+ufunc should not have explicit dispatch entry for CPU""",
+        )
 
     # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456
     @unittest.expectedFailure
     def test_scalaronly_shadowed(self) -> None:
-        yaml_str = f'''\
+        yaml_str = f"""\
 - {self.ti_binop_out}
   ufunc_inner_loop:
     Generic: binop (Bool)
     ScalarOnly: binop (Bool)
-'''
-        self.assertParseErrorInline(yaml_str, '''\
-''')
+"""
+        self.assertParseErrorInline(
+            yaml_str,
+            """\
+""",
+        )
 
     def test_conflicting_ufunc(self) -> None:
-        yaml_str = f'''\
+        yaml_str = f"""\
 - {self.ti_binop_out}
   ufunc_inner_loop:
     Generic: binop (Bool)
     ScalarOnly: binop_scalar (Bool)
 - {self.ti_binop}
-'''
-        self.assertUfuncErrorInline(yaml_str, '''\
-ScalarOnly and Generic must have same ufunc name''')
+"""
+        self.assertUfuncErrorInline(
+            yaml_str,
+            """\
+ScalarOnly and Generic must have same ufunc name""",
+        )
 
     def test_invalid_cudafunctoronself_for_binary_op(self) -> None:
-        yaml_str = f'''\
+        yaml_str = f"""\
 - {self.ti_unop_out}
   ufunc_inner_loop:
     Generic: unop (All)
     CUDAFunctorOnSelf: unop_self_cuda (All)
 - {self.ti_unop}
-'''
-        self.assertUfuncErrorInline(yaml_str, '''\
-cannot use CUDAFunctorOnSelf on non-binary function''')
+"""
+        self.assertUfuncErrorInline(
+            yaml_str,
+            """\
+cannot use CUDAFunctorOnSelf on non-binary function""",
+        )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_extract_scripts.py b/tools/test/test_extract_scripts.py
index 3126893..9914032 100644
--- a/tools/test/test_extract_scripts.py
+++ b/tools/test/test_extract_scripts.py
@@ -2,84 +2,94 @@
 
 from tools import extract_scripts
 
-requirements_sh = '''
+requirements_sh = """
 #!/usr/bin/env bash
 set -eo pipefail
 pip install -r requirements.txt
-'''.strip()
+""".strip()
 
-hello_sh = '''
+hello_sh = """
 #!/usr/bin/env sh
 set -e
 echo hello world
-'''.strip()
+""".strip()
 
 
 class TestExtractScripts(unittest.TestCase):
     def test_extract_none(self) -> None:
         self.assertEqual(
-            extract_scripts.extract({
-                'name': 'Checkout PyTorch',
-                'uses': 'zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9',
-            }),
+            extract_scripts.extract(
+                {
+                    "name": "Checkout PyTorch",
+                    "uses": "zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9",
+                }
+            ),
             None,
         )
 
     def test_extract_run_default_bash(self) -> None:
         self.assertEqual(
-            extract_scripts.extract({
-                'name': 'Install requirements',
-                'run': 'pip install -r requirements.txt',
-            }),
+            extract_scripts.extract(
+                {
+                    "name": "Install requirements",
+                    "run": "pip install -r requirements.txt",
+                }
+            ),
             {
-                'extension': '.sh',
-                'script': requirements_sh,
+                "extension": ".sh",
+                "script": requirements_sh,
             },
         )
 
     def test_extract_run_sh(self) -> None:
         self.assertEqual(
-            extract_scripts.extract({
-                'name': 'Hello world',
-                'run': 'echo hello world',
-                'shell': 'sh',
-            }),
+            extract_scripts.extract(
+                {
+                    "name": "Hello world",
+                    "run": "echo hello world",
+                    "shell": "sh",
+                }
+            ),
             {
-                'extension': '.sh',
-                'script': hello_sh,
+                "extension": ".sh",
+                "script": hello_sh,
             },
         )
 
     def test_extract_run_py(self) -> None:
         self.assertEqual(
-            extract_scripts.extract({
-                'name': 'Hello world',
-                'run': 'print("Hello!")',
-                'shell': 'python',
-            }),
+            extract_scripts.extract(
+                {
+                    "name": "Hello world",
+                    "run": 'print("Hello!")',
+                    "shell": "python",
+                }
+            ),
             {
-                'extension': '.py',
-                'script': 'print("Hello!")',
+                "extension": ".py",
+                "script": 'print("Hello!")',
             },
         )
 
     def test_extract_github_script(self) -> None:
         self.assertEqual(
             # https://github.com/actions/github-script/tree/v3.1.1#reading-step-results
-            extract_scripts.extract({
-                'uses': 'actions/github-script@v3',
-                'id': 'set-result',
-                'with': {
-                    'script': 'return "Hello!"',
-                    'result-encoding': 'string',
-                },
-            }),
+            extract_scripts.extract(
+                {
+                    "uses": "actions/github-script@v3",
+                    "id": "set-result",
+                    "with": {
+                        "script": 'return "Hello!"',
+                        "result-encoding": "string",
+                    },
+                }
+            ),
             {
-                'extension': '.js',
-                'script': 'return "Hello!"',
+                "extension": ".js",
+                "script": 'return "Hello!"',
             },
         )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 9dae08c36..0024737 100644
--- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,229 +9,265 @@
 from tools.codegen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE  # noqa: F401
 
 path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../tools/codegen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, "../tools/codegen/gen_backend_stubs.py")
 
 # gen_backend_stubs.py is an integration point that is called directly by external backends.
 # The tests here are to confirm that badly formed inputs result in reasonable error messages.
 class TestGenBackendStubs(expecttest.TestCase):
-
     def setUp(self) -> None:
         global _GLOBAL_PARSE_NATIVE_YAML_CACHE
         _GLOBAL_PARSE_NATIVE_YAML_CACHE.clear()
 
-
     def assert_success_from_gen_backend_stubs(self, yaml_str: str) -> None:
-        with tempfile.NamedTemporaryFile(mode='w') as fp:
+        with tempfile.NamedTemporaryFile(mode="w") as fp:
             fp.write(yaml_str)
             fp.flush()
-            run(fp.name, '', True)
+            run(fp.name, "", True)
 
     def get_errors_from_gen_backend_stubs(self, yaml_str: str) -> str:
-        with tempfile.NamedTemporaryFile(mode='w') as fp:
+        with tempfile.NamedTemporaryFile(mode="w") as fp:
             fp.write(yaml_str)
             fp.flush()
             try:
-                run(fp.name, '', True)
+                run(fp.name, "", True)
             except AssertionError as e:
                 # Scrub out the temp file name from any error messages to simplify assertions.
-                return str(e).replace(fp.name, '')
-            self.fail('Expected gen_backend_stubs to raise an AssertionError, but it did not.')
+                return str(e).replace(fp.name, "")
+            self.fail(
+                "Expected gen_backend_stubs to raise an AssertionError, but it did not."
+            )
 
     def test_valid_single_op(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
-- abs'''
+- abs"""
         self.assert_success_from_gen_backend_stubs(yaml_str)
 
     def test_valid_multiple_ops(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - add.Tensor
-- abs'''
+- abs"""
         self.assert_success_from_gen_backend_stubs(yaml_str)
 
     def test_valid_zero_ops(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
-supported:'''
+supported:"""
         self.assert_success_from_gen_backend_stubs(yaml_str)
 
     def test_valid_zero_ops_doesnt_require_backend_dispatch_key(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: BAD_XLA
 cpp_namespace: torch_xla
-supported:'''
+supported:"""
         # External codegen on a yaml file with no operators is effectively a no-op,
         # so there's no reason to parse the backend
         self.assert_success_from_gen_backend_stubs(yaml_str)
 
     def test_valid_with_autograd_ops(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - abs
 autograd:
-- add.Tensor'''
+- add.Tensor"""
         # External codegen on a yaml file with no operators is effectively a no-op,
         # so there's no reason to parse the backend
         self.assert_success_from_gen_backend_stubs(yaml_str)
 
     def test_missing_backend(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 cpp_namespace: torch_xla
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide a value for "backend"''')
+        self.assertExpectedInline(
+            output_error, '''You must provide a value for "backend"'''
+        )
 
     def test_empty_backend(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend:
 cpp_namespace: torch_xla
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide a value for "backend"''')
+        self.assertExpectedInline(
+            output_error, '''You must provide a value for "backend"'''
+        )
 
     def test_backend_invalid_dispatch_key(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: NOT_XLA
 cpp_namespace: torch_xla
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''\
+        self.assertExpectedInline(
+            output_error,
+            """\
 unknown dispatch key NOT_XLA
-  The provided value for "backend" must be a valid DispatchKey, but got NOT_XLA.''')  # noqa: B950
+  The provided value for "backend" must be a valid DispatchKey, but got NOT_XLA.""",
+        )  # noqa: B950
 
     def test_missing_cpp_namespace(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide a value for "cpp_namespace"''')
+        self.assertExpectedInline(
+            output_error, '''You must provide a value for "cpp_namespace"'''
+        )
 
     def test_whitespace_cpp_namespace(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace:\t
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide a value for "cpp_namespace"''')
+        self.assertExpectedInline(
+            output_error, '''You must provide a value for "cpp_namespace"'''
+        )
 
     # supported is a single item (it should be a list)
     def test_nonlist_supported(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
-supported: abs'''
+supported: abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''expected "supported" to be a list, but got: abs (of type <class 'str'>)''')
+        self.assertExpectedInline(
+            output_error,
+            """expected "supported" to be a list, but got: abs (of type <class 'str'>)""",
+        )
 
     # supported contains an op that isn't in native_functions.yaml
     def test_supported_invalid_op(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
-- abs_BAD'''
+- abs_BAD"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''')
+        self.assertExpectedInline(
+            output_error, """Found an invalid operator name: abs_BAD"""
+        )
 
     # The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case.
     # Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test.
     def test_backend_has_no_autograd_key_but_provides_entries(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: Vulkan
 cpp_namespace: torch_vulkan
 supported:
 - add
 autograd:
-- sub'''
+- sub"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''Found an invalid operator name: add''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error, """Found an invalid operator name: add"""
+        )  # noqa: B950
 
     # in an operator group, currently all operators must either be registered to the backend or autograd kernel.
     # Here, functional and out mismatch
     def test_backend_autograd_kernel_mismatch_out_functional(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
-- add.out'''
+- add.out"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".""",  # noqa: B950
+        )
 
     # in an operator group, currently all operators must either be registered to the backend or autograd kernel.
     # Here, functional and inplace mismatch
     def test_backend_autograd_kernel_mismatch_functional_inplace(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
-- add_.Tensor'''
+- add_.Tensor"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".""",  # noqa: B950
+        )
 
     # Currently, the same operator can't be listed under both 'supported' and 'autograd', which would
     # involve registering the same kernel to both the XLA and AutogradXLA keys.
     # If we need that functionality in the future, we'll need to augment the codegen.
     def test_op_appears_in_supported_and_autograd_lists(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
-- add.Tensor'''
+- add.Tensor"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".""",  # noqa: B950
+        )
 
     # unrecognized extra yaml key
     def test_unrecognized_key(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 supported:
 - abs
-invalid_key: invalid_val'''
+invalid_key: invalid_val"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, ''' contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""",  # noqa: B950
+        )
 
     # if use_out_as_primary is provided, it must be a bool
     def test_use_out_as_primary_non_bool(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 use_out_as_primary: frue
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide either True or False for use_out_as_primary. Provided: frue''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """You must provide either True or False for use_out_as_primary. Provided: frue""",
+        )  # noqa: B950
 
     # if device_guard is provided, it must be a bool
     def test_device_guard_non_bool(self) -> None:
-        yaml_str = '''\
+        yaml_str = """\
 backend: XLA
 cpp_namespace: torch_xla
 device_guard: frue
 supported:
-- abs'''
+- abs"""
         output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
-        self.assertExpectedInline(output_error, '''You must provide either True or False for device_guard. Provided: frue''')  # noqa: B950
+        self.assertExpectedInline(
+            output_error,
+            """You must provide either True or False for device_guard. Provided: frue""",
+        )  # noqa: B950
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_import_test_stats.py b/tools/test/test_import_test_stats.py
index 5a43a7d..ea9aad8 100644
--- a/tools/test/test_import_test_stats.py
+++ b/tools/test/test_import_test_stats.py
@@ -4,48 +4,64 @@
 from typing import List
 from unittest.mock import patch
 
-class TestGetDisabledIssues(unittest.TestCase):
 
-    def run_assert_disabled_issues(self, pr_body: str, commit_messages: str, expected: List[str]) -> None:
-        with patch.dict(os.environ, {"PR_BODY": pr_body, "COMMIT_MESSAGES": commit_messages}):
+class TestGetDisabledIssues(unittest.TestCase):
+    def run_assert_disabled_issues(
+        self, pr_body: str, commit_messages: str, expected: List[str]
+    ) -> None:
+        with patch.dict(
+            os.environ, {"PR_BODY": pr_body, "COMMIT_MESSAGES": commit_messages}
+        ):
             disabled_issues = get_disabled_issues()
         self.assertEqual(disabled_issues, expected)
 
     # test variations of close in PR_BODY
     def test_closes_pr_body(self) -> None:
-        pr_body = 'closes #123 Close #143 ClOsE #345 closed #10283'
-        self.run_assert_disabled_issues(pr_body, '', ['123', '143', '345', '10283'])
+        pr_body = "closes #123 Close #143 ClOsE #345 closed #10283"
+        self.run_assert_disabled_issues(pr_body, "", ["123", "143", "345", "10283"])
 
     # test variations of fix in COMMIT_MESSAGES
     def test_fixes_commit_messages(self) -> None:
-        commit_messages = 'fix #123 FixEd #143 fixes #345 FiXeD #10283'
-        self.run_assert_disabled_issues('', commit_messages, ['123', '143', '345', '10283'])
+        commit_messages = "fix #123 FixEd #143 fixes #345 FiXeD #10283"
+        self.run_assert_disabled_issues(
+            "", commit_messages, ["123", "143", "345", "10283"]
+        )
 
     # test variations of resolve in PR_BODY and COMMIT_MESSAGES
     def test_resolves_pr_commits(self) -> None:
-        pr_body = 'resolve #123 resolveS #143'
-        commit_messages = 'REsolved #345 RESOLVES #10283'
-        self.run_assert_disabled_issues(pr_body, commit_messages, ['123', '143', '345', '10283'])
+        pr_body = "resolve #123 resolveS #143"
+        commit_messages = "REsolved #345 RESOLVES #10283"
+        self.run_assert_disabled_issues(
+            pr_body, commit_messages, ["123", "143", "345", "10283"]
+        )
 
     # test links
     def test_issue_links(self) -> None:
-        pr_body = 'closes https://github.com/pytorch/pytorch/issues/75198 fixes https://github.com/pytorch/pytorch/issues/75123'
-        self.run_assert_disabled_issues(pr_body, '', ['75198', '75123'])
+        pr_body = "closes https://github.com/pytorch/pytorch/issues/75198 fixes https://github.com/pytorch/pytorch/issues/75123"
+        self.run_assert_disabled_issues(pr_body, "", ["75198", "75123"])
 
     # test strange spacing
     def test_spacing(self) -> None:
-        pr_body = 'resolve #123,resolveS #143Resolved #345\nRESOLVES #10283'
-        commit_messages = 'Fixed #2348fixes https://github.com/pytorch/pytorch/issues/75123resolveS #2134'
-        self.run_assert_disabled_issues(pr_body, commit_messages, ['123', '143', '345', '10283', '2348', '75123', '2134'])
+        pr_body = "resolve #123,resolveS #143Resolved #345\nRESOLVES #10283"
+        commit_messages = "Fixed #2348fixes https://github.com/pytorch/pytorch/issues/75123resolveS #2134"
+        self.run_assert_disabled_issues(
+            pr_body,
+            commit_messages,
+            ["123", "143", "345", "10283", "2348", "75123", "2134"],
+        )
 
     # test bad things
     def test_not_accepted(self) -> None:
-        pr_body = 'fixes189 fixeshttps://github.com/pytorch/pytorch/issues/75123 ' \
-            'closedhttps://githubcom/pytorch/pytorch/issues/75123'
-        commit_messages = 'fix 234, fixes # 45, fixing #123, close 234, closes#45, closing #123 resolve 234, ' \
-            'resolves  #45, resolving #123'
+        pr_body = (
+            "fixes189 fixeshttps://github.com/pytorch/pytorch/issues/75123 "
+            "closedhttps://githubcom/pytorch/pytorch/issues/75123"
+        )
+        commit_messages = (
+            "fix 234, fixes # 45, fixing #123, close 234, closes#45, closing #123 resolve 234, "
+            "resolves  #45, resolving #123"
+        )
         self.run_assert_disabled_issues(pr_body, commit_messages, [])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_mypy_wrapper.py b/tools/test/test_mypy_wrapper.py
index df7b0ab..460e4dd 100644
--- a/tools/test/test_mypy_wrapper.py
+++ b/tools/test/test_mypy_wrapper.py
@@ -5,45 +5,48 @@
 
 class TestMypyWrapper(unittest.TestCase):
     configs = {
-        'foo.ini': {
-            'file1.abc',
-            'dir2',
-            'dir3/file4.xyz',
+        "foo.ini": {
+            "file1.abc",
+            "dir2",
+            "dir3/file4.xyz",
         },
-        'bar/baz.ini': {
-            'file1.abc',
-            'dir2/dir5/file6.def',
-            'dir3/file7.abc',
+        "bar/baz.ini": {
+            "file1.abc",
+            "dir2/dir5/file6.def",
+            "dir3/file7.abc",
         },
     }
 
     trie: mypy_wrapper.Trie = {
-        'file1.abc': {None: {'foo.ini', 'bar/baz.ini'}},
-        'dir2': {
-            None: {'foo.ini'},
-            'dir5': {'file6.def': {None: {'bar/baz.ini'}}},
+        "file1.abc": {None: {"foo.ini", "bar/baz.ini"}},
+        "dir2": {
+            None: {"foo.ini"},
+            "dir5": {"file6.def": {None: {"bar/baz.ini"}}},
         },
-        'dir3': {
-            'file4.xyz': {None: {'foo.ini'}},
-            'file7.abc': {None: {'bar/baz.ini'}},
+        "dir3": {
+            "file4.xyz": {None: {"foo.ini"}},
+            "file7.abc": {None: {"bar/baz.ini"}},
         },
     }
 
     def test_config_files(self) -> None:
-        self.assertEqual(mypy_wrapper.config_files().keys(), {
-            'mypy.ini',
-            'mypy-strict.ini',
-        })
+        self.assertEqual(
+            mypy_wrapper.config_files().keys(),
+            {
+                "mypy.ini",
+                "mypy-strict.ini",
+            },
+        )
 
     def test_split_path(self) -> None:
-        self.assertEqual(mypy_wrapper.split_path('file1.abc'), ['file1.abc'])
+        self.assertEqual(mypy_wrapper.split_path("file1.abc"), ["file1.abc"])
         self.assertEqual(
-            mypy_wrapper.split_path('dir3/file4.xyz'),
-            ['dir3', 'file4.xyz'],
+            mypy_wrapper.split_path("dir3/file4.xyz"),
+            ["dir3", "file4.xyz"],
         )
         self.assertEqual(
-            mypy_wrapper.split_path('dir2/dir5/file6.def'),
-            ['dir2', 'dir5', 'file6.def'],
+            mypy_wrapper.split_path("dir2/dir5/file6.def"),
+            ["dir2", "dir5", "file6.def"],
         )
 
     def test_make_trie(self) -> None:
@@ -51,108 +54,120 @@
 
     def test_lookup(self) -> None:
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'file1.abc'),
-            {'foo.ini', 'bar/baz.ini'},
+            mypy_wrapper.lookup(self.trie, "file1.abc"),
+            {"foo.ini", "bar/baz.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir2/dir5/file6.def'),
-            {'foo.ini', 'bar/baz.ini'},
+            mypy_wrapper.lookup(self.trie, "dir2/dir5/file6.def"),
+            {"foo.ini", "bar/baz.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir3/file4.xyz'),
-            {'foo.ini'},
+            mypy_wrapper.lookup(self.trie, "dir3/file4.xyz"),
+            {"foo.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir3/file7.abc'),
-            {'bar/baz.ini'},
+            mypy_wrapper.lookup(self.trie, "dir3/file7.abc"),
+            {"bar/baz.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'file8.xyz'),
+            mypy_wrapper.lookup(self.trie, "file8.xyz"),
             set(),
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir2/dir9/file10.abc'),
-            {'foo.ini'},
+            mypy_wrapper.lookup(self.trie, "dir2/dir9/file10.abc"),
+            {"foo.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir3/file11.abc'),
+            mypy_wrapper.lookup(self.trie, "dir3/file11.abc"),
             set(),
         )
 
         # non-leaves shouldn't ever be passed to lookup in practice, but
         # still, good to consider/test these cases
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir2'),
-            {'foo.ini'},
+            mypy_wrapper.lookup(self.trie, "dir2"),
+            {"foo.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir2/dir5'),
-            {'foo.ini'},
+            mypy_wrapper.lookup(self.trie, "dir2/dir5"),
+            {"foo.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir3'),
+            mypy_wrapper.lookup(self.trie, "dir3"),
             set(),
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir2/dir9'),
-            {'foo.ini'},
+            mypy_wrapper.lookup(self.trie, "dir2/dir9"),
+            {"foo.ini"},
         )
         self.assertEqual(
-            mypy_wrapper.lookup(self.trie, 'dir4'),
+            mypy_wrapper.lookup(self.trie, "dir4"),
             set(),
         )
 
     def test_make_plan(self) -> None:
         self.assertEqual(
-            mypy_wrapper.make_plan(configs=self.configs, files=[
-                'file8.xyz',
-                'dir3/file11.abc',
-            ]),
-            {}
-        )
-        self.assertEqual(
-            mypy_wrapper.make_plan(configs=self.configs, files=[
-                'file8.xyz',
-                'dir2/dir9/file10.abc',
-                'dir3/file4.xyz',
-                'dir3/file11.abc',
-            ]),
-            {
-                'foo.ini': ['dir2/dir9/file10.abc', 'dir3/file4.xyz'],
-            }
-        )
-        self.assertEqual(
-            mypy_wrapper.make_plan(configs=self.configs, files=[
-                'file8.xyz',
-                'dir3/file11.abc',
-                'dir3/file7.abc',
-            ]),
-            {
-                'bar/baz.ini': ['dir3/file7.abc'],
-            }
-        )
-        self.assertEqual(
-            mypy_wrapper.make_plan(configs=self.configs, files=[
-                'dir2/dir9/file10.abc',
-                'dir2/dir5/file6.def',
-                'dir3/file7.abc',
-                'file1.abc',
-                'dir3/file11.abc',
-            ]),
-            {
-                'foo.ini': [
-                    'dir2/dir9/file10.abc',
-                    'dir2/dir5/file6.def',
-                    'file1.abc',
+            mypy_wrapper.make_plan(
+                configs=self.configs,
+                files=[
+                    "file8.xyz",
+                    "dir3/file11.abc",
                 ],
-                'bar/baz.ini': [
-                    'dir2/dir5/file6.def',
-                    'dir3/file7.abc',
-                    'file1.abc',
+            ),
+            {},
+        )
+        self.assertEqual(
+            mypy_wrapper.make_plan(
+                configs=self.configs,
+                files=[
+                    "file8.xyz",
+                    "dir2/dir9/file10.abc",
+                    "dir3/file4.xyz",
+                    "dir3/file11.abc",
                 ],
-            }
+            ),
+            {
+                "foo.ini": ["dir2/dir9/file10.abc", "dir3/file4.xyz"],
+            },
+        )
+        self.assertEqual(
+            mypy_wrapper.make_plan(
+                configs=self.configs,
+                files=[
+                    "file8.xyz",
+                    "dir3/file11.abc",
+                    "dir3/file7.abc",
+                ],
+            ),
+            {
+                "bar/baz.ini": ["dir3/file7.abc"],
+            },
+        )
+        self.assertEqual(
+            mypy_wrapper.make_plan(
+                configs=self.configs,
+                files=[
+                    "dir2/dir9/file10.abc",
+                    "dir2/dir5/file6.def",
+                    "dir3/file7.abc",
+                    "file1.abc",
+                    "dir3/file11.abc",
+                ],
+            ),
+            {
+                "foo.ini": [
+                    "dir2/dir9/file10.abc",
+                    "dir2/dir5/file6.def",
+                    "file1.abc",
+                ],
+                "bar/baz.ini": [
+                    "dir2/dir5/file6.def",
+                    "dir3/file7.abc",
+                    "file1.abc",
+                ],
+            },
         )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_stats.py b/tools/test/test_stats.py
index 46ad287..d352084 100644
--- a/tools/test/test_stats.py
+++ b/tools/test/test_stats.py
@@ -3,10 +3,16 @@
 from typing import Dict, List
 
 from tools.stats import print_test_stats
-from tools.stats.s3_stat_parser import (Commit, Report, ReportMetaMeta,
-                                        Status, Version1Case,
-                                        Version1Report, Version2Case,
-                                        Version2Report)
+from tools.stats.s3_stat_parser import (
+    Commit,
+    Report,
+    ReportMetaMeta,
+    Status,
+    Version1Case,
+    Version1Report,
+    Version2Case,
+    Version2Report,
+)
 
 
 def fakehash(char: str) -> str:
@@ -15,14 +21,14 @@
 
 def dummy_meta_meta() -> ReportMetaMeta:
     return {
-        'build_pr': '',
-        'build_tag': '',
-        'build_sha1': '',
-        'build_base_commit': '',
-        'build_branch': '',
-        'build_job': '',
-        'build_workflow_id': '',
-        'build_start_time_epoch': '',
+        "build_pr": "",
+        "build_tag": "",
+        "build_sha1": "",
+        "build_base_commit": "",
+        "build_branch": "",
+        "build_job": "",
+        "build_workflow_id": "",
+        "build_start_time_epoch": "",
     }
 
 
@@ -35,202 +41,210 @@
     skipped: bool = False,
 ) -> Version1Case:
     return {
-        'name': name,
-        'seconds': seconds,
-        'errored': errored,
-        'failed': failed,
-        'skipped': skipped,
+        "name": name,
+        "seconds": seconds,
+        "errored": errored,
+        "failed": failed,
+        "skipped": skipped,
     }
 
 
 def make_report_v1(tests: Dict[str, List[Version1Case]]) -> Version1Report:
     suites = {
         suite_name: {
-            'total_seconds': sum(case['seconds'] for case in cases),
-            'cases': cases,
+            "total_seconds": sum(case["seconds"] for case in cases),
+            "cases": cases,
         }
         for suite_name, cases in tests.items()
     }
     return {
         **dummy_meta_meta(),  # type: ignore[misc]
-        'total_seconds': sum(s['total_seconds'] for s in suites.values()),
-        'suites': suites,
+        "total_seconds": sum(s["total_seconds"] for s in suites.values()),
+        "suites": suites,
     }
 
 
 def make_case_v2(seconds: float, status: Status = None) -> Version2Case:
     return {
-        'seconds': seconds,
-        'status': status,
+        "seconds": seconds,
+        "status": status,
     }
 
 
-def make_report_v2(tests: Dict[str, Dict[str, Dict[str, Version2Case]]]) -> Version2Report:
+def make_report_v2(
+    tests: Dict[str, Dict[str, Dict[str, Version2Case]]]
+) -> Version2Report:
     files = {}
     for file_name, file_suites in tests.items():
         suites = {
             suite_name: {
-                'total_seconds': sum(case['seconds'] for case in cases.values()),
-                'cases': cases,
+                "total_seconds": sum(case["seconds"] for case in cases.values()),
+                "cases": cases,
             }
             for suite_name, cases in file_suites.items()
         }
         files[file_name] = {
-            'suites': suites,
-            'total_seconds': sum(suite['total_seconds'] for suite in suites.values()),
+            "suites": suites,
+            "total_seconds": sum(suite["total_seconds"] for suite in suites.values()),
         }
     return {
         **dummy_meta_meta(),  # type: ignore[misc]
-        'format_version': 2,
-        'total_seconds': sum(s['total_seconds'] for s in files.values()),
-        'files': files,
+        "format_version": 2,
+        "total_seconds": sum(s["total_seconds"] for s in files.values()),
+        "files": files,
     }
+
+
 maxDiff = None
 
+
 class TestPrintTestStats(unittest.TestCase):
-    version1_report: Version1Report = make_report_v1({
-        # input ordering of the suites is ignored
-        'Grault': [
-            # not printed: status same and time similar
-            makecase('test_grault0', 4.78, failed=True),
-            # status same, but time increased a lot
-            makecase('test_grault2', 1.473, errored=True),
-        ],
-        # individual tests times changed, not overall suite
-        'Qux': [
-            # input ordering of the test cases is ignored
-            makecase('test_qux1', 0.001, skipped=True),
-            makecase('test_qux6', 0.002, skipped=True),
-            # time in bounds, but status changed
-            makecase('test_qux4', 7.158, failed=True),
-            # not printed because it's the same as before
-            makecase('test_qux7', 0.003, skipped=True),
-            makecase('test_qux5', 11.968),
-            makecase('test_qux3', 23.496),
-        ],
-        # new test suite
-        'Bar': [
-            makecase('test_bar2', 3.742, failed=True),
-            makecase('test_bar1', 50.447),
-        ],
-        # overall suite time changed but no individual tests
-        'Norf': [
-            makecase('test_norf1', 3),
-            makecase('test_norf2', 3),
-            makecase('test_norf3', 3),
-            makecase('test_norf4', 3),
-        ],
-        # suite doesn't show up if it doesn't change enough
-        'Foo': [
-            makecase('test_foo1', 42),
-            makecase('test_foo2', 56),
-        ],
-    })
+    version1_report: Version1Report = make_report_v1(
+        {
+            # input ordering of the suites is ignored
+            "Grault": [
+                # not printed: status same and time similar
+                makecase("test_grault0", 4.78, failed=True),
+                # status same, but time increased a lot
+                makecase("test_grault2", 1.473, errored=True),
+            ],
+            # individual tests times changed, not overall suite
+            "Qux": [
+                # input ordering of the test cases is ignored
+                makecase("test_qux1", 0.001, skipped=True),
+                makecase("test_qux6", 0.002, skipped=True),
+                # time in bounds, but status changed
+                makecase("test_qux4", 7.158, failed=True),
+                # not printed because it's the same as before
+                makecase("test_qux7", 0.003, skipped=True),
+                makecase("test_qux5", 11.968),
+                makecase("test_qux3", 23.496),
+            ],
+            # new test suite
+            "Bar": [
+                makecase("test_bar2", 3.742, failed=True),
+                makecase("test_bar1", 50.447),
+            ],
+            # overall suite time changed but no individual tests
+            "Norf": [
+                makecase("test_norf1", 3),
+                makecase("test_norf2", 3),
+                makecase("test_norf3", 3),
+                makecase("test_norf4", 3),
+            ],
+            # suite doesn't show up if it doesn't change enough
+            "Foo": [
+                makecase("test_foo1", 42),
+                makecase("test_foo2", 56),
+            ],
+        }
+    )
 
     version2_report: Version2Report = make_report_v2(
         {
-            'test_a': {
-                'Grault': {
-                    'test_grault0': make_case_v2(4.78, 'failed'),
-                    'test_grault2': make_case_v2(1.473, 'errored'),
+            "test_a": {
+                "Grault": {
+                    "test_grault0": make_case_v2(4.78, "failed"),
+                    "test_grault2": make_case_v2(1.473, "errored"),
                 },
-                'Qux': {
-                    'test_qux1': make_case_v2(0.001, 'skipped'),
-                    'test_qux6': make_case_v2(0.002, 'skipped'),
-                    'test_qux4': make_case_v2(7.158, 'failed'),
-                    'test_qux7': make_case_v2(0.003, 'skipped'),
-                    'test_qux8': make_case_v2(11.968),
-                    'test_qux3': make_case_v2(23.496),
-                }
+                "Qux": {
+                    "test_qux1": make_case_v2(0.001, "skipped"),
+                    "test_qux6": make_case_v2(0.002, "skipped"),
+                    "test_qux4": make_case_v2(7.158, "failed"),
+                    "test_qux7": make_case_v2(0.003, "skipped"),
+                    "test_qux8": make_case_v2(11.968),
+                    "test_qux3": make_case_v2(23.496),
+                },
             },
-            'test_b': {
-                'Bar': {
-                    'test_bar2': make_case_v2(3.742, 'failed'),
-                    'test_bar1': make_case_v2(50.447),
+            "test_b": {
+                "Bar": {
+                    "test_bar2": make_case_v2(3.742, "failed"),
+                    "test_bar1": make_case_v2(50.447),
                 },
                 # overall suite time changed but no individual tests
-                'Norf': {
-                    'test_norf1': make_case_v2(3),
-                    'test_norf2': make_case_v2(3),
-                    'test_norf3': make_case_v2(3),
-                    'test_norf4': make_case_v2(3),
+                "Norf": {
+                    "test_norf1": make_case_v2(3),
+                    "test_norf2": make_case_v2(3),
+                    "test_norf3": make_case_v2(3),
+                    "test_norf4": make_case_v2(3),
                 },
             },
-            'test_c': {
-                'Foo': {
-                    'test_foo1': make_case_v2(42),
-                    'test_foo2': make_case_v2(56),
+            "test_c": {
+                "Foo": {
+                    "test_foo1": make_case_v2(42),
+                    "test_foo2": make_case_v2(56),
                 },
-            }
-        })
+            },
+        }
+    )
 
     def test_simplify(self) -> None:
         self.assertEqual(
             {
-                '': {
-                    'Bar': {
-                        'test_bar1': {'seconds': 50.447, 'status': None},
-                        'test_bar2': {'seconds': 3.742, 'status': 'failed'},
+                "": {
+                    "Bar": {
+                        "test_bar1": {"seconds": 50.447, "status": None},
+                        "test_bar2": {"seconds": 3.742, "status": "failed"},
                     },
-                    'Foo': {
-                        'test_foo1': {'seconds': 42, 'status': None},
-                        'test_foo2': {'seconds': 56, 'status': None},
+                    "Foo": {
+                        "test_foo1": {"seconds": 42, "status": None},
+                        "test_foo2": {"seconds": 56, "status": None},
                     },
-                    'Grault': {
-                        'test_grault0': {'seconds': 4.78, 'status': 'failed'},
-                        'test_grault2': {'seconds': 1.473, 'status': 'errored'},
+                    "Grault": {
+                        "test_grault0": {"seconds": 4.78, "status": "failed"},
+                        "test_grault2": {"seconds": 1.473, "status": "errored"},
                     },
-                    'Norf': {
-                        'test_norf1': {'seconds': 3, 'status': None},
-                        'test_norf3': {'seconds': 3, 'status': None},
-                        'test_norf2': {'seconds': 3, 'status': None},
-                        'test_norf4': {'seconds': 3, 'status': None},
+                    "Norf": {
+                        "test_norf1": {"seconds": 3, "status": None},
+                        "test_norf3": {"seconds": 3, "status": None},
+                        "test_norf2": {"seconds": 3, "status": None},
+                        "test_norf4": {"seconds": 3, "status": None},
                     },
-                    'Qux': {
-                        'test_qux1': {'seconds': 0.001, 'status': 'skipped'},
-                        'test_qux3': {'seconds': 23.496, 'status': None},
-                        'test_qux4': {'seconds': 7.158, 'status': 'failed'},
-                        'test_qux5': {'seconds': 11.968, 'status': None},
-                        'test_qux6': {'seconds': 0.002, 'status': 'skipped'},
-                        'test_qux7': {'seconds': 0.003, 'status': 'skipped'},
+                    "Qux": {
+                        "test_qux1": {"seconds": 0.001, "status": "skipped"},
+                        "test_qux3": {"seconds": 23.496, "status": None},
+                        "test_qux4": {"seconds": 7.158, "status": "failed"},
+                        "test_qux5": {"seconds": 11.968, "status": None},
+                        "test_qux6": {"seconds": 0.002, "status": "skipped"},
+                        "test_qux7": {"seconds": 0.003, "status": "skipped"},
                     },
                 },
             },
-            print_test_stats.simplify(self.version1_report)
+            print_test_stats.simplify(self.version1_report),
         )
 
         self.assertEqual(
             {
-                'test_a': {
-                    'Grault': {
-                        'test_grault0': {'seconds': 4.78, 'status': 'failed'},
-                        'test_grault2': {'seconds': 1.473, 'status': 'errored'},
+                "test_a": {
+                    "Grault": {
+                        "test_grault0": {"seconds": 4.78, "status": "failed"},
+                        "test_grault2": {"seconds": 1.473, "status": "errored"},
                     },
-                    'Qux': {
-                        'test_qux1': {'seconds': 0.001, 'status': 'skipped'},
-                        'test_qux3': {'seconds': 23.496, 'status': None},
-                        'test_qux4': {'seconds': 7.158, 'status': 'failed'},
-                        'test_qux6': {'seconds': 0.002, 'status': 'skipped'},
-                        'test_qux7': {'seconds': 0.003, 'status': 'skipped'},
-                        'test_qux8': {'seconds': 11.968, 'status': None},
+                    "Qux": {
+                        "test_qux1": {"seconds": 0.001, "status": "skipped"},
+                        "test_qux3": {"seconds": 23.496, "status": None},
+                        "test_qux4": {"seconds": 7.158, "status": "failed"},
+                        "test_qux6": {"seconds": 0.002, "status": "skipped"},
+                        "test_qux7": {"seconds": 0.003, "status": "skipped"},
+                        "test_qux8": {"seconds": 11.968, "status": None},
                     },
                 },
-                'test_b': {
-                    'Bar': {
-                        'test_bar1': {'seconds': 50.447, 'status': None},
-                        'test_bar2': {'seconds': 3.742, 'status': 'failed'},
+                "test_b": {
+                    "Bar": {
+                        "test_bar1": {"seconds": 50.447, "status": None},
+                        "test_bar2": {"seconds": 3.742, "status": "failed"},
                     },
-                    'Norf': {
-                        'test_norf1': {'seconds': 3, 'status': None},
-                        'test_norf2': {'seconds': 3, 'status': None},
-                        'test_norf3': {'seconds': 3, 'status': None},
-                        'test_norf4': {'seconds': 3, 'status': None},
+                    "Norf": {
+                        "test_norf1": {"seconds": 3, "status": None},
+                        "test_norf2": {"seconds": 3, "status": None},
+                        "test_norf3": {"seconds": 3, "status": None},
+                        "test_norf4": {"seconds": 3, "status": None},
                     },
                 },
-                'test_c': {
-                    'Foo': {
-                        'test_foo1': {'seconds': 42, 'status': None},
-                        'test_foo2': {'seconds': 56, 'status': None},
+                "test_c": {
+                    "Foo": {
+                        "test_foo1": {"seconds": 42, "status": None},
+                        "test_foo2": {"seconds": 56, "status": None},
                     },
                 },
             },
@@ -242,95 +256,101 @@
 
         base_reports: Dict[Commit, List[Report]] = {
             # bbbb has no reports, so base is cccc instead
-            fakehash('b'): [],
-            fakehash('c'): [
-                make_report_v1({
-                    'Baz': [
-                        makecase('test_baz2', 13.605),
-                        # no recent suites have & skip this test
-                        makecase('test_baz1', 0.004, skipped=True),
-                    ],
-                    'Foo': [
-                        makecase('test_foo1', 43),
-                        # test added since dddd
-                        makecase('test_foo2', 57),
-                    ],
-                    'Grault': [
-                        makecase('test_grault0', 4.88, failed=True),
-                        makecase('test_grault1', 11.967, failed=True),
-                        makecase('test_grault2', 0.395, errored=True),
-                        makecase('test_grault3', 30.460),
-                    ],
-                    'Norf': [
-                        makecase('test_norf1', 2),
-                        makecase('test_norf2', 2),
-                        makecase('test_norf3', 2),
-                        makecase('test_norf4', 2),
-                    ],
-                    'Qux': [
-                        makecase('test_qux3', 4.978, errored=True),
-                        makecase('test_qux7', 0.002, skipped=True),
-                        makecase('test_qux2', 5.618),
-                        makecase('test_qux4', 7.766, errored=True),
-                        makecase('test_qux6', 23.589, failed=True),
-                    ],
-                }),
+            fakehash("b"): [],
+            fakehash("c"): [
+                make_report_v1(
+                    {
+                        "Baz": [
+                            makecase("test_baz2", 13.605),
+                            # no recent suites have & skip this test
+                            makecase("test_baz1", 0.004, skipped=True),
+                        ],
+                        "Foo": [
+                            makecase("test_foo1", 43),
+                            # test added since dddd
+                            makecase("test_foo2", 57),
+                        ],
+                        "Grault": [
+                            makecase("test_grault0", 4.88, failed=True),
+                            makecase("test_grault1", 11.967, failed=True),
+                            makecase("test_grault2", 0.395, errored=True),
+                            makecase("test_grault3", 30.460),
+                        ],
+                        "Norf": [
+                            makecase("test_norf1", 2),
+                            makecase("test_norf2", 2),
+                            makecase("test_norf3", 2),
+                            makecase("test_norf4", 2),
+                        ],
+                        "Qux": [
+                            makecase("test_qux3", 4.978, errored=True),
+                            makecase("test_qux7", 0.002, skipped=True),
+                            makecase("test_qux2", 5.618),
+                            makecase("test_qux4", 7.766, errored=True),
+                            makecase("test_qux6", 23.589, failed=True),
+                        ],
+                    }
+                ),
             ],
-            fakehash('d'): [
-                make_report_v1({
-                    'Foo': [
-                        makecase('test_foo1', 40),
-                        # removed in cccc
-                        makecase('test_foo3', 17),
-                    ],
-                    'Baz': [
-                        # not skipped, so not included in stdev
-                        makecase('test_baz1', 3.14),
-                    ],
-                    'Qux': [
-                        makecase('test_qux7', 0.004, skipped=True),
-                        makecase('test_qux2', 6.02),
-                        makecase('test_qux4', 20.932),
-                    ],
-                    'Norf': [
-                        makecase('test_norf1', 3),
-                        makecase('test_norf2', 3),
-                        makecase('test_norf3', 3),
-                        makecase('test_norf4', 3),
-                    ],
-                    'Grault': [
-                        makecase('test_grault0', 5, failed=True),
-                        makecase('test_grault1', 14.325, failed=True),
-                        makecase('test_grault2', 0.31, errored=True),
-                    ],
-                }),
+            fakehash("d"): [
+                make_report_v1(
+                    {
+                        "Foo": [
+                            makecase("test_foo1", 40),
+                            # removed in cccc
+                            makecase("test_foo3", 17),
+                        ],
+                        "Baz": [
+                            # not skipped, so not included in stdev
+                            makecase("test_baz1", 3.14),
+                        ],
+                        "Qux": [
+                            makecase("test_qux7", 0.004, skipped=True),
+                            makecase("test_qux2", 6.02),
+                            makecase("test_qux4", 20.932),
+                        ],
+                        "Norf": [
+                            makecase("test_norf1", 3),
+                            makecase("test_norf2", 3),
+                            makecase("test_norf3", 3),
+                            makecase("test_norf4", 3),
+                        ],
+                        "Grault": [
+                            makecase("test_grault0", 5, failed=True),
+                            makecase("test_grault1", 14.325, failed=True),
+                            makecase("test_grault2", 0.31, errored=True),
+                        ],
+                    }
+                ),
             ],
-            fakehash('e'): [],
-            fakehash('f'): [
-                make_report_v1({
-                    'Foo': [
-                        makecase('test_foo3', 24),
-                        makecase('test_foo1', 43),
-                    ],
-                    'Baz': [
-                        makecase('test_baz2', 16.857),
-                    ],
-                    'Qux': [
-                        makecase('test_qux2', 6.422),
-                        makecase('test_qux4', 6.382, errored=True),
-                    ],
-                    'Norf': [
-                        makecase('test_norf1', 0.9),
-                        makecase('test_norf3', 0.9),
-                        makecase('test_norf2', 0.9),
-                        makecase('test_norf4', 0.9),
-                    ],
-                    'Grault': [
-                        makecase('test_grault0', 4.7, failed=True),
-                        makecase('test_grault1', 13.146, failed=True),
-                        makecase('test_grault2', 0.48, errored=True),
-                    ],
-                }),
+            fakehash("e"): [],
+            fakehash("f"): [
+                make_report_v1(
+                    {
+                        "Foo": [
+                            makecase("test_foo3", 24),
+                            makecase("test_foo1", 43),
+                        ],
+                        "Baz": [
+                            makecase("test_baz2", 16.857),
+                        ],
+                        "Qux": [
+                            makecase("test_qux2", 6.422),
+                            makecase("test_qux4", 6.382, errored=True),
+                        ],
+                        "Norf": [
+                            makecase("test_norf1", 0.9),
+                            makecase("test_norf3", 0.9),
+                            makecase("test_norf2", 0.9),
+                            makecase("test_norf4", 0.9),
+                        ],
+                        "Grault": [
+                            makecase("test_grault0", 4.7, failed=True),
+                            makecase("test_grault1", 13.146, failed=True),
+                            makecase("test_grault2", 0.48, errored=True),
+                        ],
+                    }
+                ),
             ],
         }
 
@@ -344,7 +364,7 @@
         )
 
         self.assertEqual(
-            '''\
+            """\
 
 - class Baz:
 -     # was   15.23s ±   2.30s
@@ -402,14 +422,14 @@
 +     def test_bar2: ...
 +         # now   3.742s           (failed)
 
-''',
+""",
             print_test_stats.anomalies(analysis),
         )
 
     def test_graph(self) -> None:
         # HEAD is on master
         self.assertEqual(
-            '''\
+            """\
 Commit graph (base is most recent master ancestor with at least one S3 report):
 
     : (master)
@@ -420,21 +440,21 @@
     * dddddddddd          0 reports
     |
     :
-''',
+""",
             print_test_stats.graph(
-                head_sha=fakehash('a'),
+                head_sha=fakehash("a"),
                 head_seconds=502.99,
                 base_seconds={
-                    fakehash('b'): [47.84],
-                    fakehash('c'): [332.50],
-                    fakehash('d'): [],
+                    fakehash("b"): [47.84],
+                    fakehash("c"): [332.50],
+                    fakehash("d"): [],
                 },
                 on_master=True,
-            )
+            ),
         )
 
         self.assertEqual(
-            '''\
+            """\
 Commit graph (base is most recent master ancestor with at least one S3 report):
 
     : (master)
@@ -446,21 +466,21 @@
     * dddddddddd          1 report,  total time  1234.56s
     |
     :
-''',
+""",
             print_test_stats.graph(
-                head_sha=fakehash('a'),
+                head_sha=fakehash("a"),
                 head_seconds=9988.77,
                 base_seconds={
-                    fakehash('b'): [7598.77] * 60 + [7654.32] + [7709.87] * 60,
-                    fakehash('c'): [5308.77] * 10 + [5802.33] * 10,
-                    fakehash('d'): [1234.56],
+                    fakehash("b"): [7598.77] * 60 + [7654.32] + [7709.87] * 60,
+                    fakehash("c"): [5308.77] * 10 + [5802.33] * 10,
+                    fakehash("d"): [1234.56],
                 },
                 on_master=False,
-            )
+            ),
         )
 
         self.assertEqual(
-            '''\
+            """\
 Commit graph (base is most recent master ancestor with at least one S3 report):
 
     : (master)
@@ -474,22 +494,22 @@
     * dddddddddd (base)  15 reports, total time    58.92s ±   25.82s
     |
     :
-''',
+""",
             print_test_stats.graph(
-                head_sha=fakehash('a'),
+                head_sha=fakehash("a"),
                 head_seconds=25.52,
                 base_seconds={
-                    fakehash('b'): [],
-                    fakehash('c'): [],
-                    fakehash('d'): [52.25] * 14 + [152.26],
+                    fakehash("b"): [],
+                    fakehash("c"): [],
+                    fakehash("d"): [52.25] * 14 + [152.26],
                 },
                 on_master=False,
                 ancestry_path=5,
-            )
+            ),
         )
 
         self.assertEqual(
-            '''\
+            """\
 Commit graph (base is most recent master ancestor with at least one S3 report):
 
     : (master)
@@ -503,22 +523,22 @@
     * dddddddddd          3 reports, total time     0.10s ±    0.05s
     |
     :
-''',
+""",
             print_test_stats.graph(
-                head_sha=fakehash('a'),
+                head_sha=fakehash("a"),
                 head_seconds=0.08,
                 base_seconds={
-                    fakehash('b'): [],
-                    fakehash('c'): [0.09],
-                    fakehash('d'): [0.05, 0.10, 0.15],
+                    fakehash("b"): [],
+                    fakehash("c"): [0.09],
+                    fakehash("d"): [0.05, 0.10, 0.15],
                 },
                 on_master=False,
                 other_ancestors=1,
-            )
+            ),
         )
 
         self.assertEqual(
-            '''\
+            """\
 Commit graph (base is most recent master ancestor with at least one S3 report):
 
     : (master)
@@ -534,24 +554,24 @@
     * dddddddddd         10 reports, total time     5.84s ±    0.92s
     |
     :
-''',
+""",
             print_test_stats.graph(
-                head_sha=fakehash('a'),
+                head_sha=fakehash("a"),
                 head_seconds=5.98,
                 base_seconds={
-                    fakehash('b'): [4.81, 7.23],
-                    fakehash('c'): [],
-                    fakehash('d'): [4.97] * 5 + [6.71] * 5,
+                    fakehash("b"): [4.81, 7.23],
+                    fakehash("c"): [],
+                    fakehash("d"): [4.97] * 5 + [6.71] * 5,
                 },
                 on_master=False,
                 ancestry_path=1,
                 other_ancestors=7,
-            )
+            ),
         )
 
     def test_regression_info(self) -> None:
         self.assertEqual(
-            '''\
+            """\
 ----- Historic stats comparison result ------
 
     job: foo_job
@@ -571,41 +591,48 @@
 Removed  (across    1 suite)      1 test,  totaling -   1.00s
 Modified (across    1 suite)      1 test,  totaling -  41.48s ±   2.12s
 Added    (across    1 suite)      1 test,  totaling +   3.00s
-''',
+""",
             print_test_stats.regression_info(
-                head_sha=fakehash('a'),
-                head_report=make_report_v1({
-                    'Foo': [
-                        makecase('test_foo', 0.02, skipped=True),
-                        makecase('test_baz', 3),
-                    ]}),
+                head_sha=fakehash("a"),
+                head_report=make_report_v1(
+                    {
+                        "Foo": [
+                            makecase("test_foo", 0.02, skipped=True),
+                            makecase("test_baz", 3),
+                        ]
+                    }
+                ),
                 base_reports={
-                    fakehash('b'): [
-                        make_report_v1({
-                            'Foo': [
-                                makecase('test_foo', 40),
-                                makecase('test_bar', 1),
-                            ],
-                        }),
+                    fakehash("b"): [
+                        make_report_v1(
+                            {
+                                "Foo": [
+                                    makecase("test_foo", 40),
+                                    makecase("test_bar", 1),
+                                ],
+                            }
+                        ),
                     ],
-                    fakehash('c'): [
-                        make_report_v1({
-                            'Foo': [
-                                makecase('test_foo', 43),
-                            ],
-                        }),
+                    fakehash("c"): [
+                        make_report_v1(
+                            {
+                                "Foo": [
+                                    makecase("test_foo", 43),
+                                ],
+                            }
+                        ),
                     ],
                 },
-                job_name='foo_job',
+                job_name="foo_job",
                 on_master=False,
                 ancestry_path=0,
                 other_ancestors=0,
-            )
+            ),
         )
 
     def test_regression_info_new_job(self) -> None:
         self.assertEqual(
-            '''\
+            """\
 ----- Historic stats comparison result ------
 
     job: foo_job
@@ -629,25 +656,28 @@
 Removed  (across    0 suites)     0 tests, totaling     0.00s
 Modified (across    0 suites)     0 tests, totaling     0.00s
 Added    (across    1 suite)      2 tests, totaling +   3.02s
-''',
+""",
             print_test_stats.regression_info(
-                head_sha=fakehash('a'),
-                head_report=make_report_v1({
-                    'Foo': [
-                        makecase('test_foo', 0.02, skipped=True),
-                        makecase('test_baz', 3),
-                    ]}),
+                head_sha=fakehash("a"),
+                head_report=make_report_v1(
+                    {
+                        "Foo": [
+                            makecase("test_foo", 0.02, skipped=True),
+                            makecase("test_baz", 3),
+                        ]
+                    }
+                ),
                 base_reports={
-                    fakehash('b'): [],
-                    fakehash('c'): [],
+                    fakehash("b"): [],
+                    fakehash("c"): [],
                 },
-                job_name='foo_job',
+                job_name="foo_job",
                 on_master=False,
                 ancestry_path=3,
                 other_ancestors=2,
-            )
+            ),
         )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_test_history.py b/tools/test/test_test_history.py
index 1b8b5c9..7851ca3 100644
--- a/tools/test/test_test_history.py
+++ b/tools/test/test_test_history.py
@@ -16,36 +16,33 @@
 
 def parse_block(block: List[str]) -> Optional[Example]:
     if block:
-        match = re.match(r'^\$ ([^ ]+) (.*)$', block[0])
+        match = re.match(r"^\$ ([^ ]+) (.*)$", block[0])
         if match:
             cmd, first = match.groups()
             args = []
             for i, line in enumerate([first] + block[1:]):
-                if line.endswith('\\'):
+                if line.endswith("\\"):
                     args.append(line[:-1])
                 else:
                     args.append(line)
                     break
             return {
-                'cmd': cmd,
-                'args': shlex.split(''.join(args)),
-                'lines': block[i + 1:]
+                "cmd": cmd,
+                "args": shlex.split("".join(args)),
+                "lines": block[i + 1 :],
             }
     return None
 
 
 def parse_description(description: str) -> List[Example]:
     examples: List[Example] = []
-    for block in description.split('\n\n'):
-        matches = [
-            re.match(r'^    (.*)$', line)
-            for line in block.splitlines()
-        ]
+    for block in description.split("\n\n"):
+        matches = [re.match(r"^    (.*)$", line) for line in block.splitlines()]
         if all(matches):
             lines = []
             for match in matches:
                 assert match
-                line, = match.groups()
+                (line,) = match.groups()
                 lines.append(line)
             example = parse_block(lines)
             if example:
@@ -62,14 +59,16 @@
         self.assertEqual(len(examples), 3)
         for i, example in enumerate(examples):
             with self.subTest(i=i):
-                self.assertTrue(test_history.__file__.endswith(example['cmd']))
-                expected = example['lines']
-                actual = list(itertools.islice(
-                    test_history.run(example['args']),
-                    len(expected),
-                ))
+                self.assertTrue(test_history.__file__.endswith(example["cmd"]))
+                expected = example["lines"]
+                actual = list(
+                    itertools.islice(
+                        test_history.run(example["args"]),
+                        len(expected),
+                    )
+                )
                 self.assertEqual(actual, expected)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py
index 5ea6fa8..10bf3a2 100644
--- a/tools/test/test_test_selections.py
+++ b/tools/test/test_test_selections.py
@@ -7,37 +7,37 @@
 
 class TestCalculateShards(unittest.TestCase):
     tests: List[str] = [
-        'super_long_test',
-        'long_test1',
-        'long_test2',
-        'normal_test1',
-        'normal_test2',
-        'normal_test3',
-        'short_test1',
-        'short_test2',
-        'short_test3',
-        'short_test4',
-        'short_test5',
+        "super_long_test",
+        "long_test1",
+        "long_test2",
+        "normal_test1",
+        "normal_test2",
+        "normal_test3",
+        "short_test1",
+        "short_test2",
+        "short_test3",
+        "short_test4",
+        "short_test5",
     ]
 
     test_times: Dict[str, float] = {
-        'super_long_test': 55,
-        'long_test1': 22,
-        'long_test2': 18,
-        'normal_test1': 9,
-        'normal_test2': 7,
-        'normal_test3': 5,
-        'short_test1': 1,
-        'short_test2': 0.6,
-        'short_test3': 0.4,
-        'short_test4': 0.3,
-        'short_test5': 0.01,
+        "super_long_test": 55,
+        "long_test1": 22,
+        "long_test2": 18,
+        "normal_test1": 9,
+        "normal_test2": 7,
+        "normal_test3": 5,
+        "short_test1": 1,
+        "short_test2": 0.6,
+        "short_test3": 0.4,
+        "short_test4": 0.3,
+        "short_test5": 0.01,
     }
 
     def assert_shards_equal(
         self,
         expected_shards: List[Tuple[float, List[str]]],
-        actual_shards: List[Tuple[float, List[str]]]
+        actual_shards: List[Tuple[float, List[str]]],
     ) -> None:
         for expected, actual in zip(expected_shards, actual_shards):
             self.assertAlmostEqual(expected[0], actual[0])
@@ -45,53 +45,117 @@
 
     def test_calculate_2_shards_with_complete_test_times(self) -> None:
         expected_shards = [
-            (60, ['super_long_test', 'normal_test3']),
-            (58.31, ['long_test1', 'long_test2', 'normal_test1', 'normal_test2', 'short_test1', 'short_test2',
-                     'short_test3', 'short_test4', 'short_test5'])
+            (60, ["super_long_test", "normal_test3"]),
+            (
+                58.31,
+                [
+                    "long_test1",
+                    "long_test2",
+                    "normal_test1",
+                    "normal_test2",
+                    "short_test1",
+                    "short_test2",
+                    "short_test3",
+                    "short_test4",
+                    "short_test5",
+                ],
+            ),
         ]
-        self.assert_shards_equal(expected_shards, calculate_shards(2, self.tests, self.test_times))
-
+        self.assert_shards_equal(
+            expected_shards, calculate_shards(2, self.tests, self.test_times)
+        )
 
     def test_calculate_5_shards_with_complete_test_times(self) -> None:
         expected_shards = [
-            (55.0, ['super_long_test']),
-            (22.0, ['long_test1', ]),
-            (18.0, ['long_test2', ]),
-            (11.31, ['normal_test1', 'short_test1', 'short_test2', 'short_test3', 'short_test4', 'short_test5']),
-            (12.0, ['normal_test2', 'normal_test3']),
+            (55.0, ["super_long_test"]),
+            (
+                22.0,
+                [
+                    "long_test1",
+                ],
+            ),
+            (
+                18.0,
+                [
+                    "long_test2",
+                ],
+            ),
+            (
+                11.31,
+                [
+                    "normal_test1",
+                    "short_test1",
+                    "short_test2",
+                    "short_test3",
+                    "short_test4",
+                    "short_test5",
+                ],
+            ),
+            (12.0, ["normal_test2", "normal_test3"]),
         ]
-        self.assert_shards_equal(expected_shards, calculate_shards(5, self.tests, self.test_times))
-
+        self.assert_shards_equal(
+            expected_shards, calculate_shards(5, self.tests, self.test_times)
+        )
 
     def test_calculate_2_shards_with_incomplete_test_times(self) -> None:
-        incomplete_test_times = {k: v for k, v in self.test_times.items() if 'test1' in k}
+        incomplete_test_times = {
+            k: v for k, v in self.test_times.items() if "test1" in k
+        }
         expected_shards = [
-            (22.0, ['long_test1', 'long_test2', 'normal_test3', 'short_test3', 'short_test5']),
-            (10.0, ['normal_test1', 'short_test1', 'super_long_test', 'normal_test2', 'short_test2', 'short_test4']),
+            (
+                22.0,
+                [
+                    "long_test1",
+                    "long_test2",
+                    "normal_test3",
+                    "short_test3",
+                    "short_test5",
+                ],
+            ),
+            (
+                10.0,
+                [
+                    "normal_test1",
+                    "short_test1",
+                    "super_long_test",
+                    "normal_test2",
+                    "short_test2",
+                    "short_test4",
+                ],
+            ),
         ]
-        self.assert_shards_equal(expected_shards, calculate_shards(2, self.tests, incomplete_test_times))
-
+        self.assert_shards_equal(
+            expected_shards, calculate_shards(2, self.tests, incomplete_test_times)
+        )
 
     def test_calculate_5_shards_with_incomplete_test_times(self) -> None:
-        incomplete_test_times = {k: v for k, v in self.test_times.items() if 'test1' in k}
+        incomplete_test_times = {
+            k: v for k, v in self.test_times.items() if "test1" in k
+        }
         expected_shards = [
-            (22.0, ['long_test1', 'normal_test2', 'short_test5']),
-            (9.0, ['normal_test1', 'normal_test3']),
-            (1.0, ['short_test1', 'short_test2']),
-            (0.0, ['super_long_test', 'short_test3']),
-            (0.0, ['long_test2', 'short_test4']),
+            (22.0, ["long_test1", "normal_test2", "short_test5"]),
+            (9.0, ["normal_test1", "normal_test3"]),
+            (1.0, ["short_test1", "short_test2"]),
+            (0.0, ["super_long_test", "short_test3"]),
+            (0.0, ["long_test2", "short_test4"]),
         ]
-        self.assert_shards_equal(expected_shards, calculate_shards(5, self.tests, incomplete_test_times))
+        self.assert_shards_equal(
+            expected_shards, calculate_shards(5, self.tests, incomplete_test_times)
+        )
 
     def test_calculate_2_shards_against_optimal_shards(self) -> None:
         for _ in range(100):
             random.seed(120)
             random_times = {k: random.random() * 10 for k in self.tests}
             # all test times except first two
-            rest_of_tests = [i for k, i in random_times.items() if k != 'super_long_test' and k != 'long_test1']
+            rest_of_tests = [
+                i
+                for k, i in random_times.items()
+                if k != "super_long_test" and k != "long_test1"
+            ]
             sum_of_rest = sum(rest_of_tests)
-            random_times['super_long_test'] = max(sum_of_rest / 2, max(rest_of_tests))
-            random_times['long_test1'] = sum_of_rest - random_times['super_long_test']
+            random_times["super_long_test"] = max(sum_of_rest / 2, max(rest_of_tests))
+            random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
             # An optimal sharding would look like the below, but we don't need to compute this for the test:
             # optimal_shards = [
             #     (sum_of_rest, ['super_long_test', 'long_test1']),
@@ -103,10 +167,12 @@
                 # The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
                 self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
                 sorted_tests = sorted(self.tests)
-                sorted_shard_tests = sorted(calculated_shards[0][1] + calculated_shards[1][1])
+                sorted_shard_tests = sorted(
+                    calculated_shards[0][1] + calculated_shards[1][1]
+                )
                 # All the tests should be represented by some shard
                 self.assertEqual(sorted_tests, sorted_shard_tests)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_trailing_newlines.py b/tools/test/test_trailing_newlines.py
index 4f4b662..2631c30 100644
--- a/tools/test/test_trailing_newlines.py
+++ b/tools/test/test_trailing_newlines.py
@@ -4,7 +4,7 @@
 
 
 def correct_trailing_newlines(file_contents: str) -> bool:
-    with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp:
+    with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
         filename = tmp.name
         tmp.write(file_contents)
     return trailing_newlines.correct_trailing_newlines(filename)
@@ -12,38 +12,38 @@
 
 class TestTrailingNewlines(unittest.TestCase):
     def test_empty(self) -> None:
-        self.assertTrue(correct_trailing_newlines(''))
+        self.assertTrue(correct_trailing_newlines(""))
 
     def test_single_byte(self) -> None:
-        self.assertFalse(correct_trailing_newlines('a'))
+        self.assertFalse(correct_trailing_newlines("a"))
 
     def test_single_newline(self) -> None:
-        self.assertFalse(correct_trailing_newlines('\n'))
+        self.assertFalse(correct_trailing_newlines("\n"))
 
     def test_two_newlines(self) -> None:
-        self.assertFalse(correct_trailing_newlines('\n\n'))
+        self.assertFalse(correct_trailing_newlines("\n\n"))
 
     def test_three_newlines(self) -> None:
-        self.assertFalse(correct_trailing_newlines('\n\n\n'))
+        self.assertFalse(correct_trailing_newlines("\n\n\n"))
 
     def test_hello_world(self) -> None:
-        self.assertFalse(correct_trailing_newlines('hello world'))
+        self.assertFalse(correct_trailing_newlines("hello world"))
 
     def test_hello_world_newline(self) -> None:
-        self.assertTrue(correct_trailing_newlines('hello world\n'))
+        self.assertTrue(correct_trailing_newlines("hello world\n"))
 
     def test_hello_world_two_newlines(self) -> None:
-        self.assertFalse(correct_trailing_newlines('hello world\n\n'))
+        self.assertFalse(correct_trailing_newlines("hello world\n\n"))
 
     def test_hello_world_three_newlines(self) -> None:
-        self.assertFalse(correct_trailing_newlines('hello world\n\n\n'))
+        self.assertFalse(correct_trailing_newlines("hello world\n\n\n"))
 
     def test_hello_world_multiline(self) -> None:
-        self.assertFalse(correct_trailing_newlines('hello\nworld'))
+        self.assertFalse(correct_trailing_newlines("hello\nworld"))
 
     def test_hello_world_multiline_gap(self) -> None:
-        self.assertTrue(correct_trailing_newlines('hello\n\nworld\n'))
+        self.assertTrue(correct_trailing_newlines("hello\n\nworld\n"))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/test/test_translate_annotations.py b/tools/test/test_translate_annotations.py
index 867decc..92f0c78 100644
--- a/tools/test/test_translate_annotations.py
+++ b/tools/test/test_translate_annotations.py
@@ -3,10 +3,8 @@
 
 from tools.linter.translate_annotations import parse_annotation, parse_diff, translate
 
-flake8_regex \
-    = r'^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorCode>\w+\d+) (?P<errorDesc>.*)'
-clang_tidy_regex \
-    = r'^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorDesc>.*?) \[(?P<errorCode>.*)\]'
+flake8_regex = r"^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorCode>\w+\d+) (?P<errorDesc>.*)"
+clang_tidy_regex = r"^(?P<filename>.*?):(?P<lineNumber>\d+):(?P<columnNumber>\d+): (?P<errorDesc>.*?) \[(?P<errorCode>.*)\]"
 
 # in the below example patch, note that the filenames differ, so the
 # translation should reflect that as well as the line numbers
@@ -14,7 +12,7 @@
 # $ git clone -b 1.0.2 https://github.com/cscorley/whatthepatch.git
 # $ cd whatthepatch/tests/casefiles
 # $ git diff --no-index --unified=0 lao tzu
-lao_tzu_diff = '''
+lao_tzu_diff = """
 diff --git a/lao b/tzu
 index 635ef2c..5af88a8 100644
 --- a/lao
@@ -30,9 +28,9 @@
 +They both may be called deep and profound.
 +Deeper and more profound,
 +The door of all subtleties!
-'''.lstrip()
+""".lstrip()
 
-sparser_diff = '''
+sparser_diff = """
 diff --git a/foo.txt b/bar.txt
 index 27a6dad..6fae323 100644
 --- a/foo.txt
@@ -46,9 +44,9 @@
 @@ -10,2 +8,0 @@ more lines
 -even more
 -even more
-'''.lstrip()
+""".lstrip()
 
-new_file_diff = '''
+new_file_diff = """
 diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.h b/torch/csrc/jit/tensorexpr/operators/conv2d.h
 new file mode 100644
 index 0000000000..a81eeae346
@@ -74,10 +72,10 @@
 +} // namespace tensorexpr
 +} // namespace jit
 +} // namespace torch
-'''.lstrip()
+""".lstrip()
 
 # fun fact, this example fools VS Code's diff syntax highlighter
-haskell_diff = '''
+haskell_diff = """
 diff --git a/hello.hs b/hello.hs
 index ffb8d4ad14..0872ac9db6 100644
 --- a/hello.hs
@@ -85,7 +83,7 @@
 @@ -1 +1 @@
 --- a/hello/world/example
 +main = putStrLn "Hello, world!"
-'''.lstrip()
+""".lstrip()
 
 
 class TestTranslateAnnotations(unittest.TestCase):
@@ -95,25 +93,25 @@
         self.assertEqual(
             parse_diff(lao_tzu_diff),
             {
-                'old_filename': 'lao',
-                'hunks': [
+                "old_filename": "lao",
+                "hunks": [
                     {
-                        'old_start': 1,
-                        'old_count': 2,
-                        'new_start': 0,
-                        'new_count': 0,
+                        "old_start": 1,
+                        "old_count": 2,
+                        "new_start": 0,
+                        "new_count": 0,
                     },
                     {
-                        'old_start': 4,
-                        'old_count': 1,
-                        'new_start': 2,
-                        'new_count': 2,
+                        "old_start": 4,
+                        "old_count": 1,
+                        "new_start": 2,
+                        "new_count": 2,
                     },
                     {
-                        'old_start': 11,
-                        'old_count': 0,
-                        'new_start': 11,
-                        'new_count': 3,
+                        "old_start": 11,
+                        "old_count": 0,
+                        "new_start": 11,
+                        "new_count": 3,
                     },
                 ],
             },
@@ -123,13 +121,13 @@
         self.assertEqual(
             parse_diff(new_file_diff),
             {
-                'old_filename': None,
-                'hunks': [
+                "old_filename": None,
+                "hunks": [
                     {
-                        'old_start': 0,
-                        'old_count': 0,
-                        'new_start': 1,
-                        'new_count': 19,
+                        "old_start": 0,
+                        "old_count": 0,
+                        "new_start": 1,
+                        "new_count": 19,
                     },
                 ],
             },
@@ -139,13 +137,13 @@
         self.assertEqual(
             parse_diff(haskell_diff),
             {
-                'old_filename': 'hello.hs',
-                'hunks': [
+                "old_filename": "hello.hs",
+                "hunks": [
                     {
-                        'old_start': 1,
-                        'old_count': 1,
-                        'new_start': 1,
-                        'new_count': 1,
+                        "old_start": 1,
+                        "old_count": 1,
+                        "new_start": 1,
+                        "new_count": 1,
                     },
                 ],
             },
@@ -197,7 +195,7 @@
         self.assertEqual(translate(diff, 15), 13)
 
     def test_translate_empty(self) -> None:
-        diff = parse_diff('--- a/foo')
+        diff = parse_diff("--- a/foo")
 
         # again, we start numbering at 1
         self.assertEqual(translate(diff, -1), None)
@@ -252,29 +250,29 @@
     def test_parse_annotation_flake8(self) -> None:
         regex = re.compile(flake8_regex)
         self.assertEqual(
-            parse_annotation(regex, 'README.md:1:3: R100 make a better title'),
+            parse_annotation(regex, "README.md:1:3: R100 make a better title"),
             {
-                'filename': 'README.md',
-                'lineNumber': 1,
-                'columnNumber': 3,
-                'errorCode': 'R100',
-                'errorDesc': 'make a better title',
+                "filename": "README.md",
+                "lineNumber": 1,
+                "columnNumber": 3,
+                "errorCode": "R100",
+                "errorDesc": "make a better title",
             },
         )
 
     def test_parse_annotation_clang_tidy(self) -> None:
         regex = re.compile(clang_tidy_regex)
         self.assertEqual(
-            parse_annotation(regex, 'README.md:2:1: improve description [R200]'),
+            parse_annotation(regex, "README.md:2:1: improve description [R200]"),
             {
-                'filename': 'README.md',
-                'lineNumber': 2,
-                'columnNumber': 1,
-                'errorCode': 'R200',
-                'errorDesc': 'improve description',
+                "filename": "README.md",
+                "lineNumber": 2,
+                "columnNumber": 1,
+                "errorCode": "R200",
+                "errorDesc": "improve description",
             },
         )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
diff --git a/tools/testing/explicit_ci_jobs.py b/tools/testing/explicit_ci_jobs.py
index 5944d22..3de04e1 100755
--- a/tools/testing/explicit_ci_jobs.py
+++ b/tools/testing/explicit_ci_jobs.py
@@ -45,7 +45,13 @@
     if requires is not None:
         for requirement in requires:
             dependency = past_jobs[requirement]
-            add_job(workflows, dependency["workflow_name"], dependency["type"], dependency["job"], past_jobs)
+            add_job(
+                workflows,
+                dependency["workflow_name"],
+                dependency["type"],
+                dependency["job"],
+                past_jobs,
+            )
 
     workflows[workflow_name]["jobs"].append({type: job})
 
@@ -88,13 +94,16 @@
 def commit_ci(files: List[str], message: str) -> None:
     # Check that there are no other modified files than the ones edited by this
     # tool
-    stdout = subprocess.run(["git", "status", "--porcelain"], stdout=subprocess.PIPE).stdout.decode()
+    stdout = subprocess.run(
+        ["git", "status", "--porcelain"], stdout=subprocess.PIPE
+    ).stdout.decode()
     for line in stdout.split("\n"):
         if line == "":
             continue
         if line[0] != " ":
-            raise RuntimeError(f"Refusing to commit while other changes are already staged: {line}")
-
+            raise RuntimeError(
+                f"Refusing to commit while other changes are already staged: {line}"
+            )
 
     # Make the commit
     subprocess.run(["git", "add"] + files)
@@ -107,10 +116,12 @@
     )
     parser.add_argument("--job", action="append", help="job name", default=[])
     parser.add_argument(
-        "--filter-gha", help="keep only these github actions (glob match)", default=''
+        "--filter-gha", help="keep only these github actions (glob match)", default=""
     )
     parser.add_argument(
-        "--make-commit", action="store_true", help="add change to git with to a do-not-merge commit"
+        "--make-commit",
+        action="store_true",
+        help="add change to git with to a do-not-merge commit",
     )
     args = parser.parse_args()
 
@@ -118,7 +129,9 @@
     with open(CONFIG_YML, "r") as f:
         config_yml = yaml.safe_load(f.read())
 
-    config_yml["workflows"] = get_filtered_circleci_config(config_yml["workflows"], args.job)
+    config_yml["workflows"] = get_filtered_circleci_config(
+        config_yml["workflows"], args.job
+    )
 
     with open(CONFIG_YML, "w") as f:
         yaml.dump(config_yml, f)
@@ -131,13 +144,15 @@
                 path.resolve().unlink()
 
     if args.make_commit:
-        jobs_str = '\n'.join([f" * {job}" for job in args.job])
-        message = textwrap.dedent(f"""
+        jobs_str = "\n".join([f" * {job}" for job in args.job])
+        message = textwrap.dedent(
+            f"""
         [skip ci][do not merge] Edit config.yml to filter specific jobs
 
         Filter CircleCI to only run:
         {jobs_str}
 
         See [Run Specific CI Jobs](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md#run-specific-ci-jobs) for details.
-        """).strip()
+        """
+        ).strip()
         commit_ci([str(f.relative_to(REPO_ROOT)) for f in touched_files], message)
diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py
index f09b87a..b11d0dd 100644
--- a/tools/testing/test_selections.py
+++ b/tools/testing/test_selections.py
@@ -6,16 +6,16 @@
 from tools.stats.s3_stat_parser import (
     get_previous_reports_for_branch,
     get_previous_reports_for_pr,
-    Report, Version2Report,
-    HAVE_BOTO3)
-from tools.stats.import_test_stats import (
-    get_disabled_tests,
-    get_slow_tests
+    Report,
+    Version2Report,
+    HAVE_BOTO3,
 )
+from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
 
 from typing import Any, Dict, List, Optional, Tuple, cast
 from typing_extensions import TypedDict
 
+
 class JobTimeJSON(TypedDict):
     commit: str
     JOB_BASE_NAME: str
@@ -23,50 +23,55 @@
 
 
 def _get_stripped_CI_job() -> str:
-    """E.g. convert 'pytorch_windows_vs2019_py36_cuda10.1_build' to 'pytorch_windows_vs2019_py36_cuda10.1'.
-    """
-    job = os.environ.get("JOB_BASE_NAME", "").rstrip('0123456789')
-    if job.endswith('_slow_test'):
-        job = job[:len(job) - len('_slow_test')]
-    elif job.endswith('_test') or job.endswith('-test'):
-        job = job[:len(job) - len('_test')]
-    elif job.endswith('_build') or job.endswith('-build'):
-        job = job[:len(job) - len('_build')]
+    """E.g. convert 'pytorch_windows_vs2019_py36_cuda10.1_build' to 'pytorch_windows_vs2019_py36_cuda10.1'."""
+    job = os.environ.get("JOB_BASE_NAME", "").rstrip("0123456789")
+    if job.endswith("_slow_test"):
+        job = job[: len(job) - len("_slow_test")]
+    elif job.endswith("_test") or job.endswith("-test"):
+        job = job[: len(job) - len("_test")]
+    elif job.endswith("_build") or job.endswith("-build"):
+        job = job[: len(job) - len("_build")]
     return job
 
 
 def _get_job_times_json(job_times: Dict[str, float]) -> JobTimeJSON:
     return {
-        'commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip(),
-        'JOB_BASE_NAME': _get_stripped_CI_job(),
-        'job_times': job_times,
+        "commit": subprocess.check_output(
+            ["git", "rev-parse", "HEAD"], encoding="ascii"
+        ).strip(),
+        "JOB_BASE_NAME": _get_stripped_CI_job(),
+        "job_times": job_times,
     }
 
 
 def _calculate_job_times(reports: List["Report"]) -> Dict[str, float]:
-    """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))
-    """
+    """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))"""
     jobs_to_times: Dict[str, Tuple[float, int]] = dict()
     for report in reports:
         v_report = cast(Version2Report, report)
-        assert 'format_version' in v_report.keys() and v_report.get('format_version') == 2, \
-            "S3 format currently handled is version 2 only"
-        files: Dict[str, Any] = v_report['files']
+        assert (
+            "format_version" in v_report.keys() and v_report.get("format_version") == 2
+        ), "S3 format currently handled is version 2 only"
+        files: Dict[str, Any] = v_report["files"]
         for name, test_file in files.items():
             if name not in jobs_to_times:
-                jobs_to_times[name] = (test_file['total_seconds'], 1)
+                jobs_to_times[name] = (test_file["total_seconds"], 1)
             else:
                 curr_avg, curr_count = jobs_to_times[name]
                 new_count = curr_count + 1
-                new_avg = (curr_avg * curr_count + test_file['total_seconds']) / new_count
+                new_avg = (
+                    curr_avg * curr_count + test_file["total_seconds"]
+                ) / new_count
                 jobs_to_times[name] = (new_avg, new_count)
 
     return {job: time for job, (time, _) in jobs_to_times.items()}
 
 
-def calculate_shards(num_shards: int, tests: List[str], job_times: Dict[str, float]) -> List[Tuple[float, List[str]]]:
+def calculate_shards(
+    num_shards: int, tests: List[str], job_times: Dict[str, float]
+) -> List[Tuple[float, List[str]]]:
     filtered_job_times: Dict[str, float] = dict()
-    unknown_jobs : List[str] = []
+    unknown_jobs: List[str] = []
     for test in tests:
         if test in job_times:
             filtered_job_times[test] = job_times[test]
@@ -75,13 +80,18 @@
 
     # The following attempts to implement a partition approximation greedy algorithm
     # See more at https://en.wikipedia.org/wiki/Greedy_number_partitioning
-    sorted_jobs = sorted(filtered_job_times, key=lambda j: filtered_job_times[j], reverse=True)
+    sorted_jobs = sorted(
+        filtered_job_times, key=lambda j: filtered_job_times[j], reverse=True
+    )
     sharded_jobs: List[Tuple[float, List[str]]] = [(0.0, []) for _ in range(num_shards)]
     for job in sorted_jobs:
         min_shard_index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0]
         curr_shard_time, curr_shard_jobs = sharded_jobs[min_shard_index]
         curr_shard_jobs.append(job)
-        sharded_jobs[min_shard_index] = (curr_shard_time + filtered_job_times[job], curr_shard_jobs)
+        sharded_jobs[min_shard_index] = (
+            curr_shard_time + filtered_job_times[job],
+            curr_shard_jobs,
+        )
 
     # Round robin the unknown jobs starting with the smallest shard
     index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0]
@@ -94,14 +104,20 @@
 def _pull_job_times_from_S3() -> Dict[str, float]:
     if HAVE_BOTO3:
         ci_job_prefix = _get_stripped_CI_job()
-        s3_reports: List["Report"] = get_previous_reports_for_branch('origin/viable/strict', ci_job_prefix)
+        s3_reports: List["Report"] = get_previous_reports_for_branch(
+            "origin/viable/strict", ci_job_prefix
+        )
     else:
-        print('Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser.')
-        print('If not installed, please install boto3 for automatic sharding and test categorization.')
+        print(
+            "Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser."
+        )
+        print(
+            "If not installed, please install boto3 for automatic sharding and test categorization."
+        )
         s3_reports = []
 
     if len(s3_reports) == 0:
-        print('Gathered no reports from S3. Please proceed without them.')
+        print("Gathered no reports from S3. Please proceed without them.")
         return dict()
 
     return _calculate_job_times(s3_reports)
@@ -116,20 +132,26 @@
         with open(test_times_file) as file:
             test_times_json: JobTimeJSON = json.load(file)
 
-        curr_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip()
-        file_commit = test_times_json.get('commit', '')
+        curr_commit = subprocess.check_output(
+            ["git", "rev-parse", "HEAD"], encoding="ascii"
+        ).strip()
+        file_commit = test_times_json.get("commit", "")
         curr_ci_job = _get_stripped_CI_job()
-        file_ci_job = test_times_json.get('JOB_BASE_NAME', 'N/A')
+        file_ci_job = test_times_json.get("JOB_BASE_NAME", "N/A")
         if curr_commit != file_commit:
-            print(f'Current test times file is from different commit {file_commit}.')
+            print(f"Current test times file is from different commit {file_commit}.")
         elif curr_ci_job != file_ci_job:
-            print(f'Current test times file is for different CI job {file_ci_job}.')
+            print(f"Current test times file is for different CI job {file_ci_job}.")
         else:
-            print(f'Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values.')
-            return test_times_json.get('job_times', {})
+            print(
+                f"Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values."
+            )
+            return test_times_json.get("job_times", {})
 
         # Found file, but commit or CI job in JSON doesn't match
-        print(f'Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}')
+        print(
+            f"Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}"
+        )
 
     job_times = export_S3_test_times(test_times_file)
 
@@ -142,14 +164,18 @@
         return test_modules
     report = reports[0][0]
     v_report = cast(Version2Report, report)
-    assert 'format_version' in v_report.keys() and v_report.get('format_version') == 2, \
-        "S3 format currently handled is version 2 only"
-    files: Dict[str, Any] = v_report['files']
+    assert (
+        "format_version" in v_report.keys() and v_report.get("format_version") == 2
+    ), "S3 format currently handled is version 2 only"
+    files: Dict[str, Any] = v_report["files"]
     for fname, file in files.items():
         contains_failure = any(
-            any(case['status'] == 'errored' or case['status'] == 'failed'
-                for _, case in suite['cases'].items())
-            for _, suite in file['suites'].items())
+            any(
+                case["status"] == "errored" or case["status"] == "failed"
+                for _, case in suite["cases"].items()
+            )
+            for _, suite in file["suites"].items()
+        )
         if contains_failure:
             test_modules.append(fname)
     return test_modules
@@ -168,14 +194,15 @@
     return lines
 
 
-def get_shard_based_on_S3(which_shard: int, num_shards: int, tests: List[str], test_times_file: str) -> List[str]:
-    """Get sharded test allocation based on historic S3 data.
-    """
+def get_shard_based_on_S3(
+    which_shard: int, num_shards: int, tests: List[str], test_times_file: str
+) -> List[str]:
+    """Get sharded test allocation based on historic S3 data."""
     jobs_to_times = _query_past_job_times(test_times_file)
 
     # Got no stats from S3, returning early to save runtime
     if len(jobs_to_times) == 0:
-        print('Gathered no stats from S3. Proceeding with default sharding plan.')
+        print("Gathered no stats from S3. Proceeding with default sharding plan.")
         return tests[which_shard - 1 :: num_shards]
 
     shards = calculate_shards(num_shards, tests, jobs_to_times)
@@ -183,14 +210,15 @@
     return tests_from_shard
 
 
-def get_slow_tests_based_on_S3(test_list: List[str], td_list: List[str], slow_test_threshold: int) -> List[str]:
-    """Get list of slow tests based on historic S3 data.
-    """
+def get_slow_tests_based_on_S3(
+    test_list: List[str], td_list: List[str], slow_test_threshold: int
+) -> List[str]:
+    """Get list of slow tests based on historic S3 data."""
     jobs_to_times: Dict[str, float] = _query_past_job_times()
 
     # Got no stats from S3, returning early to save runtime
     if len(jobs_to_times) == 0:
-        print('Gathered no stats from S3. No new slow tests calculated.')
+        print("Gathered no stats from S3. No new slow tests calculated.")
         return []
 
     slow_tests: List[str] = []
@@ -202,38 +230,42 @@
 
 
 def get_specified_test_cases(filename: str, tests: List[str]) -> Dict[str, List[str]]:
-    """Get test cases from a specified test case file. Usually exported manually or through CI system.
-    """
+    """Get test cases from a specified test case file. Usually exported manually or through CI system."""
     if not os.path.exists(filename):
-        print(f'Could not find specified tests file: {filename}. Proceeding with default behavior.')
+        print(
+            f"Could not find specified tests file: {filename}. Proceeding with default behavior."
+        )
         return dict()
 
     # The below encoding is utf-8-sig because utf-8 doesn't properly handle the byte-order-mark character
-    with open(filename, mode='r', encoding="utf-8-sig") as csv_file:
+    with open(filename, mode="r", encoding="utf-8-sig") as csv_file:
         csv_reader = csv.DictReader(csv_file)
         line_count = 0
         specified_test_case_dict: Dict[str, List[str]] = dict()
         for row in csv_reader:
             line_count += 1
             if line_count == 1:
-                if 'test_filename' not in row or 'test_case_name' not in row:
-                    print('Data is missing necessary columns for test specification. Proceeding with default behavior.')
+                if "test_filename" not in row or "test_case_name" not in row:
+                    print(
+                        "Data is missing necessary columns for test specification. Proceeding with default behavior."
+                    )
                     return dict()
-            test_filename = row['test_filename']
-            test_case_name = row['test_case_name']
+            test_filename = row["test_filename"]
+            test_case_name = row["test_case_name"]
             if test_filename not in tests:
-                print(f'Specified test_filename {test_filename} not found in TESTS. Skipping.')
+                print(
+                    f"Specified test_filename {test_filename} not found in TESTS. Skipping."
+                )
                 continue
             if test_filename not in specified_test_case_dict:
                 specified_test_case_dict[test_filename] = []
             specified_test_case_dict[test_filename].append(test_case_name)
-        print(f'Processed {line_count} test cases.')
+        print(f"Processed {line_count} test cases.")
         return specified_test_case_dict
 
 
 def get_reordered_tests(tests: List[str], is_reordering_by_pr: bool) -> List[str]:
-    """Get the reordered test filename list based on github PR history or git changed file.
-    """
+    """Get the reordered test filename list based on github PR history or git changed file."""
     prioritized_tests = []
     # Try using historic stats from PR.
     if is_reordering_by_pr and HAVE_BOTO3:
@@ -241,7 +273,8 @@
         if len(pr_number):
             ci_job_prefix = _get_stripped_CI_job()
             s3_reports: List[Tuple["Report", str]] = get_previous_reports_for_pr(
-                pr_number, ci_job_prefix)
+                pr_number, ci_job_prefix
+            )
             prioritized_tests = _query_failure_test_module(s3_reports)
             print("Prioritized test from previous CI info.")
 
@@ -254,9 +287,11 @@
             return tests
 
         prefix = f"test{os.path.sep}"
-        prioritized_tests = [f for f in changed_files if f.startswith(prefix) and f.endswith(".py")]
-        prioritized_tests = [f[len(prefix):] for f in prioritized_tests]
-        prioritized_tests = [f[:-len(".py")] for f in prioritized_tests]
+        prioritized_tests = [
+            f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
+        ]
+        prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
+        prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
         print("Prioritized test from test file changes.")
 
     bring_to_front = []
@@ -268,12 +303,16 @@
         else:
             the_rest.append(test)
     if len(tests) == len(bring_to_front) + len(the_rest):
-        print(f"reordering tests for PR:\n"
-              f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n")
+        print(
+            f"reordering tests for PR:\n"
+            f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
+        )
         return bring_to_front + the_rest
     else:
-        print(f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
-              f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n")
+        print(
+            f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
+            f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
+        )
         return tests
 
 
@@ -281,13 +320,13 @@
 def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str, float]:
     test_times: Dict[str, float] = _pull_job_times_from_S3()
     if test_times_filename is not None:
-        print(f'Exporting S3 test stats to {test_times_filename}.')
+        print(f"Exporting S3 test stats to {test_times_filename}.")
         if os.path.exists(test_times_filename):
-            print(f'Overwriting existent file: {test_times_filename}')
-        with open(test_times_filename, 'w+') as file:
+            print(f"Overwriting existent file: {test_times_filename}")
+        with open(test_times_filename, "w+") as file:
             job_times_json = _get_job_times_json(test_times)
-            json.dump(job_times_json, file, indent='    ', separators=(',', ': '))
-            file.write('\n')
+            json.dump(job_times_json, file, indent="    ", separators=(",", ": "))
+            file.write("\n")
     return test_times
 
 
diff --git a/tools/update_masked_docs.py b/tools/update_masked_docs.py
index 6d705d5..87ee083 100644
--- a/tools/update_masked_docs.py
+++ b/tools/update_masked_docs.py
@@ -7,24 +7,26 @@
 
 import os
 
+
 def main() -> None:
 
-    target = os.path.join('torch', '_masked', '_docs.py')
+    target = os.path.join("torch", "_masked", "_docs.py")
 
     try:
         import torch
     except ImportError as msg:
-        print(f'Failed to import torch required to build {target}: {msg}')
+        print(f"Failed to import torch required to build {target}: {msg}")
         return
 
     if os.path.isfile(target):
         with open(target) as _f:
             current_content = _f.read()
     else:
-        current_content = ''
+        current_content = ""
 
     _new_content = []
-    _new_content.append('''\
+    _new_content.append(
+        """\
 # -*- coding: utf-8 -*-
 # This file is generated, do not modify it!
 #
@@ -35,24 +37,25 @@
 # The script must be called from an environment where the development
 # version of torch package can be imported and is functional.
 #
-''')
+"""
+    )
 
     for func_name in sorted(torch._masked.__all__):
         func = getattr(torch._masked, func_name)
         func_doc = torch._masked._generate_docstring(func)
         _new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')
 
-    new_content = '\n'.join(_new_content)
+    new_content = "\n".join(_new_content)
 
     if new_content == current_content:
-        print(f'Nothing to update in {target}')
+        print(f"Nothing to update in {target}")
         return
 
-    with open(target, 'w') as _f:
+    with open(target, "w") as _f:
         _f.write(new_content)
 
-    print(f'Successfully updated {target}')
+    print(f"Successfully updated {target}")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/tools/vscode_settings.py b/tools/vscode_settings.py
index 88dbfb4..5c7fa87 100755
--- a/tools/vscode_settings.py
+++ b/tools/vscode_settings.py
@@ -5,17 +5,17 @@
 
 
 def main() -> None:
-    folder = Path('.vscode')
-    recommended = json.loads((folder / 'settings_recommended.json').read_text())
-    path = folder / 'settings.json'
+    folder = Path(".vscode")
+    recommended = json.loads((folder / "settings_recommended.json").read_text())
+    path = folder / "settings.json"
     try:
         current = json.loads(path.read_text())
     except Exception:
         current = {}
-    with open(path, 'w') as f:
+    with open(path, "w") as f:
         json.dump({**current, **recommended}, f, indent=2)
-        f.write('\n')
+        f.write("\n")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()