Skip to content
Closed
178 changes: 111 additions & 67 deletions crates/precompiles/src/account_keychain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use tempo_contracts::precompiles::{
use crate::{
ACCOUNT_KEYCHAIN_ADDRESS,
error::Result,
storage::{Handler, Mapping, Set, packing::insert_into_word},
storage::{EnumerableMap, Handler, Mapping, Set, packing::insert_into_word},
tip20_factory::TIP20Factory,
};
use alloy::primitives::{Address, B256, FixedBytes, TxKind, U256, keccak256};
Expand Down Expand Up @@ -85,12 +85,11 @@ pub struct SelectorScope {
/// mode:
/// - 0 => unset/disabled
/// - 1 => all selectors allowed
/// - 2 => only selectors in the set are allowed
/// - 2 => only selectors in the configured list are allowed
#[derive(Debug, Clone, Storable, Default)]
pub struct TargetScope {
pub mode: u8,
pub selectors: Set<FixedBytes<4>>,
pub selector_scopes: Mapping<FixedBytes<4>, SelectorScope>,
pub selectors: EnumerableMap<FixedBytes<4>, SelectorScope>,
}

/// Key-level call scope.
Expand All @@ -102,8 +101,7 @@ pub struct TargetScope {
#[derive(Debug, Clone, Storable, Default)]
pub struct KeyScope {
pub mode: u8,
pub targets: Set<Address>,
pub target_scopes: Mapping<Address, TargetScope>,
pub targets: EnumerableMap<Address, TargetScope>,
}

/// Key information stored in the precompile
Expand Down Expand Up @@ -617,7 +615,7 @@ impl AccountKeychain {
}]);
}

let targets = self.key_scopes[key_hash].targets.read()?;
let targets = self.key_scopes[key_hash].targets.keys()?;
if targets.is_empty() {
return Ok(vec![CallScope {
target: Address::ZERO,
Expand All @@ -628,9 +626,7 @@ impl AccountKeychain {

let mut scopes = Vec::new();
for target in targets {
let target_mode = self.key_scopes[key_hash].target_scopes[target]
.mode
.read()?;
let target_mode = self.key_scopes[key_hash].targets[target].mode.read()?;

let scope = match target_mode {
1 => CallScope {
Expand All @@ -640,22 +636,19 @@ impl AccountKeychain {
},
2 => {
let mut rules = Vec::new();
let selectors = self.key_scopes[key_hash].target_scopes[target]
.selectors
.read()?;
let selectors = self.key_scopes[key_hash].targets[target].selectors.keys()?;
for selector in selectors {
let selector_mode = self.key_scopes[key_hash].target_scopes[target]
.selector_scopes[selector]
let selector_mode = self.key_scopes[key_hash].targets[target].selectors
[selector]
.mode
.read()?;

let recipients = if selector_mode == 2 {
let recipients: Vec<Address> = self.key_scopes[key_hash].target_scopes
[target]
.selector_scopes[selector]
.recipients
.read()?
.into();
let recipients: Vec<Address> =
self.key_scopes[key_hash].targets[target].selectors[selector]
.recipients
.read()?
.into();
recipients
} else if selector_mode == 1 {
Vec::new()
Expand Down Expand Up @@ -785,9 +778,7 @@ impl AccountKeychain {
TxKind::Create => return Err(AccountKeychainError::call_not_allowed().into()),
};

let target_mode = self.key_scopes[key_hash].target_scopes[target]
.mode
.read()?;
let target_mode = self.key_scopes[key_hash].targets[target].mode.read()?;
if target_mode == 1 {
return Ok(());
}
Expand All @@ -800,16 +791,8 @@ impl AccountKeychain {
return Err(AccountKeychainError::call_not_allowed().into());
}

let selector = FixedBytes::<4>::from([input[0], input[1], input[2], input[3]]);
if !self.key_scopes[key_hash].target_scopes[target]
.selectors
.contains(&selector)?
{
return Err(AccountKeychainError::call_not_allowed().into());
}

let selector_mode = self.key_scopes[key_hash].target_scopes[target].selector_scopes
[selector]
let selector: FixedBytes<4> = [input[0], input[1], input[2], input[3]].into();
let selector_mode = self.key_scopes[key_hash].targets[target].selectors[selector]
.mode
.read()?;
if selector_mode == 1 {
Expand All @@ -829,7 +812,7 @@ impl AccountKeychain {
}

let recipient = Address::from_slice(&recipient_word[12..]);
if self.key_scopes[key_hash].target_scopes[target].selector_scopes[selector]
if self.key_scopes[key_hash].targets[target].selectors[selector]
.recipients
.contains(&recipient)?
{
Expand All @@ -846,8 +829,9 @@ impl AccountKeychain {
) -> Result<()> {
// Fresh authorizations should not have any pre-existing call-scope rows because
// `authorize_key` rejects both existing and previously revoked keys before reaching this
// path. We still clear the scope tree first as a defense-in-depth measure against stale or
// out-of-band state, and keep it because the valid-path cost is low (empty target set).
// path. We still clear the indexed scope tree first as a defense-in-depth measure against
// stale state. Mapping rows are the internal source of truth; vec indexes are maintained
// by mutator paths so externally enumerated scopes remain correct.
self.clear_all_target_scopes(account_key)?;

match allowed_calls {
Expand Down Expand Up @@ -888,7 +872,7 @@ impl AccountKeychain {
}

fn clear_all_target_scopes(&mut self, account_key: B256) -> Result<()> {
let targets = self.key_scopes[account_key].targets.read()?;
let targets = self.key_scopes[account_key].targets.keys()?;
for target in targets {
self.remove_target_scope(account_key, target)?;
}
Expand All @@ -897,32 +881,32 @@ impl AccountKeychain {
}

fn remove_target_scope(&mut self, account_key: B256, target: Address) -> Result<()> {
if !self.key_scopes[account_key].targets.remove(&target)? {
if self.key_scopes[account_key].targets[target].mode.read()? == 0 {
return Ok(());
}

self.clear_target_selectors(account_key, target)?;
self.key_scopes[account_key].target_scopes[target]
.mode
.write(0)
self.key_scopes[account_key].targets[target].mode.write(0)?;
self.key_scopes[account_key].targets.remove_key(&target)?;
Ok(())
}

fn clear_target_selectors(&mut self, account_key: B256, target: Address) -> Result<()> {
let selectors = self.key_scopes[account_key].target_scopes[target]
let selectors = self.key_scopes[account_key].targets[target]
.selectors
.read()?;
.keys()?;
for selector in selectors {
self.key_scopes[account_key].target_scopes[target].selector_scopes[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.recipients
.delete()?;
self.key_scopes[account_key].target_scopes[target].selector_scopes[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.mode
.write(0)?;
}

self.key_scopes[account_key].target_scopes[target]
self.key_scopes[account_key].targets[target]
.selectors
.delete()
.clear_keys()
}

fn upsert_target_scope(
Expand All @@ -938,52 +922,49 @@ impl AccountKeychain {
self.validate_selector_rules(target, rules)?;
}

if !self.key_scopes[account_key].targets.contains(&target)? {
if !self.key_scopes[account_key]
.targets
.contains_mapped(&target, |scope| scope.mode.read().map(|mode| mode != 0))?
{
let count = self.key_scopes[account_key].targets.len()?;
if count >= MAX_CALL_SCOPES as usize {
return Err(AccountKeychainError::scope_limit_exceeded().into());
}

self.key_scopes[account_key].targets.insert(target)?;
self.key_scopes[account_key]
.targets
.insert_key_unchecked(target)?;
}

self.clear_target_selectors(account_key, target)?;

match selector_rules {
None => {
self.key_scopes[account_key].target_scopes[target]
.mode
.write(1)?;
self.key_scopes[account_key].targets[target].mode.write(1)?;
}
Some(rules) => {
self.key_scopes[account_key].target_scopes[target]
.mode
.write(2)?;
self.key_scopes[account_key].targets[target].mode.write(2)?;

for rule in rules {
let selector = FixedBytes::<4>::from(rule.selector);
self.key_scopes[account_key].target_scopes[target]
let selector: FixedBytes<4> = rule.selector.into();
self.key_scopes[account_key].targets[target]
.selectors
.insert(selector)?;
.insert_key_unchecked(selector)?;

match rule.recipients {
None => {
self.key_scopes[account_key].target_scopes[target].selector_scopes
[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.mode
.write(1)?;
self.key_scopes[account_key].target_scopes[target].selector_scopes
[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.recipients
.delete()?;
}
Some(recipients) => {
self.key_scopes[account_key].target_scopes[target].selector_scopes
[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.mode
.write(2)?;
self.key_scopes[account_key].target_scopes[target].selector_scopes
[selector]
self.key_scopes[account_key].targets[target].selectors[selector]
.recipients
.write(Set::from(recipients))?;
}
Expand Down Expand Up @@ -3551,6 +3532,69 @@ mod tests {
})
}

#[test]
fn test_t3_set_allowed_calls_replaces_selector_index_without_duplicates() -> eyre::Result<()> {
let mut storage = HashMapStorageProvider::new_with_spec(1, TempoHardfork::T3);
let account = Address::random();
let key_id = Address::random();
let target = DEFAULT_FEE_TOKEN;
let first_selector = TIP20_TRANSFER_SELECTOR;
let second_selector = TIP20_APPROVE_SELECTOR;

StorageCtx::enter(&mut storage, || {
let mut keychain = AccountKeychain::new();
keychain.initialize()?;
keychain.set_transaction_key(Address::ZERO)?;
keychain.set_tx_origin(account)?;

keychain.authorize_key(
account,
authorizeKeyCall {
keyId: key_id,
signatureType: SignatureType::Secp256k1,
config: KeyRestrictions {
expiry: u64::MAX,
enforceLimits: false,
limits: vec![],
enforceAllowedCalls: false,
allowedCalls: vec![],
},
},
)?;

let set_scope = |keychain: &mut AccountKeychain, selector: [u8; 4]| {
keychain.set_allowed_calls(
account,
setAllowedCallsCall {
keyId: key_id,
scope: CallScope {
target,
allowAllSelectors: false,
selectorRules: vec![SelectorRule {
selector: selector.into(),
recipients: vec![],
}],
},
},
)
};

set_scope(&mut keychain, first_selector)?;
set_scope(&mut keychain, second_selector)?;

let scopes = keychain.get_allowed_calls(getAllowedCallsCall {
account,
keyId: key_id,
})?;
assert_eq!(scopes.len(), 1);
assert_eq!(scopes[0].target, target);
assert_eq!(scopes[0].selectorRules.len(), 1);
assert_eq!(*scopes[0].selectorRules[0].selector, second_selector);

Ok(())
})
}

#[test]
fn test_t3_set_allowed_calls_allow_all_selectors_ignores_selector_rules() -> eyre::Result<()> {
let mut storage = HashMapStorageProvider::new_with_spec(1, TempoHardfork::T3);
Expand Down
Loading