Skip to content

Commit

Permalink
feat(internal): Add support for constant value default args in puya
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanmenzel authored and achidlow committed Jan 9, 2025
1 parent 5f0dec0 commit ae2c5ed
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
13 changes: 8 additions & 5 deletions src/puya/arc32.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from puya import log
from puya.errors import InternalError
from puya.models import (
ABIMethodArgConstantDefault,
ABIMethodArgDefault,
ARC4ABIMethod,
ARC4BareMethod,
ARC4CreateOption,
Expand Down Expand Up @@ -92,26 +94,27 @@ def _get_signature(method: ARC4ABIMethod) -> str:


def _encode_default_arg(
metadata: ContractMetaData, source: str, loc: SourceLocation | None
metadata: ContractMetaData, source: ABIMethodArgDefault, loc: SourceLocation | None
) -> JSONDict:
if state := metadata.global_state.get(source):
if isinstance(source, ABIMethodArgConstantDefault):
return {"source": "constant", "data": source.value}
if state := metadata.global_state.get(source.name):
return {
"source": "global-state",
# TODO: handle non utf-8 bytes
"data": state.key_or_prefix.decode("utf-8"),
}
if state := metadata.local_state.get(source):
if state := metadata.local_state.get(source.name):
return {
"source": "local-state",
"data": state.key_or_prefix.decode("utf-8"),
}
for method in metadata.arc4_methods:
if isinstance(method, ARC4ABIMethod) and method.name == source:
if isinstance(method, ARC4ABIMethod) and method.name == source.name:
return {
"source": "abi-method",
"data": _encode_abi_method(method),
}
# TODO: constants
raise InternalError(f"Cannot find source '{source}' on {metadata.ref}", loc)


Expand Down
26 changes: 25 additions & 1 deletion src/puya/ir/arc4_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from puya.errors import CodeError, InternalError
from puya.models import (
ABIMethodArgConstantDefault,
ARC4ABIMethod,
ARC4ABIMethodConfig,
ARC4BareMethod,
Expand Down Expand Up @@ -474,7 +475,7 @@ def _validate_default_args(
args_by_name = {a.name: a for a in method.args}
for (
parameter_name,
source_name,
default_source,
) in method.arc4_method_config.default_args.items():
# any invalid parameter matches should have been caught earlier
parameter = args_by_name[parameter_name]
Expand All @@ -486,6 +487,17 @@ def _validate_default_args(
case "account":
param_arc4_type = "address"

if isinstance(default_source, ABIMethodArgConstantDefault):
if not _is_valid_client_literal_for_arc4_type(
default_source.value, param_arc4_type
):
logger.warning(
f"'{default_source.value}' is not a valid"
f" default value for parameter '{parameter_name}'"
)
continue

source_name = default_source.name
try:
source = known_sources[source_name]
except KeyError as ex:
Expand Down Expand Up @@ -671,6 +683,18 @@ def _get_abi_signature(subroutine: awst_nodes.ContractMethod, config: ARC4ABIMet
return f"{config.name}({','.join(arg_types)}){return_type}"


def _is_valid_client_literal_for_arc4_type(literal: str | int, arc4_type_alias: str) -> bool:
if arc4_type_alias.startswith(("uint", "ufixed")):
return isinstance(literal, int)

match arc4_type_alias:
case "byte" | "bool":
return isinstance(literal, int)
case "address" | "string":
return isinstance(literal, str)
return False


def _wtype_to_arc4(wtype: wtypes.WType, loc: SourceLocation | None = None) -> str:
match wtype:
case wtypes.ARC4Type(arc4_name=arc4_name):
Expand Down
15 changes: 14 additions & 1 deletion src/puya/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ class ARC4BareMethodConfig:
create: ARC4CreateOption = ARC4CreateOption.disallow


@attrs.frozen(kw_only=True)
class ABIMethodArgConstantDefault:
value: int | str


@attrs.frozen(kw_only=True)
class ABIMethodArgMemberDefault:
name: str


ABIMethodArgDefault = ABIMethodArgMemberDefault | ABIMethodArgConstantDefault


@attrs.frozen(kw_only=True)
class ARC4ABIMethodConfig:
source_location: SourceLocation
Expand All @@ -131,7 +144,7 @@ class ARC4ABIMethodConfig:
create: ARC4CreateOption = ARC4CreateOption.disallow
name: str
readonly: bool = False
default_args: immutabledict[str, str] = immutabledict()
default_args: immutabledict[str, ABIMethodArgDefault] = immutabledict()
"""Mapping is from parameter -> source"""


Expand Down
6 changes: 4 additions & 2 deletions src/puyapy/awst_build/arc4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from puya.awst import wtypes
from puya.errors import CodeError, InternalError
from puya.models import (
ABIMethodArgDefault,
ABIMethodArgMemberDefault,
ARC4ABIMethodConfig,
ARC4BareMethodConfig,
ARC4CreateOption,
Expand Down Expand Up @@ -175,7 +177,7 @@ def get_arc4_abimethod_data(
readonly = default_readonly

# map "default_args" param
default_args = dict[str, str]()
default_args = dict[str, ABIMethodArgDefault]()
match evaluated_args.pop("default_args", {}):
case {**options}:
method_arg_names = func_types.keys() - {"output"}
Expand All @@ -190,7 +192,7 @@ def get_arc4_abimethod_data(
else:
# if it's in method_arg_names, it's a str
assert isinstance(parameter, str)
default_args[parameter] = value
default_args[parameter] = ABIMethodArgMemberDefault(name=value)
case invalid_default_args_option:
context.error(f"invalid default_args option: {invalid_default_args_option}", dec_loc)

Expand Down

0 comments on commit ae2c5ed

Please sign in to comment.