diff --git a/programs/svm-spoke/src/instructions/refund_claims.rs b/programs/svm-spoke/src/instructions/refund_claims.rs index 12276d273..3a48cb6d2 100644 --- a/programs/svm-spoke/src/instructions/refund_claims.rs +++ b/programs/svm-spoke/src/instructions/refund_claims.rs @@ -36,7 +36,6 @@ pub fn initialize_claim_account(ctx: Context) -> Result< #[event_cpi] #[derive(Accounts)] -#[instruction(refund_address: Option)] pub struct ClaimRelayerRefund<'info> { pub signer: Signer<'info>, @@ -59,14 +58,17 @@ pub struct ClaimRelayerRefund<'info> { #[account(mint::token_program = token_program)] pub mint: InterfaceAccount<'info, Mint>, - // If refund_address is not provided this method allows relayer to claim refunds on any custom token account. + /// CHECK: This is used for claim_account PDA derivation and it is up to the caller to ensure it is valid. + pub refund_address: UncheckedAccount<'info>, + + // If refund_address is the same as signer this method allows relayer to claim refunds on any custom token account. // Otherwise this must be the associated token account of the provided refund_address. #[account( mut, token::mint = mint, token::token_program = token_program, - constraint = refund_address.is_none() - || is_valid_associated_token_account(&token_account, &mint, &token_program, &refund_address.unwrap()) + constraint = refund_address.key().eq(&signer.key()) + || is_valid_associated_token_account(&token_account, &mint, &token_program, &refund_address.key()) @ SvmError::InvalidRefundTokenAccount )] pub token_account: InterfaceAccount<'info, TokenAccount>, @@ -74,7 +76,7 @@ pub struct ClaimRelayerRefund<'info> { #[account( mut, close = initializer, - seeds = [b"claim_account", mint.key().as_ref(), refund_address.unwrap_or_else(|| signer.key()).as_ref()], + seeds = [b"claim_account", mint.key().as_ref(), refund_address.key().as_ref()], bump )] pub claim_account: Account<'info, ClaimAccount>, @@ -82,7 +84,7 @@ pub struct ClaimRelayerRefund<'info> { pub token_program: Interface<'info, TokenInterface>, } -pub fn claim_relayer_refund(ctx: Context, refund_address: Option) -> Result<()> { +pub fn claim_relayer_refund(ctx: Context) -> Result<()> { // Ensure the claim account holds a non-zero amount. let claim_amount = ctx.accounts.claim_account.amount; if claim_amount == 0 { @@ -108,7 +110,7 @@ pub fn claim_relayer_refund(ctx: Context, refund_address: Op emit_cpi!(ClaimedRelayerRefund { l2_token_address: ctx.accounts.mint.key(), claim_amount, - refund_address: refund_address.unwrap_or_else(|| ctx.accounts.signer.key()), + refund_address: ctx.accounts.refund_address.key(), }); Ok(()) // There is no need to reset the claim amount as the account will be closed at the end of instruction. diff --git a/programs/svm-spoke/src/lib.rs b/programs/svm-spoke/src/lib.rs index e9a5c90a4..1d5830554 100644 --- a/programs/svm-spoke/src/lib.rs +++ b/programs/svm-spoke/src/lib.rs @@ -455,15 +455,13 @@ pub mod svm_spoke { /// - state (Account): Spoke state PDA. Seed: ["state",state.seed] where seed is 0 on mainnet. /// - vault (InterfaceAccount): The ATA for the refunded mint. Authority must be the state. /// - mint (InterfaceAccount): The mint account for the token being refunded. - /// - token_account (InterfaceAccount): The receiving token account for the refund. When refund_address is provided, - /// this must match its ATA. + /// - refund_address: token account authority receiving the refund. + /// - token_account (InterfaceAccount): The receiving token account for the refund. When refund_address is different + /// from the signer, this must match its ATA. /// - claim_account (Account): The claim account PDA. Seed: ["claim_account",mint,refund_address]. /// - token_program (Interface): The token program. - /// - /// ### Parameters: - /// - refund_address: Optional token account authority receiving the refund. If None, the signer is used. - pub fn claim_relayer_refund(ctx: Context, refund_address: Option) -> Result<()> { - instructions::claim_relayer_refund(ctx, refund_address) + pub fn claim_relayer_refund(ctx: Context) -> Result<()> { + instructions::claim_relayer_refund(ctx) } /// Creates token accounts in batch for a set of addresses. diff --git a/test/svm/SvmSpoke.RefundClaims.ts b/test/svm/SvmSpoke.RefundClaims.ts index 61e61d8ee..dcb0e4890 100644 --- a/test/svm/SvmSpoke.RefundClaims.ts +++ b/test/svm/SvmSpoke.RefundClaims.ts @@ -38,6 +38,7 @@ describe("svm_spoke.refund_claims", () => { state: PublicKey; vault: PublicKey; mint: PublicKey; + refundAddress: PublicKey; tokenAccount: PublicKey; claimAccount: PublicKey; tokenProgram: PublicKey; @@ -149,6 +150,7 @@ describe("svm_spoke.refund_claims", () => { state, vault, mint, + refundAddress: relayer.publicKey, tokenAccount, claimAccount, tokenProgram: TOKEN_PROGRAM_ID, @@ -176,7 +178,7 @@ describe("svm_spoke.refund_claims", () => { const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount; // Claim refund for the relayer. - const tx = await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + const tx = await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received funds from the vault. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -198,11 +200,11 @@ describe("svm_spoke.refund_claims", () => { await executeRelayerRefundToClaim(relayerRefund); // Claim refund for the relayer. - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The claim account should have been automatically closed, so repeated claim should fail. try { - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund from closed account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -216,7 +218,7 @@ describe("svm_spoke.refund_claims", () => { // After reinitalizing the claim account, the repeated claim should still fail. await initializeClaimAccount(); try { - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund from reinitalized account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -235,7 +237,7 @@ describe("svm_spoke.refund_claims", () => { const iRelayerBal = (await connection.getTokenAccountBalance(tokenAccount)).value.amount; // Claim refund for the relayer. - await await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received both refunds. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -260,7 +262,7 @@ describe("svm_spoke.refund_claims", () => { // Claiming with default initializer should fail. try { - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); } catch (error: any) { assert.instanceOf(error, AnchorError); assert.strictEqual( @@ -272,7 +274,7 @@ describe("svm_spoke.refund_claims", () => { // Claim refund for the relayer passing the correct initializer account. claimRelayerRefundAccounts.initializer = anotherInitializer.publicKey; - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); // The relayer should have received funds from the vault. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -344,7 +346,7 @@ describe("svm_spoke.refund_claims", () => { claimRelayerRefundAccounts.tokenAccount = wrongTokenAccount; try { - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund to custom token account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -369,7 +371,7 @@ describe("svm_spoke.refund_claims", () => { await setAuthority(connection, payer, wrongTokenAccount, wrongOwner, AuthorityType.AccountOwner, relayer.publicKey); try { - await program.methods.claimRelayerRefund(relayer.publicKey).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund to custom token account should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError); @@ -396,11 +398,7 @@ describe("svm_spoke.refund_claims", () => { claimRelayerRefundAccounts.signer = relayer.publicKey; // Only relayer itself should be able to do this. // Relayer can claim refund to custom token account. - const tx = await program.methods - .claimRelayerRefund(null) - .accounts(claimRelayerRefundAccounts) - .signers([relayer]) - .rpc(); + const tx = await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).signers([relayer]).rpc(); // The relayer should have received funds from the vault. const fVaultBal = (await connection.getTokenAccountBalance(vault)).value.amount; @@ -422,8 +420,9 @@ describe("svm_spoke.refund_claims", () => { await executeRelayerRefundToClaim(relayerRefund); // Claim refund for the relayer with the default signer should fail as relayer address is part of claim account derivation. + claimRelayerRefundAccounts.refundAddress = owner; try { - await program.methods.claimRelayerRefund(null).accounts(claimRelayerRefundAccounts).rpc(); + await program.methods.claimRelayerRefund().accounts(claimRelayerRefundAccounts).rpc(); assert.fail("Claiming refund with wrong signer should fail"); } catch (error: any) { assert.instanceOf(error, AnchorError);