From ae2c5edab5e0c1d713e8ff90de111ebb38a89048 Mon Sep 17 00:00:00 2001 From: Tristan Menzel Date: Tue, 5 Nov 2024 16:58:14 -0800 Subject: [PATCH] feat(internal): Add support for constant value default args in puya --- src/puya/arc32.py | 13 ++++++++----- src/puya/ir/arc4_router.py | 26 +++++++++++++++++++++++++- src/puya/models.py | 15 ++++++++++++++- src/puyapy/awst_build/arc4_utils.py | 6 ++++-- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/puya/arc32.py b/src/puya/arc32.py index 73f9aa2d2e..d9e91a6d00 100644 --- a/src/puya/arc32.py +++ b/src/puya/arc32.py @@ -6,6 +6,8 @@ from puya import log from puya.errors import InternalError from puya.models import ( + ABIMethodArgConstantDefault, + ABIMethodArgDefault, ARC4ABIMethod, ARC4BareMethod, ARC4CreateOption, @@ -92,26 +94,27 @@ def _get_signature(method: ARC4ABIMethod) -> str: def _encode_default_arg( - metadata: ContractMetaData, source: str, loc: SourceLocation | None + metadata: ContractMetaData, source: ABIMethodArgDefault, loc: SourceLocation | None ) -> JSONDict: - if state := metadata.global_state.get(source): + if isinstance(source, ABIMethodArgConstantDefault): + return {"source": "constant", "data": source.value} + if state := metadata.global_state.get(source.name): return { "source": "global-state", # TODO: handle non utf-8 bytes "data": state.key_or_prefix.decode("utf-8"), } - if state := metadata.local_state.get(source): + if state := metadata.local_state.get(source.name): return { "source": "local-state", "data": state.key_or_prefix.decode("utf-8"), } for method in metadata.arc4_methods: - if isinstance(method, ARC4ABIMethod) and method.name == source: + if isinstance(method, ARC4ABIMethod) and method.name == source.name: return { "source": "abi-method", "data": _encode_abi_method(method), } - # TODO: constants raise InternalError(f"Cannot find source '{source}' on {metadata.ref}", loc) diff --git a/src/puya/ir/arc4_router.py b/src/puya/ir/arc4_router.py index 8cd97ee732..34d65b060d 100644 --- a/src/puya/ir/arc4_router.py +++ b/src/puya/ir/arc4_router.py @@ -11,6 +11,7 @@ ) from puya.errors import CodeError, InternalError from puya.models import ( + ABIMethodArgConstantDefault, ARC4ABIMethod, ARC4ABIMethodConfig, ARC4BareMethod, @@ -474,7 +475,7 @@ def _validate_default_args( args_by_name = {a.name: a for a in method.args} for ( parameter_name, - source_name, + default_source, ) in method.arc4_method_config.default_args.items(): # any invalid parameter matches should have been caught earlier parameter = args_by_name[parameter_name] @@ -486,6 +487,17 @@ def _validate_default_args( case "account": param_arc4_type = "address" + if isinstance(default_source, ABIMethodArgConstantDefault): + if not _is_valid_client_literal_for_arc4_type( + default_source.value, param_arc4_type + ): + logger.warning( + f"'{default_source.value}' is not a valid" + f" default value for parameter '{parameter_name}'" + ) + continue + + source_name = default_source.name try: source = known_sources[source_name] except KeyError as ex: @@ -671,6 +683,18 @@ def _get_abi_signature(subroutine: awst_nodes.ContractMethod, config: ARC4ABIMet return f"{config.name}({','.join(arg_types)}){return_type}" +def _is_valid_client_literal_for_arc4_type(literal: str | int, arc4_type_alias: str) -> bool: + if arc4_type_alias.startswith(("uint", "ufixed")): + return isinstance(literal, int) + + match arc4_type_alias: + case "byte" | "bool": + return isinstance(literal, int) + case "address" | "string": + return isinstance(literal, str) + return False + + def _wtype_to_arc4(wtype: wtypes.WType, loc: SourceLocation | None = None) -> str: match wtype: case wtypes.ARC4Type(arc4_name=arc4_name): diff --git a/src/puya/models.py b/src/puya/models.py index d71ca2f662..a6cde22511 100644 --- a/src/puya/models.py +++ b/src/puya/models.py @@ -120,6 +120,19 @@ class ARC4BareMethodConfig: create: ARC4CreateOption = ARC4CreateOption.disallow +@attrs.frozen(kw_only=True) +class ABIMethodArgConstantDefault: + value: int | str + + +@attrs.frozen(kw_only=True) +class ABIMethodArgMemberDefault: + name: str + + +ABIMethodArgDefault = ABIMethodArgMemberDefault | ABIMethodArgConstantDefault + + @attrs.frozen(kw_only=True) class ARC4ABIMethodConfig: source_location: SourceLocation @@ -131,7 +144,7 @@ class ARC4ABIMethodConfig: create: ARC4CreateOption = ARC4CreateOption.disallow name: str readonly: bool = False - default_args: immutabledict[str, str] = immutabledict() + default_args: immutabledict[str, ABIMethodArgDefault] = immutabledict() """Mapping is from parameter -> source""" diff --git a/src/puyapy/awst_build/arc4_utils.py b/src/puyapy/awst_build/arc4_utils.py index 8afdc480ba..995f7fd62d 100644 --- a/src/puyapy/awst_build/arc4_utils.py +++ b/src/puyapy/awst_build/arc4_utils.py @@ -12,6 +12,8 @@ from puya.awst import wtypes from puya.errors import CodeError, InternalError from puya.models import ( + ABIMethodArgDefault, + ABIMethodArgMemberDefault, ARC4ABIMethodConfig, ARC4BareMethodConfig, ARC4CreateOption, @@ -175,7 +177,7 @@ def get_arc4_abimethod_data( readonly = default_readonly # map "default_args" param - default_args = dict[str, str]() + default_args = dict[str, ABIMethodArgDefault]() match evaluated_args.pop("default_args", {}): case {**options}: method_arg_names = func_types.keys() - {"output"} @@ -190,7 +192,7 @@ def get_arc4_abimethod_data( else: # if it's in method_arg_names, it's a str assert isinstance(parameter, str) - default_args[parameter] = value + default_args[parameter] = ABIMethodArgMemberDefault(name=value) case invalid_default_args_option: context.error(f"invalid default_args option: {invalid_default_args_option}", dec_loc)