diff --git a/src/puya/ir/models.py b/src/puya/ir/models.py index f38db39e0..5963f3238 100644 --- a/src/puya/ir/models.py +++ b/src/puya/ir/models.py @@ -1,7 +1,7 @@ import abc import typing import typing as t -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence, Set import attrs from immutabledict import immutabledict @@ -823,7 +823,7 @@ def _check_blocks(self, _attribute: object, body: list[BasicBlock]) -> None: f" for phi node {phi}", self.source_location, ) - used_registers = frozenset(_get_used_registers(body)) + used_registers = _get_used_registers(body) defined_registers = frozenset(self.parameters) | frozenset(_get_assigned_registers(body)) bad_reads = used_registers - defined_registers if bad_reads: @@ -867,25 +867,14 @@ def _get_assigned_registers(blocks: Sequence[BasicBlock]) -> Iterator[Register]: yield from op.targets -def _get_used_registers(blocks: Sequence[BasicBlock]) -> Iterator[Register]: - # TODO: replace with visitor +def _get_used_registers(blocks: Sequence[BasicBlock]) -> Set[Register]: + from puya.ir.register_read_collector import RegisterReadCollector + + collector = RegisterReadCollector() for block in blocks: - for phi in block.phis: - yield from (arg.value for arg in phi.args) - for op in block.ops: - match op: - case ( - Assignment( - source=Intrinsic(args=args) - | Assignment(source=ValueTuple(values=args)) - | InvokeSubroutine(args=args) - ) - | Intrinsic(args=args) - | InvokeSubroutine(args=args) - ): - yield from (arg for arg in args if isinstance(arg, Register)) - case Assignment(source=Register() as reg): - yield reg + for op in block.all_ops: + op.accept(collector) + return collector.used_registers @attrs.define(kw_only=True, eq=False)