| import torch |
| from torch import Tensor |
| from typing import Callable, List |
| |
| import re |
| |
| __all__ : List[str] = [] |
| |
| class _CodeParser: |
| def __init__(self, code_string: str): |
| optional_ws = r"\s*" |
| required_ws = r"\s+" |
| template_params = r"(?P<template_params>\<.+\>)" |
| return_type = r"(?P<return_type>\w+)" |
| function_name = r"(?P<function_name>\w+)" |
| function_params = r"(?P<function_params>\(.+\))" |
| function_body = r"(?P<function_body>\{.+\})" |
| |
| pattern = \ |
| optional_ws \ |
| + "template" \ |
| + optional_ws + template_params \ |
| + optional_ws + return_type \ |
| + required_ws + function_name \ |
| + optional_ws + function_params \ |
| + optional_ws + function_body \ |
| + optional_ws |
| |
| result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline |
| |
| if result is None: |
| raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}") |
| |
| self.template_params = result["template_params"] |
| self.return_type = result["return_type"] |
| self.function_name = result["function_name"] |
| self.function_params = result["function_params"] |
| self.function_body = result["function_body"] |
| |
| class _JittedFunction: |
| def __init__(self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs): |
| self.code_string = code_string |
| |
| assert return_by_ref or num_outputs == 1, "Return by value only works for single output. " |
| self.return_by_ref = return_by_ref |
| self.num_outputs = num_outputs |
| |
| parsed_code = _CodeParser(code_string) |
| self.kernel_name = parsed_code.function_name |
| |
| self.kwargs_dict = kwargs |
| self.is_cuda_available = torch.cuda.is_available() |
| |
| def __call__(self, *tensors: Tensor, **kwargs): |
| # Jiterator follow torch.cuda's lazy initialization behavior |
| # Defer checking cuda's availability at the function invocation time |
| assert self.is_cuda_available, "Jiterator is only supported on CUDA GPUs, no CUDA GPUs are available." |
| |
| assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." |
| |
| expanded_kwargs = self.kwargs_dict.copy() |
| for key, value in kwargs.items(): |
| if key in self.kwargs_dict: |
| expanded_kwargs[key] = value |
| else: |
| raise KeyError(f"{key} is not declared in function definition") |
| |
| return torch._C._cuda_jiterator_compile_and_launch_kernel( |
| self.code_string, |
| self.kernel_name, |
| self.return_by_ref, |
| self.num_outputs, |
| tensors, |
| expanded_kwargs) |
| |
| |
| def _create_jit_fn(code_string: str, **kwargs) -> Callable: |
| """ |
| Create a jiterator-generated cuda kernel for an elementwise op. |
| |
| The code string has to be a valid CUDA function that describes the computation for a single element. The code |
| string has to follow the c++ template pattern, as shown in the example below. This function will be inlined |
| into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as |
| local temp dir. |
| |
| Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion. |
| |
| Args: |
| code_string (string): CUDA code string to be compiled by jiterator. The entry functor must return by value. |
| kwargs (Dict, optional): Keyword arguments for generated function |
| |
| Example:: |
| |
| code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" |
| jitted_fn = create_jit_fn(code_string, alpha=1.0) |
| a = torch.rand(3, device='cuda') |
| b = torch.rand(3, device='cuda') |
| # invoke jitted function like a regular python function |
| result = jitted_fn(a, b, alpha=3.14) |
| |
| code_string also allows mulitple function definitions, and the last function will be treated as the entry function. |
| |
| Example:: |
| |
| code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" |
| code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" |
| jitted_fn = create_jit_fn(code_string, val=0.0) |
| a = torch.rand(3, device='cuda') |
| b = torch.rand(3, device='cuda') |
| # invoke jitted function like a regular python function |
| result = jitted_fn(a, b) # using default val=0.0 |
| |
| Jiterator can be used together with python registration to override an operator's cuda kernel. |
| Following example is overriding gelu's cuda kernel with relu. |
| |
| Example:: |
| |
| code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }" |
| my_gelu = create_jit_fn(code_string) |
| my_lib = torch.library.Library("aten", "IMPL") |
| my_lib.impl('aten::gelu', my_gelu, "CUDA") |
| # torch.nn.GELU and torch.nn.function.gelu are now overridden |
| a = torch.rand(3, device='cuda') |
| torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) |
| |
| .. warning:: |
| This API is in beta and may change in future releases. |
| |
| .. warning:: |
| This API only supports up to 8 inputs and 1 output |
| |
| .. warning:: |
| All input tensors must live in CUDA device |
| """ |
| |
| return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) |
| |
| def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Callable: |
| """ |
| Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. |
| |
| Args: |
| code_string (string): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. |
| num_outputs(int): number of outputs return by the kernel |
| kwargs (Dict, optional): Keyword arguments for generated function |
| |
| Example:: |
| |
| code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" |
| jitted_fn = create_jit_fn(code_string, alpha=1.0) |
| a = torch.rand(3, device='cuda') |
| b = torch.rand(3, device='cuda') |
| # invoke jitted function like a regular python function |
| result = jitted_fn(a, b, alpha=3.14) |
| |
| .. warning:: |
| This API is in beta and may change in future releases. |
| |
| .. warning:: |
| This API only supports up to 8 inputs and 8 outputs |
| """ |
| |
| return _JittedFunction(code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs) |