diff --git a/src/cli/wormhole.rs b/src/cli/wormhole.rs index 7e1cc9f..2313b84 100644 --- a/src/cli/wormhole.rs +++ b/src/cli/wormhole.rs @@ -3938,4 +3938,55 @@ mod tests { } } } + + /// Wrapper so `WormholeCommands` can be parsed directly via clap in tests. + #[derive(clap::Parser, Debug)] + struct CollectRewardsTestCli { + #[command(subcommand)] + cmd: WormholeCommands, + } + + fn try_parse_collect_rewards(extra_args: &[&str]) -> Result { + use clap::Parser; + let mut args = vec!["test", "collect-rewards"]; + args.extend_from_slice(extra_args); + CollectRewardsTestCli::try_parse_from(args).map(|cli| cli.cmd) + } + + #[test] + fn collect_rewards_requires_one_credential() { + let err = try_parse_collect_rewards(&[]).unwrap_err(); + let s = err.to_string(); + assert!( + s.contains("--wallet") || s.contains("--mnemonic") || s.contains("--secret"), + "expected missing-credential error, got: {s}" + ); + } + + #[test] + fn collect_rewards_accepts_each_credential_alone() { + assert!(try_parse_collect_rewards(&["--wallet", "w"]).is_ok()); + assert!(try_parse_collect_rewards(&["--mnemonic", "word ".repeat(24).trim()]).is_ok()); + assert!(try_parse_collect_rewards(&[ + "--secret", + "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20", + ]) + .is_ok()); + } + + #[test] + fn collect_rewards_credentials_mutually_exclusive() { + let pairs: &[(&str, &str, &str, &str)] = &[ + ("--wallet", "w", "--mnemonic", "m"), + ("--wallet", "w", "--secret", "s"), + ("--mnemonic", "m", "--secret", "s"), + ]; + for (a, av, b, bv) in pairs { + let err = try_parse_collect_rewards(&[a, av, b, bv]).unwrap_err().to_string(); + assert!( + err.contains("cannot be used with"), + "expected conflict error for {a} + {b}, got: {err}" + ); + } + } } diff --git a/src/collect_rewards_lib.rs b/src/collect_rewards_lib.rs index 379b826..9402d65 100644 --- a/src/collect_rewards_lib.rs +++ b/src/collect_rewards_lib.rs @@ -223,6 +223,32 @@ pub enum WormholeCredential { }, } +/// Resolve a `WormholeCredential` into `(ss58_address, address_bytes, secret_bytes)`. +pub fn resolve_credential(credential: &WormholeCredential) -> Result<(String, [u8; 32], [u8; 32])> { + match credential { + WormholeCredential::Mnemonic { phrase, wormhole_index } => { + let path = format!("m/44'/{}/0'/0'/{}'", QUANTUS_WORMHOLE_CHAIN_ID, wormhole_index); + let wormhole_pair = derive_wormhole_from_mnemonic(phrase, None, &path) + .map_err(|e| CollectRewardsError::from(format!("HD derivation failed: {:?}", e)))?; + let address_bytes: [u8; 32] = wormhole_pair.address; + let secret_bytes: [u8; 32] = + wormhole_pair.secret.as_ref().try_into().map_err(|_| { + CollectRewardsError::from( + "Invalid secret length from HD derivation".to_string(), + ) + })?; + Ok((AccountId32::from(address_bytes).to_ss58check(), address_bytes, secret_bytes)) + }, + WormholeCredential::Secret { hex } => { + let secret_bytes = parse_secret_hex_str(hex) + .map_err(|e| CollectRewardsError::from(format!("Invalid secret: {}", e)))?; + let address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes) + .map_err(|e| CollectRewardsError::from(e.message))?; + Ok((AccountId32::from(address_bytes).to_ss58check(), address_bytes, secret_bytes)) + }, + } +} + /// Configuration for collect_rewards #[derive(Debug, Clone)] pub struct CollectRewardsConfig { @@ -263,43 +289,8 @@ pub async fn collect_rewards( config: CollectRewardsConfig, progress: &P, ) -> Result { - // Step 1: Derive wormhole address from credential - // Returns: (ss58_address, raw_address_bytes, secret_bytes) - let (wormhole_address, wormhole_address_bytes, wormhole_secret_bytes) = match &config.credential - { - WormholeCredential::Mnemonic { phrase, wormhole_index } => { - progress.on_step("derive", "Deriving wormhole address from mnemonic"); - - let path = format!("m/44'/{}/0'/0'/{}'", QUANTUS_WORMHOLE_CHAIN_ID, wormhole_index); - let wormhole_pair = derive_wormhole_from_mnemonic(phrase, None, &path) - .map_err(|e| CollectRewardsError::from(format!("HD derivation failed: {:?}", e)))?; - - let address_bytes: [u8; 32] = wormhole_pair.address; - let address = AccountId32::from(address_bytes).to_ss58check(); - let secret_bytes: [u8; 32] = - wormhole_pair.secret.as_ref().try_into().map_err(|_| { - CollectRewardsError::from( - "Invalid secret length from HD derivation".to_string(), - ) - })?; - (address, address_bytes, secret_bytes) - }, - WormholeCredential::Secret { hex } => { - progress.on_step("derive", "Deriving wormhole address from secret"); - - // Parse the hex secret - let secret_bytes = parse_secret_hex_str(hex) - .map_err(|e| CollectRewardsError::from(format!("Invalid secret: {}", e)))?; - - // Compute wormhole address from the secret - let address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes) - .map_err(|e| CollectRewardsError::from(e.message))?; - - let address = AccountId32::from(address_bytes).to_ss58check(); - (address, address_bytes, secret_bytes) - }, - }; - + let (wormhole_address, wormhole_address_bytes, wormhole_secret_bytes) = + resolve_credential(&config.credential)?; progress.on_step("derive", &format!("Derived wormhole address: {}", wormhole_address)); // Parse destination address @@ -496,25 +487,18 @@ pub async fn collect_rewards( let digest = header.digest.encode(); let block_number = header.number; - // Compute wormhole address from secret - let computed_wormhole_address = - wormhole_lib::compute_wormhole_address(&wormhole_secret_bytes) - .map_err(|e| CollectRewardsError::from(e.message))?; - - // Verify the leaf's to_account matches our computed wormhole address - if leaf_to_account != computed_wormhole_address { + if leaf_to_account != wormhole_address_bytes { return Err(CollectRewardsError::from(format!( "Leaf to_account mismatch: expected 0x{}, got 0x{}", - hex::encode(computed_wormhole_address), + hex::encode(wormhole_address_bytes), hex::encode(leaf_to_account) ))); } - // Build proof input let input = wormhole_lib::ProofGenerationInput { secret: wormhole_secret_bytes, transfer_count, - wormhole_address: computed_wormhole_address, + wormhole_address: wormhole_address_bytes, input_amount, block_hash: proof_block_hash.0, block_number, @@ -1185,109 +1169,103 @@ mod tests { assert_eq!(result, 100); } - #[test] - fn test_wormhole_credential_secret_address_derivation() { - // Test that WormholeCredential::Secret correctly derives the wormhole address - // using the same logic as wormhole_lib::compute_wormhole_address - - // A known 32-byte secret (hex encoded) - let secret_hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"; - let secret_bytes: [u8; 32] = hex::decode(secret_hex).unwrap().try_into().unwrap(); + const TEST_SECRET_HEX: &str = + "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"; + const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon art"; - // Compute expected address using wormhole_lib - let expected_address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes).unwrap(); - let expected_address = AccountId32::from(expected_address_bytes).to_ss58check(); + #[test] + fn test_resolve_credential_secret() { + let cred = WormholeCredential::Secret { hex: TEST_SECRET_HEX.to_string() }; + let (address, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap(); - // Now simulate what WormholeCredential::Secret does - let parsed_secret = parse_secret_hex_str(secret_hex).unwrap(); - let derived_address_bytes = wormhole_lib::compute_wormhole_address(&parsed_secret).unwrap(); - let derived_address = AccountId32::from(derived_address_bytes).to_ss58check(); + let expected_secret: [u8; 32] = hex::decode(TEST_SECRET_HEX).unwrap().try_into().unwrap(); + let expected_address_bytes = + wormhole_lib::compute_wormhole_address(&expected_secret).unwrap(); - assert_eq!(derived_address, expected_address); - assert_eq!(derived_address_bytes, expected_address_bytes); + assert_eq!(secret_bytes, expected_secret); + assert_eq!(address_bytes, expected_address_bytes); + assert_eq!(address, AccountId32::from(expected_address_bytes).to_ss58check()); } #[test] - fn test_wormhole_credential_secret_with_0x_prefix() { - // Test that secrets with 0x prefix are handled correctly - let secret_hex_no_prefix = - "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"; - let secret_hex_with_prefix = - "0xabcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"; - - let parsed_no_prefix = parse_secret_hex_str(secret_hex_no_prefix).unwrap(); - let parsed_with_prefix = parse_secret_hex_str(secret_hex_with_prefix).unwrap(); - - // Both should produce the same result - assert_eq!(parsed_no_prefix, parsed_with_prefix); - - // Both should derive the same address - let address_no_prefix = wormhole_lib::compute_wormhole_address(&parsed_no_prefix).unwrap(); - let address_with_prefix = - wormhole_lib::compute_wormhole_address(&parsed_with_prefix).unwrap(); - assert_eq!(address_no_prefix, address_with_prefix); + fn test_resolve_credential_secret_accepts_0x_prefix() { + let cred_plain = WormholeCredential::Secret { hex: TEST_SECRET_HEX.to_string() }; + let cred_prefixed = WormholeCredential::Secret { hex: format!("0x{}", TEST_SECRET_HEX) }; + assert_eq!( + resolve_credential(&cred_plain).unwrap(), + resolve_credential(&cred_prefixed).unwrap() + ); } #[test] - fn test_wormhole_credential_secret_invalid_length() { - // Test that invalid secret lengths are rejected - let too_short = "0102030405060708"; // Only 8 bytes - let result = parse_secret_hex_str(too_short); - assert!(result.is_err()); - - let too_long = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021"; // 33 bytes - let result = parse_secret_hex_str(too_long); - assert!(result.is_err()); + fn test_resolve_credential_secret_invalid() { + let too_short = WormholeCredential::Secret { hex: "0102".to_string() }; + assert!(resolve_credential(&too_short).unwrap_err().message.contains("Invalid secret")); + + let bad_hex = WormholeCredential::Secret { hex: "zz".repeat(32) }; + assert!(resolve_credential(&bad_hex).unwrap_err().message.contains("Invalid secret")); } #[test] - fn test_wormhole_credential_secret_invalid_hex() { - // Test that invalid hex is rejected - let invalid_hex = "ghijklmnopqrstuvwxyz01234567890123456789012345678901234567890123"; - let result = parse_secret_hex_str(invalid_hex); - assert!(result.is_err()); + fn test_resolve_credential_mnemonic() { + let cred = + WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 }; + let (address, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap(); + + assert_ne!(secret_bytes, [0u8; 32]); + assert_ne!(address_bytes, [0u8; 32]); + assert_eq!(address, AccountId32::from(address_bytes).to_ss58check()); + assert_eq!(address_bytes, wormhole_lib::compute_wormhole_address(&secret_bytes).unwrap()); } #[test] - fn test_wormhole_credential_secret_deterministic() { - // Test that the same secret always produces the same address - let secret_hex = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef"; - - let parsed1 = parse_secret_hex_str(secret_hex).unwrap(); - let parsed2 = parse_secret_hex_str(secret_hex).unwrap(); - - let address1 = wormhole_lib::compute_wormhole_address(&parsed1).unwrap(); - let address2 = wormhole_lib::compute_wormhole_address(&parsed2).unwrap(); + fn test_resolve_credential_mnemonic_pinned_derivation_path() { + // Regression guard for the HD path `m/44'/CHAIN/0'/0'/index'` (fixed in #93). + // If this breaks, the derivation path or the underlying HD library changed. + let cred = + WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 }; + let (_, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap(); + assert_eq!( + hex::encode(address_bytes), + "b8a7c11fc57b36fbad44e437ec05d91c44231974c058ded1fed66cb7baa41973", + ); + assert_eq!( + hex::encode(secret_bytes), + "110684de72bc884f854accf8bc6ba724dcc1cc2f99932a4d28bdf85fc6f28ccf", + ); + } - assert_eq!(address1, address2); + #[test] + fn test_resolve_credential_mnemonic_index_changes_output() { + let cred_0 = + WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 }; + let cred_1 = + WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 1 }; + assert_ne!(resolve_credential(&cred_0).unwrap(), resolve_credential(&cred_1).unwrap()); } #[test] - fn test_wormhole_credential_enum_variants() { - // Test that both credential variants can be constructed - let mnemonic_cred = WormholeCredential::Mnemonic { - phrase: "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about".to_string(), + fn test_resolve_credential_mnemonic_invalid_phrase() { + let cred = WormholeCredential::Mnemonic { + phrase: "not a real mnemonic".to_string(), wormhole_index: 0, }; + assert!(resolve_credential(&cred).unwrap_err().message.contains("HD derivation")); + } - let secret_cred = WormholeCredential::Secret { - hex: "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20".to_string(), - }; - - // Verify they are different variants - match mnemonic_cred { - WormholeCredential::Mnemonic { phrase, wormhole_index } => { - assert!(phrase.contains("abandon")); - assert_eq!(wormhole_index, 0); - }, - _ => panic!("Expected Mnemonic variant"), - } - - match secret_cred { - WormholeCredential::Secret { hex } => { - assert!(hex.starts_with("0102")); - }, - _ => panic!("Expected Secret variant"), - } + #[test] + fn test_resolve_credential_mnemonic_and_secret_equivalence() { + let mnemonic_cred = + WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 }; + let (m_address, m_address_bytes, m_secret_bytes) = + resolve_credential(&mnemonic_cred).unwrap(); + + let secret_cred = WormholeCredential::Secret { hex: hex::encode(m_secret_bytes) }; + let (s_address, s_address_bytes, s_secret_bytes) = + resolve_credential(&secret_cred).unwrap(); + + assert_eq!(m_address, s_address); + assert_eq!(m_address_bytes, s_address_bytes); + assert_eq!(m_secret_bytes, s_secret_bytes); } }