| # mypy: allow-untyped-defs |
| import re |
| |
| import torch._C as C |
| |
| |
| """ |
| PythonDispatcher class is a thin python-binding to C++ dispatcher and it |
| is designed to show how dispatcher precompute works. In particular, |
| it shows for a certain op `foo`, what the computed dispatch table looks |
| like after user register their kernels to certains dispatch keys. |
| |
| In the real C++ dispatcher we support many dispatch keys for different |
| functionalities. For simplicity PythonDispatcher only supports dispatch |
| keys for a single example of each use case. These use cases are listed below: |
| |
| - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & |
| autograd kernel in pytorch core library. |
| E.g. CPU, CUDA |
| - FPGA/AutogradOther: represents in-tree backends which we usually have backend specific |
| inference kernels, but they share the same autograd kernel specified in AutogradOther. |
| E.g. FPGA, SparseCsrCPU |
| - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd |
| kernel defined in pytorch core library. Backend owner is responsible for registering both |
| inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. |
| E.g. XLA, XPU, MPS |
| - CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. |
| Kernels registered to this key MUST work for inference for all backends. |
| - Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. |
| Kernels registered to this key MUST work for autograd for all backends. |
| - CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd |
| Kernels registered to this key MUST work for both inference + autograd for all backends. |
| |
| Note we only allow registrations to alias keys inside pytorch core library. E.g |
| you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd |
| kernel from torch-xla extension, instead you should upstream the kernel into |
| pytorch/pytorch repo so that it's available for all backends and continuously |
| tested even without the extension. |
| |
| Usage: |
| dispatcher = PythonDispatcher() |
| dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) |
| print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. |
| # For more debugging information |
| # print(dispatcher.keys()) |
| # print(dispatcher.registrations()) |
| # print(dispatcher.rawRegistrations()) |
| # print(dispatcher.rawDispatchTable()) |
| PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. |
| This file only provides the simplified API for developers, relevant test code is located in |
| test/test_dispatch.py |
| """ |
| |
| |
| class PythonDispatcher: |
| namespace = "__test__" |
| name = "foo" |
| # fmt: off |
| runtime_keys = [ |
| "CPU", "AutogradCPU", |
| "FPGA", "AutogradOther", |
| "XLA", "AutogradXLA", |
| "Lazy", "AutogradLazy", |
| ] |
| # fmt: on |
| alias_keys = [ |
| "CompositeExplicitAutograd", |
| "Autograd", |
| "CompositeImplicitAutograd", |
| ] |
| supported_keys = runtime_keys + alias_keys |
| |
| def __init__(self) -> None: |
| C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] |
| self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") |
| self.ref.def_("foo(Tensor x) -> Tensor") |
| |
| """ |
| Returns a list of dispatch keys supported by PythonDispatcher. |
| You can register kernels to these keys. |
| """ |
| |
| def keys(self): |
| return self.supported_keys |
| |
| """ |
| Register kernels to the target dispatchKeys. |
| dispatchKeys(list[str]): a list of dispatch keys that you want to register |
| your own kernel. Note that you don't need to write the kernel yourself in |
| this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is |
| automatically generated and registered. |
| """ |
| |
| def register(self, dispatchKeys): |
| # Overriden is not supported and triggers a warning in C++ dispatcher. |
| if len(set(dispatchKeys)) != len(dispatchKeys): |
| raise RuntimeError( |
| f"Overriden is not allowed but found duplicates in {dispatchKeys}." |
| ) |
| # We currently forbid this in codegen instead of C++ dispatcher. |
| if ( |
| "CompositeImplicitAutograd" in dispatchKeys |
| and "CompositeExplicitAutograd" in dispatchKeys |
| ): |
| raise RuntimeError( |
| "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." |
| ) |
| for key in dispatchKeys: |
| if key not in self.supported_keys: |
| raise RuntimeError( |
| f"{key} is not supported, please select a dispatch key in {self.supported_keys}." |
| ) |
| self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) |
| |
| """ |
| Helper function to format (key, kernel). |
| """ |
| |
| def _format_line(self, key, kernel): |
| return f"{key:<15} {kernel}\n" |
| |
| """ |
| Helper function to print a table header. |
| """ |
| |
| def _format_header(self, header): |
| s = f""" |
| {header} |
| """ |
| s += self._format_line("key", "kernel") |
| s += "---------------------------\n" |
| return s |
| |
| """ |
| Returns raw output of all registration info for debugging only. |
| Use registrations() for a simplified version. |
| """ |
| |
| def rawRegistrations(self): |
| return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] |
| |
| """ |
| Returns raw output of computed dispatch table for debugging only. |
| Use dispatchTable() for a simplified version. |
| """ |
| |
| def rawDispatchTable(self): |
| return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] |
| |
| """ |
| Returns a table(str) including all the registrations from users. |
| Note this includes registrations to both runtime keys and alias keys. |
| """ |
| |
| def registrations(self): |
| output = self._format_header("Registered Kernels") |
| state = self.rawRegistrations() |
| state_entries = state.split("\n") |
| for line in state_entries: |
| first = line.split(":")[0] |
| if any(first.startswith(k) for k in self.supported_keys): |
| kernel = line.split("::")[0].split(" ")[1] |
| output += self._format_line(first, kernel) |
| return output |
| |
| """ |
| Returns the computed dispatch table(str). Note this only include |
| runtime keys, registrations to alias keys have been decoded to their |
| mapped runtime keys. |
| """ |
| |
| def dispatchTable(self): |
| output = self._format_header("Computed Dispatch Table") |
| table = self.rawDispatchTable() |
| table_entries = table.split("\n") |
| regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") |
| for line in table_entries: |
| k = line.split(":")[0] |
| if k in self.runtime_keys: |
| entry = regex.sub("[", line) |
| output += self._format_line(k, entry.split(": ")[1]) |
| return output |