diff --git a/contracts/Scarb.lock b/contracts/Scarb.lock index 17a40a8..13fc2e2 100644 --- a/contracts/Scarb.lock +++ b/contracts/Scarb.lock @@ -3,8 +3,8 @@ version = 1 [[package]] name = "openzeppelin" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_access", "openzeppelin_account", @@ -21,17 +21,16 @@ dependencies = [ [[package]] name = "openzeppelin_access" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_introspection", - "openzeppelin_utils", ] [[package]] name = "openzeppelin_account" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_introspection", "openzeppelin_utils", @@ -39,8 +38,8 @@ dependencies = [ [[package]] name = "openzeppelin_finance" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_access", "openzeppelin_token", @@ -48,29 +47,30 @@ dependencies = [ [[package]] name = "openzeppelin_governance" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_access", "openzeppelin_account", "openzeppelin_introspection", "openzeppelin_token", + "openzeppelin_utils", ] [[package]] name = "openzeppelin_introspection" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" [[package]] name = "openzeppelin_merkle_tree" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" [[package]] name = "openzeppelin_presets" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_access", "openzeppelin_account", @@ -83,13 +83,13 @@ dependencies = [ [[package]] name = "openzeppelin_security" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" [[package]] name = "openzeppelin_token" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" dependencies = [ "openzeppelin_access", "openzeppelin_account", @@ -99,13 +99,13 @@ dependencies = [ [[package]] name = "openzeppelin_upgrades" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" [[package]] name = "openzeppelin_utils" -version = "0.19.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.19.0#8d49e8c445efd9bdc99b050c8b7d11ae5ad19628" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" [[package]] name = "sncast_std" diff --git a/contracts/Scarb.toml b/contracts/Scarb.toml index 0b686a6..bb0c0a9 100644 --- a/contracts/Scarb.toml +++ b/contracts/Scarb.toml @@ -7,7 +7,7 @@ edition = "2024_07" [dependencies] starknet = "2.9.1" -openzeppelin = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v0.19.0" } +openzeppelin = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v0.20.0" } [dev-dependencies] snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry", tag = "v0.34.0" } diff --git a/contracts/src/lib.cairo b/contracts/src/lib.cairo index 231951d..fc640c9 100644 --- a/contracts/src/lib.cairo +++ b/contracts/src/lib.cairo @@ -3,7 +3,10 @@ use core::starknet::ContractAddress; #[starknet::interface] pub trait IAgentRegistry { fn register_agent( - ref self: TContractState, name: ByteArray, system_prompt: ByteArray, prompt_price: u256, + ref self: TContractState, + name: ByteArray, + system_prompt_uri: ByteArray, + prompt_price: u256, ) -> ContractAddress; fn get_token(self: @TContractState) -> ContractAddress; fn is_agent_registered(self: @TContractState, address: ContractAddress) -> bool; @@ -17,7 +20,7 @@ pub trait IAgentRegistry { #[starknet::interface] pub trait IAgent { - fn get_system_prompt(self: @TContractState) -> ByteArray; + fn get_system_prompt_uri(self: @TContractState) -> ByteArray; fn get_name(self: @TContractState) -> ByteArray; fn get_creator(self: @TContractState) -> ContractAddress; fn get_prompt_price(self: @TContractState) -> u256; @@ -117,7 +120,10 @@ pub mod AgentRegistry { #[abi(embed_v0)] impl AgentRegistryImpl of super::IAgentRegistry { fn register_agent( - ref self: ContractState, name: ByteArray, system_prompt: ByteArray, prompt_price: u256, + ref self: ContractState, + name: ByteArray, + system_prompt_uri: ByteArray, + prompt_price: u256, ) -> ContractAddress { self.pausable.assert_not_paused(); @@ -132,7 +138,7 @@ pub mod AgentRegistry { let mut constructor_calldata = ArrayTrait::::new(); name.serialize(ref constructor_calldata); - system_prompt.serialize(ref constructor_calldata); + system_prompt_uri.serialize(ref constructor_calldata); self.token.read().serialize(ref constructor_calldata); prompt_price.serialize(ref constructor_calldata); creator.serialize(ref constructor_calldata); @@ -276,7 +282,7 @@ pub mod Agent { #[storage] struct Storage { registry: ContractAddress, - system_prompt: ByteArray, + system_prompt_uri: ByteArray, // URI of encrypted prompt name: ByteArray, token: ContractAddress, prompt_price: u256, @@ -289,14 +295,14 @@ pub mod Agent { fn constructor( ref self: ContractState, name: ByteArray, - system_prompt: ByteArray, + system_prompt_uri: ByteArray, token: ContractAddress, prompt_price: u256, creator: ContractAddress, ) { self.registry.write(get_caller_address()); self.name.write(name); - self.system_prompt.write(system_prompt); + self.system_prompt_uri.write(system_prompt_uri); self.token.write(token); self.prompt_price.write(prompt_price); self.creator.write(creator); @@ -309,8 +315,8 @@ pub mod Agent { self.name.read() } - fn get_system_prompt(self: @ContractState) -> ByteArray { - self.system_prompt.read() + fn get_system_prompt_uri(self: @ContractState) -> ByteArray { + self.system_prompt_uri.read() } fn get_prompt_price(self: @ContractState) -> u256 { diff --git a/contracts/tests/test_contract.cairo b/contracts/tests/test_contract.cairo index 93a1ae5..4384864 100644 --- a/contracts/tests/test_contract.cairo +++ b/contracts/tests/test_contract.cairo @@ -59,12 +59,16 @@ fn test_register_agent() { let registry = IAgentRegistryDispatcher { contract_address: registry_address }; let name = "Test Agent"; - let system_prompt = "I am a test agent"; + let system_prompt_uri = "ipfs://encrypted_prompt_123"; let mut spy = spy_events(); start_cheat_caller_address(registry.contract_address, creator); - let agent_address = registry.register_agent(name.clone(), system_prompt.clone(), prompt_price); + let agent_address = registry.register_agent( + name.clone(), + system_prompt_uri.clone(), + prompt_price + ); stop_cheat_caller_address(registry.contract_address); assert(registry.is_agent_registered(agent_address), 'Agent should be registered'); @@ -75,7 +79,8 @@ fn test_register_agent() { let agent_dispatcher = IAgentDispatcher { contract_address: agent_address }; assert(agent_dispatcher.get_name() == name.clone(), 'Wrong agent name'); - assert(agent_dispatcher.get_system_prompt() == system_prompt.clone(), 'Wrong system prompt'); + assert(agent_dispatcher.get_system_prompt_uri() == system_prompt_uri.clone(), 'Wrong system prompt URI'); + assert(agent_dispatcher.get_creator() == creator, 'Wrong creator'); // Verify event was emitted @@ -105,7 +110,11 @@ fn test_pay_for_prompt() { // Register agent start_cheat_caller_address(registry.contract_address, creator); - let agent_address = registry.register_agent("Test Agent", "Test Prompt", prompt_price); + let agent_address = registry.register_agent( + "Test Agent", + "ipfs://encrypted_prompt_pay", + prompt_price + ); stop_cheat_caller_address(registry.contract_address); let agent = IAgentDispatcher { contract_address: agent_address }; @@ -161,9 +170,17 @@ fn test_register_multiple_agents() { start_cheat_caller_address(registry.contract_address, creator); - let agent1_address = registry.register_agent("Agent 1", "Prompt 1", prompt_price); + let agent1_address = registry.register_agent( + "Agent 1", + "ipfs://encrypted_prompt_1", + prompt_price + ); - let agent2_address = registry.register_agent("Agent 2", "Prompt 2", prompt_price); + let agent2_address = registry.register_agent( + "Agent 2", + "ipfs://encrypted_prompt_2", + prompt_price + ); stop_cheat_caller_address(registry.contract_address); @@ -219,16 +236,20 @@ fn test_get_agent_details() { let registry = IAgentRegistryDispatcher { contract_address: registry_address }; let name = "Complex Agent"; - let system_prompt = "Complex system prompt with multiple words"; + let system_prompt_uri = "ipfs://encrypted_prompt_complex_123"; start_cheat_caller_address(registry.contract_address, creator); - let agent_address = registry.register_agent(name.clone(), system_prompt.clone(), prompt_price); + let agent_address = registry.register_agent( + name.clone(), + system_prompt_uri.clone(), + prompt_price + ); stop_cheat_caller_address(registry.contract_address); let agent = IAgentDispatcher { contract_address: agent_address }; assert(agent.get_name() == name, 'Wrong agent name'); - assert(agent.get_system_prompt() == system_prompt, 'Wrong system prompt'); + assert(agent.get_system_prompt_uri() == system_prompt_uri, 'Wrong system prompt URI'); } #[test] @@ -306,6 +327,8 @@ fn test_pay_for_prompt_without_approval() { agent.pay_for_prompt(12345); } + + #[test] fn test_is_agent_registered() { let tee = starknet::contract_address_const::<0x1>(); diff --git a/pkg/agent/setup/manager.go b/pkg/agent/setup/manager.go index e5bcb7a..fefb24c 100644 --- a/pkg/agent/setup/manager.go +++ b/pkg/agent/setup/manager.go @@ -2,8 +2,12 @@ package setup import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "fmt" "log/slog" + "math/big" "github.com/NethermindEth/juno/core/felt" starknetgoutils "github.com/NethermindEth/starknet.go/utils" @@ -42,6 +46,8 @@ type SetupOutput struct { AgentRegistryAddress *felt.Felt `json:"agent_registry_address"` OpenAIKey string `json:"openai_key"` DstackTappdEndpoint string `json:"dstack_tappd_endpoint"` + // RSA private key is managed securely through the DStack TEE environment + // and should not be stored in this structure } func NewSetupManagerFromEnv() (*SetupManager, error) { @@ -134,6 +140,15 @@ func (m *SetupManager) Setup(ctx context.Context) (*SetupOutput, error) { return nil, fmt.Errorf("failed to encumber proton: %v", err) } + // Generate RSA key pair for system prompt encryption + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate RSA key pair: %v", err) + } + + // Convert RSA private key to DER format for storage + rsaPrivateKeyBytes := x509.MarshalPKCS1PrivateKey(rsaPrivateKey) + output := &SetupOutput{ TwitterAuthTokens: twitterEncumbererOutput.AuthTokens, TwitterAccessToken: twitterEncumbererOutput.OAuthTokenPair.Token, @@ -148,6 +163,7 @@ func (m *SetupManager) Setup(ctx context.Context) (*SetupOutput, error) { AgentRegistryAddress: agentRegistryAddress, OpenAIKey: m.openAiKey, DstackTappdEndpoint: m.dstackTappdEndpoint, + RsaPrivateKey: rsaPrivateKeyBytes, } if debug.IsDebugShowSetup() {