Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/cli/wormhole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3938,4 +3938,55 @@ mod tests {
}
}
}

/// Wrapper so `WormholeCommands` can be parsed directly via clap in tests.
#[derive(clap::Parser, Debug)]
struct CollectRewardsTestCli {
#[command(subcommand)]
cmd: WormholeCommands,
}

fn try_parse_collect_rewards(extra_args: &[&str]) -> Result<WormholeCommands, clap::Error> {
use clap::Parser;
let mut args = vec!["test", "collect-rewards"];
args.extend_from_slice(extra_args);
CollectRewardsTestCli::try_parse_from(args).map(|cli| cli.cmd)
}

#[test]
fn collect_rewards_requires_one_credential() {
let err = try_parse_collect_rewards(&[]).unwrap_err();
let s = err.to_string();
assert!(
s.contains("--wallet") || s.contains("--mnemonic") || s.contains("--secret"),
"expected missing-credential error, got: {s}"
);
}

#[test]
fn collect_rewards_accepts_each_credential_alone() {
assert!(try_parse_collect_rewards(&["--wallet", "w"]).is_ok());
assert!(try_parse_collect_rewards(&["--mnemonic", "word ".repeat(24).trim()]).is_ok());
assert!(try_parse_collect_rewards(&[
"--secret",
"0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20",
])
.is_ok());
}

#[test]
fn collect_rewards_credentials_mutually_exclusive() {
let pairs: &[(&str, &str, &str, &str)] = &[
("--wallet", "w", "--mnemonic", "m"),
("--wallet", "w", "--secret", "s"),
("--mnemonic", "m", "--secret", "s"),
];
for (a, av, b, bv) in pairs {
let err = try_parse_collect_rewards(&[a, av, b, bv]).unwrap_err().to_string();
assert!(
err.contains("cannot be used with"),
"expected conflict error for {a} + {b}, got: {err}"
);
}
}
}
236 changes: 107 additions & 129 deletions src/collect_rewards_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,32 @@ pub enum WormholeCredential {
},
}

/// Resolve a `WormholeCredential` into `(ss58_address, address_bytes, secret_bytes)`.
pub fn resolve_credential(credential: &WormholeCredential) -> Result<(String, [u8; 32], [u8; 32])> {
match credential {
WormholeCredential::Mnemonic { phrase, wormhole_index } => {
let path = format!("m/44'/{}/0'/0'/{}'", QUANTUS_WORMHOLE_CHAIN_ID, wormhole_index);
let wormhole_pair = derive_wormhole_from_mnemonic(phrase, None, &path)
.map_err(|e| CollectRewardsError::from(format!("HD derivation failed: {:?}", e)))?;
let address_bytes: [u8; 32] = wormhole_pair.address;
let secret_bytes: [u8; 32] =
wormhole_pair.secret.as_ref().try_into().map_err(|_| {
CollectRewardsError::from(
"Invalid secret length from HD derivation".to_string(),
)
})?;
Ok((AccountId32::from(address_bytes).to_ss58check(), address_bytes, secret_bytes))
},
WormholeCredential::Secret { hex } => {
let secret_bytes = parse_secret_hex_str(hex)
.map_err(|e| CollectRewardsError::from(format!("Invalid secret: {}", e)))?;
let address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes)
.map_err(|e| CollectRewardsError::from(e.message))?;
Ok((AccountId32::from(address_bytes).to_ss58check(), address_bytes, secret_bytes))
},
}
}

/// Configuration for collect_rewards
#[derive(Debug, Clone)]
pub struct CollectRewardsConfig {
Expand Down Expand Up @@ -263,43 +289,8 @@ pub async fn collect_rewards<P: ProgressCallback>(
config: CollectRewardsConfig,
progress: &P,
) -> Result<CollectRewardsResult> {
// Step 1: Derive wormhole address from credential
// Returns: (ss58_address, raw_address_bytes, secret_bytes)
let (wormhole_address, wormhole_address_bytes, wormhole_secret_bytes) = match &config.credential
{
WormholeCredential::Mnemonic { phrase, wormhole_index } => {
progress.on_step("derive", "Deriving wormhole address from mnemonic");

let path = format!("m/44'/{}/0'/0'/{}'", QUANTUS_WORMHOLE_CHAIN_ID, wormhole_index);
let wormhole_pair = derive_wormhole_from_mnemonic(phrase, None, &path)
.map_err(|e| CollectRewardsError::from(format!("HD derivation failed: {:?}", e)))?;

let address_bytes: [u8; 32] = wormhole_pair.address;
let address = AccountId32::from(address_bytes).to_ss58check();
let secret_bytes: [u8; 32] =
wormhole_pair.secret.as_ref().try_into().map_err(|_| {
CollectRewardsError::from(
"Invalid secret length from HD derivation".to_string(),
)
})?;
(address, address_bytes, secret_bytes)
},
WormholeCredential::Secret { hex } => {
progress.on_step("derive", "Deriving wormhole address from secret");

// Parse the hex secret
let secret_bytes = parse_secret_hex_str(hex)
.map_err(|e| CollectRewardsError::from(format!("Invalid secret: {}", e)))?;

// Compute wormhole address from the secret
let address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes)
.map_err(|e| CollectRewardsError::from(e.message))?;

let address = AccountId32::from(address_bytes).to_ss58check();
(address, address_bytes, secret_bytes)
},
};

let (wormhole_address, wormhole_address_bytes, wormhole_secret_bytes) =
resolve_credential(&config.credential)?;
progress.on_step("derive", &format!("Derived wormhole address: {}", wormhole_address));

// Parse destination address
Expand Down Expand Up @@ -496,25 +487,18 @@ pub async fn collect_rewards<P: ProgressCallback>(
let digest = header.digest.encode();
let block_number = header.number;

// Compute wormhole address from secret
let computed_wormhole_address =
wormhole_lib::compute_wormhole_address(&wormhole_secret_bytes)
.map_err(|e| CollectRewardsError::from(e.message))?;

// Verify the leaf's to_account matches our computed wormhole address
if leaf_to_account != computed_wormhole_address {
if leaf_to_account != wormhole_address_bytes {
return Err(CollectRewardsError::from(format!(
"Leaf to_account mismatch: expected 0x{}, got 0x{}",
hex::encode(computed_wormhole_address),
hex::encode(wormhole_address_bytes),
hex::encode(leaf_to_account)
)));
}

// Build proof input
let input = wormhole_lib::ProofGenerationInput {
secret: wormhole_secret_bytes,
transfer_count,
wormhole_address: computed_wormhole_address,
wormhole_address: wormhole_address_bytes,
input_amount,
block_hash: proof_block_hash.0,
block_number,
Expand Down Expand Up @@ -1185,109 +1169,103 @@ mod tests {
assert_eq!(result, 100);
}

#[test]
fn test_wormhole_credential_secret_address_derivation() {
// Test that WormholeCredential::Secret correctly derives the wormhole address
// using the same logic as wormhole_lib::compute_wormhole_address

// A known 32-byte secret (hex encoded)
let secret_hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
let secret_bytes: [u8; 32] = hex::decode(secret_hex).unwrap().try_into().unwrap();
const TEST_SECRET_HEX: &str =
"0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
const TEST_MNEMONIC: &str = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon art";

// Compute expected address using wormhole_lib
let expected_address_bytes = wormhole_lib::compute_wormhole_address(&secret_bytes).unwrap();
let expected_address = AccountId32::from(expected_address_bytes).to_ss58check();
#[test]
fn test_resolve_credential_secret() {
let cred = WormholeCredential::Secret { hex: TEST_SECRET_HEX.to_string() };
let (address, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap();

// Now simulate what WormholeCredential::Secret does
let parsed_secret = parse_secret_hex_str(secret_hex).unwrap();
let derived_address_bytes = wormhole_lib::compute_wormhole_address(&parsed_secret).unwrap();
let derived_address = AccountId32::from(derived_address_bytes).to_ss58check();
let expected_secret: [u8; 32] = hex::decode(TEST_SECRET_HEX).unwrap().try_into().unwrap();
let expected_address_bytes =
wormhole_lib::compute_wormhole_address(&expected_secret).unwrap();

assert_eq!(derived_address, expected_address);
assert_eq!(derived_address_bytes, expected_address_bytes);
assert_eq!(secret_bytes, expected_secret);
assert_eq!(address_bytes, expected_address_bytes);
assert_eq!(address, AccountId32::from(expected_address_bytes).to_ss58check());
}

#[test]
fn test_wormhole_credential_secret_with_0x_prefix() {
// Test that secrets with 0x prefix are handled correctly
let secret_hex_no_prefix =
"abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
let secret_hex_with_prefix =
"0xabcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";

let parsed_no_prefix = parse_secret_hex_str(secret_hex_no_prefix).unwrap();
let parsed_with_prefix = parse_secret_hex_str(secret_hex_with_prefix).unwrap();

// Both should produce the same result
assert_eq!(parsed_no_prefix, parsed_with_prefix);

// Both should derive the same address
let address_no_prefix = wormhole_lib::compute_wormhole_address(&parsed_no_prefix).unwrap();
let address_with_prefix =
wormhole_lib::compute_wormhole_address(&parsed_with_prefix).unwrap();
assert_eq!(address_no_prefix, address_with_prefix);
fn test_resolve_credential_secret_accepts_0x_prefix() {
let cred_plain = WormholeCredential::Secret { hex: TEST_SECRET_HEX.to_string() };
let cred_prefixed = WormholeCredential::Secret { hex: format!("0x{}", TEST_SECRET_HEX) };
assert_eq!(
resolve_credential(&cred_plain).unwrap(),
resolve_credential(&cred_prefixed).unwrap()
);
}

#[test]
fn test_wormhole_credential_secret_invalid_length() {
// Test that invalid secret lengths are rejected
let too_short = "0102030405060708"; // Only 8 bytes
let result = parse_secret_hex_str(too_short);
assert!(result.is_err());

let too_long = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021"; // 33 bytes
let result = parse_secret_hex_str(too_long);
assert!(result.is_err());
fn test_resolve_credential_secret_invalid() {
let too_short = WormholeCredential::Secret { hex: "0102".to_string() };
assert!(resolve_credential(&too_short).unwrap_err().message.contains("Invalid secret"));

let bad_hex = WormholeCredential::Secret { hex: "zz".repeat(32) };
assert!(resolve_credential(&bad_hex).unwrap_err().message.contains("Invalid secret"));
}

#[test]
fn test_wormhole_credential_secret_invalid_hex() {
// Test that invalid hex is rejected
let invalid_hex = "ghijklmnopqrstuvwxyz01234567890123456789012345678901234567890123";
let result = parse_secret_hex_str(invalid_hex);
assert!(result.is_err());
fn test_resolve_credential_mnemonic() {
let cred =
WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 };
let (address, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap();

assert_ne!(secret_bytes, [0u8; 32]);
assert_ne!(address_bytes, [0u8; 32]);
assert_eq!(address, AccountId32::from(address_bytes).to_ss58check());
assert_eq!(address_bytes, wormhole_lib::compute_wormhole_address(&secret_bytes).unwrap());
}

#[test]
fn test_wormhole_credential_secret_deterministic() {
// Test that the same secret always produces the same address
let secret_hex = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef";

let parsed1 = parse_secret_hex_str(secret_hex).unwrap();
let parsed2 = parse_secret_hex_str(secret_hex).unwrap();

let address1 = wormhole_lib::compute_wormhole_address(&parsed1).unwrap();
let address2 = wormhole_lib::compute_wormhole_address(&parsed2).unwrap();
fn test_resolve_credential_mnemonic_pinned_derivation_path() {
// Regression guard for the HD path `m/44'/CHAIN/0'/0'/index'` (fixed in #93).
// If this breaks, the derivation path or the underlying HD library changed.
let cred =
WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 };
let (_, address_bytes, secret_bytes) = resolve_credential(&cred).unwrap();
assert_eq!(
hex::encode(address_bytes),
"b8a7c11fc57b36fbad44e437ec05d91c44231974c058ded1fed66cb7baa41973",
);
assert_eq!(
hex::encode(secret_bytes),
"110684de72bc884f854accf8bc6ba724dcc1cc2f99932a4d28bdf85fc6f28ccf",
);
}

assert_eq!(address1, address2);
#[test]
fn test_resolve_credential_mnemonic_index_changes_output() {
let cred_0 =
WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 };
let cred_1 =
WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 1 };
assert_ne!(resolve_credential(&cred_0).unwrap(), resolve_credential(&cred_1).unwrap());
}

#[test]
fn test_wormhole_credential_enum_variants() {
// Test that both credential variants can be constructed
let mnemonic_cred = WormholeCredential::Mnemonic {
phrase: "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about".to_string(),
fn test_resolve_credential_mnemonic_invalid_phrase() {
let cred = WormholeCredential::Mnemonic {
phrase: "not a real mnemonic".to_string(),
wormhole_index: 0,
};
assert!(resolve_credential(&cred).unwrap_err().message.contains("HD derivation"));
}

let secret_cred = WormholeCredential::Secret {
hex: "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20".to_string(),
};

// Verify they are different variants
match mnemonic_cred {
WormholeCredential::Mnemonic { phrase, wormhole_index } => {
assert!(phrase.contains("abandon"));
assert_eq!(wormhole_index, 0);
},
_ => panic!("Expected Mnemonic variant"),
}

match secret_cred {
WormholeCredential::Secret { hex } => {
assert!(hex.starts_with("0102"));
},
_ => panic!("Expected Secret variant"),
}
#[test]
fn test_resolve_credential_mnemonic_and_secret_equivalence() {
let mnemonic_cred =
WormholeCredential::Mnemonic { phrase: TEST_MNEMONIC.to_string(), wormhole_index: 0 };
let (m_address, m_address_bytes, m_secret_bytes) =
resolve_credential(&mnemonic_cred).unwrap();

let secret_cred = WormholeCredential::Secret { hex: hex::encode(m_secret_bytes) };
let (s_address, s_address_bytes, s_secret_bytes) =
resolve_credential(&secret_cred).unwrap();

assert_eq!(m_address, s_address);
assert_eq!(m_address_bytes, s_address_bytes);
assert_eq!(m_secret_bytes, s_secret_bytes);
}
}
Loading