[torchgen] Fix multiple backends with custom namespace (#82133)

Summary:
Some quantized operators needs `QuantizedCPU` backend, due to an issue in namespace checking, currently if we have two backends as well as a custom namespaces in native function, codegen will hit assertion error. This PR fixes this issue

The root cause is that codegen right now asserts that a native function should only have one namespace. The current behavior is that If a native function is not found in a `BackendIndex`, we will use default namespace for that backend, for fallback kernels. However that default namespace may not be listed in the yaml file and it should not be counted when checking if we have two different namespaces for that backend. In our error case, we have 2 `BackendIndex`, one for `QuantizedCPU` and one for `CPU`. The native function doesn't have a kernel in `QuantizedCPU` but we still use a default namespace (`at::native`) for it. Since we have a custom namespace for dispatch key `CPU`, we ran into the assertion error.

This PR changes the assertion criteria. We only error out if a namespace has two or more kernels and they have two or more different namespaces.

Test Plan: rely on newly added unit test

Differential Revision: D38101345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82133
Approved by: https://github.com/iseeyuan
diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl
index 959a73d..2ca59ee 100644
--- a/tools/BUCK.bzl
+++ b/tools/BUCK.bzl
@@ -261,3 +261,16 @@
             ":gen_operators_yaml_lib",
         ],
     )
+
+    python_test(
+        name = "test_codegen",
+        srcs = [
+            "test/test_codegen.py",
+        ],
+        contacts = contacts,
+        visibility = ["PUBLIC"],
+        deps = [
+            torchgen_deps,
+            ":autograd",
+        ],
+    )
diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py
index 21437d0..a965337 100644
--- a/tools/test/test_codegen.py
+++ b/tools/test/test_codegen.py
@@ -1,11 +1,22 @@
 import dataclasses
 import typing
 import unittest
+from typing import Dict
 
 import torchgen.model
 
 from tools.autograd import gen_autograd_functions, load_derivatives
-from torchgen.gen import get_native_function_schema_registrations
+from torchgen.gen import (
+    get_native_function_declarations,
+    get_native_function_schema_registrations,
+)
+from torchgen.model import (
+    BackendIndex,
+    BackendMetadata,
+    DispatchKey,
+    NativeFunction,
+    OperatorName,
+)
 from torchgen.selective_build.selector import SelectiveBuilder
 
 
@@ -198,6 +209,67 @@
             )
 
 
+class TestGenNativeFunctionDeclaration(unittest.TestCase):
+    def setUp(self) -> None:
+        self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
+            {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
+            loc=torchgen.model.Location(__file__, 1),
+            valid_tags=set(),
+        )
+        self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
+            {
+                "func": "op_2() -> bool",
+                "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
+            },
+            loc=torchgen.model.Location(__file__, 1),
+            valid_tags=set(),
+        )
+
+        backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
+            DispatchKey.CPU: {},
+            DispatchKey.QuantizedCPU: {},
+        }
+        BackendIndex.grow_index(backend_indices, op_1_backend_index)
+        BackendIndex.grow_index(backend_indices, op_2_backend_index)
+        self.backend_indices = {
+            k: BackendIndex(
+                dispatch_key=k,
+                use_out_as_primary=True,
+                external=False,
+                device_guard=False,
+                index=backend_indices[k],
+            )
+            for k in backend_indices
+        }
+
+    def test_native_function_declaration_1_op_2_ns_error(self) -> None:
+        with self.assertRaises(AssertionError):
+            get_native_function_declarations(
+                grouped_native_functions=[
+                    self.op_1_native_function,
+                    self.op_2_native_function,
+                ],
+                backend_indices=self.backend_indices,
+            )
+
+    def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
+        self.assertIsInstance(self.op_1_native_function, NativeFunction)
+        declaration = get_native_function_declarations(
+            grouped_native_functions=[
+                self.op_1_native_function,
+            ],
+            backend_indices=self.backend_indices,
+        )
+        target = """
+namespace at {
+namespace native {
+TORCH_API bool kernel_1();
+} // namespace native
+} // namespace at
+        """
+        self.assertEqual("\n".join(declaration), target)
+
+
 # Represents the most basic NativeFunction. Use dataclasses.replace()
 # to edit for use.
 DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
diff --git a/torchgen/gen.py b/torchgen/gen.py
index 4ca3564..ab1c0c7 100644
--- a/torchgen/gen.py
+++ b/torchgen/gen.py
@@ -1366,17 +1366,18 @@
     newline = "\n"
     for f in grouped_native_functions:
         native_function_namespaces = set()
-        for backend_idx in backend_indices.values():
+        dispatch_keys = set()
+        for dispatch_key, backend_idx in backend_indices.items():
             backend_metadata = backend_idx.get_kernel(f)
-            namespace = (
-                backend_metadata.cpp_namespace
-                if backend_metadata
-                else DEFAULT_KERNEL_NAMESPACE
-            )
-            native_function_namespaces.add(namespace)
+            if backend_metadata:
+                namespace = backend_metadata.cpp_namespace
+                dispatch_keys.add(dispatch_key)
+                native_function_namespaces.add(namespace)
+            else:
+                namespace = DEFAULT_KERNEL_NAMESPACE
             assert (
-                len(native_function_namespaces) == 1
-            ), "Codegen only supports one namespace per operator."
+                len(native_function_namespaces) <= 1
+            ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
             ns_grouped_kernels[namespace].extend(
                 dest.compute_native_function_declaration(f, backend_idx)
             )
diff --git a/torchgen/model.py b/torchgen/model.py
index 11eafe6..99e1e92 100644
--- a/torchgen/model.py
+++ b/torchgen/model.py
@@ -1115,7 +1115,7 @@
         elif isinstance(g, NativeFunctionsGroup):
             f = self.primary(g)
         else:
-            assert_never(f)
+            assert_never(g)
         if f.func.name not in self.index:
             return None
         return self.index[f.func.name]