forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torchgen] Introduce Executorch types and signatures (pytorch#90591)
Retry of pytorch#89595. Accidentally closed. ## Forked `BaseCppType` Created a module for Executorch: `torchgen.executorch`. In `torchgen.executorch.api.types.types`: * Define `BaseCppType` with `torch::executor` namespace. In `torchgen.executorch.api.et_cpp`: * Help generate `NamedCType` for `ExecutorchCppSignature` arguments. In `torchgen.executorch.api.types.signatures`: * Define the signature using these types. (`ExecutorchCppSignature`) In `torchgen.executorch.api.types.__init__`: * Suppress flake8 error for `import *`. Differential Revision: [D41501836](https://our.internmc.facebook.com/intern/diff/D41501836/) Pull Request resolved: pytorch#90591 Approved by: https://github.com/iseeyuan
- Loading branch information
1 parent
de6beca
commit ddf00c8
Showing
8 changed files
with
608 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import unittest | ||
|
||
from torchgen import local | ||
from torchgen.api.types import ( | ||
BaseCType, | ||
ConstRefCType, | ||
CType, | ||
longT, | ||
MutRefCType, | ||
NamedCType, | ||
OptionalCType, | ||
TupleCType, | ||
VectorCType, | ||
voidT, | ||
) | ||
from torchgen.executorch.api.et_cpp import argument_type, return_type, returns_type | ||
from torchgen.executorch.api.types import ArrayRefCType, scalarT, tensorListT, tensorT | ||
from torchgen.model import Argument, FunctionSchema, Return | ||
|
||
|
||
class ExecutorchCppTest(unittest.TestCase): | ||
""" | ||
Test torchgen.executorch.api.cpp | ||
""" | ||
|
||
def _test_argumenttype_type(self, arg_str: str, expected: NamedCType) -> None: | ||
arg = Argument.parse(arg_str) | ||
self.assertEqual(str(argument_type(arg, binds=arg.name)), str(expected)) | ||
|
||
@local.parametrize( | ||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False | ||
) | ||
def test_argumenttype_type(self) -> None: | ||
data = [ | ||
("Tensor self", NamedCType("self", ConstRefCType(BaseCType(tensorT)))), | ||
("Tensor(a!) out", NamedCType("out", MutRefCType(BaseCType(tensorT)))), | ||
( | ||
"Tensor? opt", | ||
NamedCType("opt", ConstRefCType(OptionalCType(BaseCType(tensorT)))), | ||
), | ||
("Scalar scalar", NamedCType("scalar", ConstRefCType(BaseCType(scalarT)))), | ||
( | ||
"Scalar? scalar", | ||
NamedCType("scalar", ConstRefCType(OptionalCType(BaseCType(scalarT)))), | ||
), | ||
("int[] size", NamedCType("size", ArrayRefCType(BaseCType(longT)))), | ||
("int? dim", NamedCType("dim", OptionalCType(BaseCType(longT)))), | ||
("Tensor[] weight", NamedCType("weight", BaseCType(tensorListT))), | ||
( | ||
"Scalar[] spacing", | ||
NamedCType("spacing", ArrayRefCType(ConstRefCType(BaseCType(scalarT)))), | ||
), | ||
( | ||
"Tensor?[] weight", | ||
NamedCType("weight", ArrayRefCType(OptionalCType(BaseCType(tensorT)))), | ||
), | ||
( | ||
"SymInt[]? output_size", | ||
NamedCType( | ||
"output_size", OptionalCType(ArrayRefCType(BaseCType(longT))) | ||
), | ||
), | ||
( | ||
"int[]? dims", | ||
NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))), | ||
), | ||
] | ||
for d in data: | ||
self._test_argumenttype_type(*d) | ||
|
||
def _test_returntype_type(self, ret_str: str, expected: CType) -> None: | ||
ret = Return.parse(ret_str) | ||
self.assertEqual(str(return_type(ret)), str(expected)) | ||
|
||
@local.parametrize( | ||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False | ||
) | ||
def test_returntype_type(self) -> None: | ||
data = [ | ||
("Tensor", BaseCType(tensorT)), | ||
("Tensor(a!)", MutRefCType(BaseCType(tensorT))), | ||
("Tensor[]", VectorCType(BaseCType(tensorT))), | ||
] | ||
for d in data: | ||
self._test_returntype_type(*d) | ||
|
||
@local.parametrize( | ||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False | ||
) | ||
def test_returns_type(self) -> None: | ||
func = FunctionSchema.parse( | ||
"min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" | ||
) | ||
expected = TupleCType([BaseCType(tensorT), BaseCType(tensorT)]) | ||
self.assertEqual(str(returns_type(func.returns)), str(expected)) | ||
|
||
@local.parametrize( | ||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False | ||
) | ||
def test_void_return_type(self) -> None: | ||
func = FunctionSchema.parse( | ||
"_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()" | ||
) | ||
expected = BaseCType(voidT) | ||
self.assertEqual(str(returns_type(func.returns)), str(expected)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Empty file.
Empty file.
Oops, something went wrong.