Skip to content

Commit

Permalink
fix: move refund_address to accounts
Browse files Browse the repository at this point in the history
Signed-off-by: Reinis Martinsons <[email protected]>
  • Loading branch information
Reinis-FRP committed Jan 9, 2025
1 parent 17222a0 commit 66caed6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
16 changes: 9 additions & 7 deletions programs/svm-spoke/src/instructions/refund_claims.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub fn initialize_claim_account(ctx: Context<InitializeClaimAccount>) -> Result<

#[event_cpi]
#[derive(Accounts)]
#[instruction(refund_address: Option<Pubkey>)]
pub struct ClaimRelayerRefund<'info> {
pub signer: Signer<'info>,

Expand All @@ -59,30 +58,33 @@ pub struct ClaimRelayerRefund<'info> {
#[account(mint::token_program = token_program)]
pub mint: InterfaceAccount<'info, Mint>,

// If refund_address is not provided this method allows relayer to claim refunds on any custom token account.
/// CHECK: This is used for claim_account PDA derivation and it is up to the caller to ensure it is valid.
pub refund_address: UncheckedAccount<'info>,

// If refund_address is the same as signer this method allows relayer to claim refunds on any custom token account.
// Otherwise this must be the associated token account of the provided refund_address.
#[account(
mut,
token::mint = mint,
token::token_program = token_program,
constraint = refund_address.is_none()
|| is_valid_associated_token_account(&token_account, &mint, &token_program, &refund_address.unwrap())
constraint = refund_address.key().eq(&signer.key())
|| is_valid_associated_token_account(&token_account, &mint, &token_program, &refund_address.key())
@ SvmError::InvalidRefundTokenAccount
)]
pub token_account: InterfaceAccount<'info, TokenAccount>,

#[account(
mut,
close = initializer,
seeds = [b"claim_account", mint.key().as_ref(), refund_address.unwrap_or_else(|| signer.key()).as_ref()],
seeds = [b"claim_account", mint.key().as_ref(), refund_address.key().as_ref()],
bump
)]
pub claim_account: Account<'info, ClaimAccount>,

pub token_program: Interface<'info, TokenInterface>,
}

pub fn claim_relayer_refund(ctx: Context<ClaimRelayerRefund>, refund_address: Option<Pubkey>) -> Result<()> {
pub fn claim_relayer_refund(ctx: Context<ClaimRelayerRefund>) -> Result<()> {
// Ensure the claim account holds a non-zero amount.
let claim_amount = ctx.accounts.claim_account.amount;
if claim_amount == 0 {
Expand All @@ -108,7 +110,7 @@ pub fn claim_relayer_refund(ctx: Context<ClaimRelayerRefund>, refund_address: Op
emit_cpi!(ClaimedRelayerRefund {
l2_token_address: ctx.accounts.mint.key(),
claim_amount,
refund_address: refund_address.unwrap_or_else(|| ctx.accounts.signer.key()),
refund_address: ctx.accounts.refund_address.key(),
});

Ok(()) // There is no need to reset the claim amount as the account will be closed at the end of instruction.
Expand Down
12 changes: 5 additions & 7 deletions programs/svm-spoke/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,13 @@ pub mod svm_spoke {
/// - state (Account): Spoke state PDA. Seed: ["state",state.seed] where seed is 0 on mainnet.
/// - vault (InterfaceAccount): The ATA for the refunded mint. Authority must be the state.
/// - mint (InterfaceAccount): The mint account for the token being refunded.
/// - token_account (InterfaceAccount): The receiving token account for the refund. When refund_address is provided,
/// this must match its ATA.
/// - refund_address: token account authority receiving the refund.
/// - token_account (InterfaceAccount): The receiving token account for the refund. When refund_address is different
/// from the signer, this must match its ATA.
/// - claim_account (Account): The claim account PDA. Seed: ["claim_account",mint,refund_address].
/// - token_program (Interface): The token program.
///
/// ### Parameters:
/// - refund_address: Optional token account authority receiving the refund. If None, the signer is used.
pub fn claim_relayer_refund(ctx: Context<ClaimRelayerRefund>, refund_address: Option<Pubkey>) -> Result<()> {
instructions::claim_relayer_refund(ctx, refund_address)
pub fn claim_relayer_refund(ctx: Context<ClaimRelayerRefund>) -> Result<()> {
instructions::claim_relayer_refund(ctx)
}

/// Creates token accounts in batch for a set of addresses.
Expand Down
29 changes: 14 additions & 15 deletions test/svm/SvmSpoke.RefundClaims.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ describe("svm_spoke.refund_claims", () => {
state: PublicKey;
vault: PublicKey;
mint: PublicKey;
refundAddress: PublicKey;
tokenAccount: PublicKey;
claimAccount: PublicKey;
tokenProgram: PublicKey;
Expand Down Expand Up @@ -149,6 +150,7 @@ describe("svm_spoke.refund_claims", () => {
state,
vault,
mint,
refundAddress: relayer.publicKey,
tokenAccount,
claimAccount,
tokenProgram: TOKEN_PROGRAM_ID,
Expand Down Expand Up @@ -176,7 +178,7 @@ describe("svm_spoke.refund_claims", () => {
const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount;

// Claim refund for the relayer.
const tx = await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
const tx = await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();

// The relayer should have received funds from the vault.
const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount;
Expand All @@ -198,11 +200,11 @@ describe("svm_spoke.refund_claims", () => {
await executeRelayerRefundToClaim(relayerRefund);

// Claim refund for the relayer.
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();

// The claim account should have been automatically closed, so repeated claim should fail.
try {
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
assert.fail("Claiming refund from closed account should fail");
} catch (error: any) {
assert.instanceOf(error, AnchorError);
Expand All @@ -216,7 +218,7 @@ describe("svm_spoke.refund_claims", () => {
// After reinitalizing the claim account, the repeated claim should still fail.
await initializeClaimAccount();
try {
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
assert.fail("Claiming refund from reinitalized account should fail");
} catch (error: any) {
assert.instanceOf(error, AnchorError);
Expand All @@ -235,7 +237,7 @@ describe("svm_spoke.refund_claims", () => {
const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount;

// Claim refund for the relayer.
await await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();

// The relayer should have received both refunds.
const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount;
Expand All @@ -260,7 +262,7 @@ describe("svm_spoke.refund_claims", () => {

// Claiming with default initializer should fail.
try {
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
} catch (error: any) {
assert.instanceOf(error, AnchorError);
assert.strictEqual(
Expand All @@ -272,7 +274,7 @@ describe("svm_spoke.refund_claims", () => {

// Claim refund for the relayer passing the correct initializer account.
claimRelayerRefundAccounts.initializer = anotherInitializer.publicKey;
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();

// The relayer should have received funds from the vault.
const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount;
Expand Down Expand Up @@ -344,7 +346,7 @@ describe("svm_spoke.refund_claims", () => {
claimRelayerRefundAccounts.tokenAccount = wrongTokenAccount;

try {
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
assert.fail("Claiming refund to custom token account should fail");
} catch (error: any) {
assert.instanceOf(error, AnchorError);
Expand All @@ -369,7 +371,7 @@ describe("svm_spoke.refund_claims", () => {
await setAuthority(connection, payer, wrongTokenAccount, wrongOwner, AuthorityType.AccountOwner, relayer.publicKey);

try {
await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
assert.fail("Claiming refund to custom token account should fail");
} catch (error: any) {
assert.instanceOf(error, AnchorError);
Expand All @@ -396,11 +398,7 @@ describe("svm_spoke.refund_claims", () => {
claimRelayerRefundAccounts.signer = relayer.publicKey; // Only relayer itself should be able to do this.

// Relayer can claim refund to custom token account.
const tx = await program.methods
.claimRelayerRefund(null)
.accounts(claimRelayerRefundAccounts)
.signers([relayer])
.rpc();
const tx = await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).signers([relayer]).rpc();

// The relayer should have received funds from the vault.
const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount;
Expand All @@ -422,8 +420,9 @@ describe("svm_spoke.refund_claims", () => {
await executeRelayerRefundToClaim(relayerRefund);

// Claim refund for the relayer with the default signer should fail as relayer address is part of claim account derivation.
claimRelayerRefundAccounts.refundAddress = owner;
try {
await program.methods.claimRelayerRefund(null).accounts(claimRelayerRefundAccounts).rpc();
await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc();
assert.fail("Claiming refund with wrong signer should fail");
} catch (error: any) {
assert.instanceOf(error, AnchorError);
Expand Down

0 comments on commit 66caed6

Please sign in to comment.