Skip to content

Commit

Permalink
[torchgen] Let native function declaration generation logic take a ca…
Browse files Browse the repository at this point in the history
…llable (pytorch#90590)

Retry of pytorch#89594. Accidentally closed.

This PR allows `get_native_function_declarations` API to take a function as argument. This function should take `NativeFunction` as input and emit code for native function declaration. By default it is `dest.compute_native_function_declaration`.

Differential Revision: [D41501838](https://our.internmc.facebook.com/intern/diff/D41501838/)
Pull Request resolved: pytorch#90590
Approved by: https://github.com/iseeyuan
  • Loading branch information
larryliu0820 authored and pytorchmergebot committed Dec 10, 2022
1 parent 453ff96 commit de6beca
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tools/test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import yaml

from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
from torchgen.gen import (
Expand Down Expand Up @@ -356,6 +357,7 @@ def test_native_function_declaration_1_op_2_ns_error(self) -> None:
self.op_2_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)

def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
Expand All @@ -365,6 +367,7 @@ def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
self.op_1_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
target = """
namespace at {
Expand Down
30 changes: 27 additions & 3 deletions torchgen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
import pathlib
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)

import yaml
from typing_extensions import Literal
Expand Down Expand Up @@ -1406,7 +1417,17 @@ def get_native_function_declarations(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
],
) -> List[str]:
"""
Generate kernel declarations, in `NativeFunction(s).h`.
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
:param backend_indices: kernel collections grouped by dispatch key.
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
"""
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
Expand All @@ -1425,7 +1446,7 @@ def get_native_function_declarations(
len(native_function_namespaces) <= 1
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
ns_grouped_kernels[namespace].extend(
dest.compute_native_function_declaration(f, backend_idx)
native_function_decl_gen(f, backend_idx)
)

for namespace, kernels in ns_grouped_kernels.items():
Expand Down Expand Up @@ -1734,6 +1755,7 @@ def gen_aggregated_headers(
declarations = get_native_function_declarations(
grouped_native_functions=grouped_native_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
cpu_fm.write(
"NativeFunctions.h",
Expand Down Expand Up @@ -1863,7 +1885,9 @@ def gen_per_operator_headers(
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_functions, backend_indices=backend_indices
grouped_native_functions=grouped_functions,
backend_indices=backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
ops_fm.write_with_template(
f"{name}_native.h",
Expand Down

0 comments on commit de6beca

Please sign in to comment.