diff --git a/evmd/app.go b/evmd/app.go index 20d4fbf7..317bebd2 100644 --- a/evmd/app.go +++ b/evmd/app.go @@ -449,15 +449,6 @@ func NewExampleApp( runtime.ProvideCometInfoService(), ) // If evidence needs to be handled for the app, set routes in router here and seal - // Note: The evidence precompile allows evidence to be submitted through an EVM transaction. - // If you implement a custom evidence handler in the router that changes token balances (e.g. penalizing - // addresses, deducting fees, etc.), be aware that the precompile logic (e.g. SetBalanceChangeEntries) - // must be properly integrated to reflect these balance changes in the EVM state. Otherwise, there is a risk - // of desynchronization between the Cosmos SDK state and the EVM state when evidence is submitted via the EVM. - // - // For example, if your custom evidence handler deducts tokens from a user’s account, ensure that the evidence - // precompile also applies these deductions through the EVM’s balance tracking. Failing to do so may cause - // inconsistencies in reported balances and break state synchronization. app.EvidenceKeeper = *evidenceKeeper // Cosmos EVM keepers diff --git a/precompiles/bank/bank.go b/precompiles/bank/bank.go index a85b6a7b..c3538204 100644 --- a/precompiles/bank/bank.go +++ b/precompiles/bank/bank.go @@ -104,7 +104,7 @@ func (p Precompile) RequiredGas(input []byte) uint64 { // Run executes the precompiled contract bank query methods defined in the ABI. func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz []byte, err error) { - ctx, stateDB, method, initialGas, args, err := p.RunSetup(evm, contract, readOnly, p.IsTransaction) + ctx, _, method, initialGas, args, err := p.RunSetup(evm, contract, readOnly, p.IsTransaction) if err != nil { return nil, err } @@ -134,9 +134,6 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ if !contract.UseGas(cost, nil, tracing.GasChangeCallPrecompiledContract) { return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { - return nil, err - } return bz, nil } diff --git a/precompiles/common/balance_handler.go b/precompiles/common/balance_handler.go new file mode 100644 index 00000000..79f5f5b3 --- /dev/null +++ b/precompiles/common/balance_handler.go @@ -0,0 +1,107 @@ +package common + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/holiman/uint256" + + "github.com/cosmos/evm/utils" + "github.com/cosmos/evm/x/vm/statedb" + evmtypes "github.com/cosmos/evm/x/vm/types" + + sdk "github.com/cosmos/cosmos-sdk/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" +) + +// BalanceHandler is a struct that handles balance changes in the Cosmos SDK context. +type BalanceHandler struct { + prevEventsLen int +} + +// NewBalanceHandler creates a new BalanceHandler instance. +func NewBalanceHandler() *BalanceHandler { + return &BalanceHandler{ + prevEventsLen: 0, + } +} + +// BeforeBalanceChange is called before any balance changes by precompile methods. +// It records the current number of events in the context to later process balance changes +// using the recorded events. +func (bh *BalanceHandler) BeforeBalanceChange(ctx sdk.Context) { + bh.prevEventsLen = len(ctx.EventManager().Events()) +} + +// AfterBalanceChange processes the recorded events and updates the stateDB accordingly. +// It handles the bank events for coin spent and coin received, updating the balances +// of the spender and receiver addresses respectively. +func (bh *BalanceHandler) AfterBalanceChange(ctx sdk.Context, stateDB *statedb.StateDB) error { + events := ctx.EventManager().Events() + + for _, event := range events[bh.prevEventsLen:] { + switch event.Type { + case banktypes.EventTypeCoinSpent: + spenderHexAddr, err := parseHexAddress(event, banktypes.AttributeKeySpender) + if err != nil { + return fmt.Errorf("failed to parse spender address from event %q: %w", banktypes.EventTypeCoinSpent, err) + } + + amount, err := parseAmount(event) + if err != nil { + return fmt.Errorf("failed to parse amount from event %q: %w", banktypes.EventTypeCoinSpent, err) + } + + stateDB.SubBalance(spenderHexAddr, amount, tracing.BalanceChangeUnspecified) + + case banktypes.EventTypeCoinReceived: + receiverHexAddr, err := parseHexAddress(event, banktypes.AttributeKeyReceiver) + if err != nil { + return fmt.Errorf("failed to parse receiver address from event %q: %w", banktypes.EventTypeCoinReceived, err) + } + + amount, err := parseAmount(event) + if err != nil { + return fmt.Errorf("failed to parse amount from event %q: %w", banktypes.EventTypeCoinReceived, err) + } + + stateDB.AddBalance(receiverHexAddr, amount, tracing.BalanceChangeUnspecified) + } + } + + return nil +} + +func parseHexAddress(event sdk.Event, key string) (common.Address, error) { + attr, ok := event.GetAttribute(key) + if !ok { + return common.Address{}, fmt.Errorf("event %q missing attribute %q", event.Type, key) + } + + accAddr, err := sdk.AccAddressFromBech32(attr.Value) + if err != nil { + return common.Address{}, fmt.Errorf("invalid address %q: %w", attr.Value, err) + } + + return common.Address(accAddr.Bytes()), nil +} + +func parseAmount(event sdk.Event) (*uint256.Int, error) { + amountAttr, ok := event.GetAttribute(sdk.AttributeKeyAmount) + if !ok { + return nil, fmt.Errorf("event %q missing attribute %q", banktypes.EventTypeCoinSpent, sdk.AttributeKeyAmount) + } + + amountCoins, err := sdk.ParseCoinsNormalized(amountAttr.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse coins from %q: %w", amountAttr.Value, err) + } + + amountBigInt := amountCoins.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt() + amount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(amountBigInt)) + if err != nil { + return nil, fmt.Errorf("failed to convert coin amount to Uint256: %w", err) + } + return amount, nil +} diff --git a/precompiles/common/balance_handler_test.go b/precompiles/common/balance_handler_test.go new file mode 100644 index 00000000..fdbe0952 --- /dev/null +++ b/precompiles/common/balance_handler_test.go @@ -0,0 +1,144 @@ +package common + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/holiman/uint256" + "github.com/stretchr/testify/require" + + testutil "github.com/cosmos/evm/testutil" + testconstants "github.com/cosmos/evm/testutil/constants" + "github.com/cosmos/evm/x/vm/statedb" + evmtypes "github.com/cosmos/evm/x/vm/types" + + storetypes "cosmossdk.io/store/types" + + sdktestutil "github.com/cosmos/cosmos-sdk/testutil" + sdk "github.com/cosmos/cosmos-sdk/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" +) + +func setupBalanceHandlerTest(t *testing.T) { + t.Helper() + + sdk.GetConfig().SetBech32PrefixForAccount(testconstants.ExampleBech32Prefix, "") + configurator := evmtypes.NewEVMConfigurator() + configurator.ResetTestConfig() + require.NoError(t, configurator.WithEVMCoinInfo(testconstants.ExampleChainCoinInfo[testconstants.ExampleChainID]).Configure()) +} + +func TestParseHexAddress(t *testing.T) { + setupBalanceHandlerTest(t) + + _, addrs, err := testutil.GeneratePrivKeyAddressPairs(1) + require.NoError(t, err) + accAddr := addrs[0] + + // valid address + ev := sdk.NewEvent("bank", sdk.NewAttribute(banktypes.AttributeKeySpender, accAddr.String())) + addr, err := parseHexAddress(ev, banktypes.AttributeKeySpender) + require.NoError(t, err) + require.Equal(t, common.Address(accAddr.Bytes()), addr) + + // missing attribute + ev = sdk.NewEvent("bank") + _, err = parseHexAddress(ev, banktypes.AttributeKeySpender) + require.Error(t, err) + + // invalid address + ev = sdk.NewEvent("bank", sdk.NewAttribute(banktypes.AttributeKeySpender, "invalid")) + _, err = parseHexAddress(ev, banktypes.AttributeKeySpender) + require.Error(t, err) +} + +func TestParseAmount(t *testing.T) { + setupBalanceHandlerTest(t) + + coinStr := sdk.NewCoins(sdk.NewInt64Coin(evmtypes.GetEVMCoinDenom(), 5)).String() + ev := sdk.NewEvent("bank", sdk.NewAttribute(sdk.AttributeKeyAmount, coinStr)) + amt, err := parseAmount(ev) + require.NoError(t, err) + require.True(t, amt.Eq(uint256.NewInt(5))) + + // missing amount + ev = sdk.NewEvent("bank") + _, err = parseAmount(ev) + require.Error(t, err) + + // invalid coins + ev = sdk.NewEvent("bank", sdk.NewAttribute(sdk.AttributeKeyAmount, "invalid")) + _, err = parseAmount(ev) + require.Error(t, err) +} + +func TestAfterBalanceChange(t *testing.T) { + setupBalanceHandlerTest(t) + + storeKey := storetypes.NewKVStoreKey("test") + tKey := storetypes.NewTransientStoreKey("test_t") + ctx := sdktestutil.DefaultContext(storeKey, tKey) + + stateDB := statedb.New(ctx, testutil.NewMockKeeper(), statedb.NewEmptyTxConfig(common.BytesToHash(ctx.HeaderHash()))) + + _, addrs, err := testutil.GeneratePrivKeyAddressPairs(2) + require.NoError(t, err) + spenderAcc := addrs[0] + receiverAcc := addrs[1] + spender := common.Address(spenderAcc.Bytes()) + receiver := common.Address(receiverAcc.Bytes()) + + // initial balance for spender + stateDB.AddBalance(spender, uint256.NewInt(5), tracing.BalanceChangeUnspecified) + + bh := NewBalanceHandler() + bh.BeforeBalanceChange(ctx) + + coins := sdk.NewCoins(sdk.NewInt64Coin(evmtypes.GetEVMCoinDenom(), 3)) + ctx.EventManager().EmitEvents(sdk.Events{ + banktypes.NewCoinSpentEvent(spenderAcc, coins), + banktypes.NewCoinReceivedEvent(receiverAcc, coins), + }) + + err = bh.AfterBalanceChange(ctx, stateDB) + require.NoError(t, err) + + require.Equal(t, "2", stateDB.GetBalance(spender).String()) + require.Equal(t, "3", stateDB.GetBalance(receiver).String()) +} + +func TestAfterBalanceChangeErrors(t *testing.T) { + setupBalanceHandlerTest(t) + + storeKey := storetypes.NewKVStoreKey("test") + tKey := storetypes.NewTransientStoreKey("test_t") + ctx := sdktestutil.DefaultContext(storeKey, tKey) + stateDB := statedb.New(ctx, testutil.NewMockKeeper(), statedb.NewEmptyTxConfig(common.BytesToHash(ctx.HeaderHash()))) + + _, addrs, err := testutil.GeneratePrivKeyAddressPairs(1) + require.NoError(t, err) + addr := addrs[0] + + bh := NewBalanceHandler() + bh.BeforeBalanceChange(ctx) + + // invalid address in event + coins := sdk.NewCoins(sdk.NewInt64Coin(evmtypes.GetEVMCoinDenom(), 1)) + ctx.EventManager().EmitEvent(banktypes.NewCoinSpentEvent(addr, coins)) + ctx.EventManager().Events()[len(ctx.EventManager().Events())-1].Attributes[0].Value = "invalid" + err = bh.AfterBalanceChange(ctx, stateDB) + require.Error(t, err) + + // reset events + ctx = ctx.WithEventManager(sdk.NewEventManager()) + bh.BeforeBalanceChange(ctx) + + // invalid amount + ev := sdk.NewEvent(banktypes.EventTypeCoinSpent, + sdk.NewAttribute(banktypes.AttributeKeySpender, addr.String()), + sdk.NewAttribute(sdk.AttributeKeyAmount, "invalid")) + ctx.EventManager().EmitEvent(ev) + err = bh.AfterBalanceChange(ctx, stateDB) + require.Error(t, err) +} diff --git a/precompiles/common/precompile.go b/precompiles/common/precompile.go index 3252f48d..a78ecfbe 100644 --- a/precompiles/common/precompile.go +++ b/precompiles/common/precompile.go @@ -23,7 +23,7 @@ type Precompile struct { KvGasConfig storetypes.GasConfig TransientKVGasConfig storetypes.GasConfig address common.Address - journalEntries []BalanceChangeEntry + balanceHandler *BalanceHandler } // Operation is a type that defines if the precompile call @@ -172,30 +172,6 @@ func HandleGasError(ctx sdk.Context, contract *vm.Contract, initialGas storetype } } -// AddJournalEntries adds the balanceChange (if corresponds) -func (p Precompile) AddJournalEntries(stateDB *statedb.StateDB) error { - for _, entry := range p.journalEntries { - switch entry.Op { - case Sub: - // add the corresponding balance change to the journal - stateDB.SubBalance(entry.Account, entry.Amount, tracing.BalanceChangeUnspecified) - case Add: - // add the corresponding balance change to the journal - stateDB.AddBalance(entry.Account, entry.Amount, tracing.BalanceChangeUnspecified) - } - } - - return nil -} - -// SetBalanceChangeEntries sets the balanceChange entries -// as the journalEntries field of the precompile. -// These entries will be added to the stateDB's journal -// when calling the AddJournalEntries function -func (p *Precompile) SetBalanceChangeEntries(entries ...BalanceChangeEntry) { - p.journalEntries = entries -} - func (p Precompile) Address() common.Address { return p.address } @@ -248,3 +224,10 @@ func (p Precompile) standardCallData(contract *vm.Contract) (method *abi.Method, return method, nil } + +func (p *Precompile) GetBalanceHandler() *BalanceHandler { + if p.balanceHandler == nil { + p.balanceHandler = NewBalanceHandler() + } + return p.balanceHandler +} diff --git a/precompiles/distribution/distribution.go b/precompiles/distribution/distribution.go index 921a66df..a9bb42fa 100644 --- a/precompiles/distribution/distribution.go +++ b/precompiles/distribution/distribution.go @@ -90,6 +90,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -140,7 +143,8 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } diff --git a/precompiles/distribution/tx.go b/precompiles/distribution/tx.go index 0cee9f17..5cda94a9 100644 --- a/precompiles/distribution/tx.go +++ b/precompiles/distribution/tx.go @@ -4,13 +4,9 @@ import ( "fmt" "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" - "github.com/holiman/uint256" cmn "github.com/cosmos/evm/precompiles/common" - "github.com/cosmos/evm/utils" - evmtypes "github.com/cosmos/evm/x/vm/types" sdk "github.com/cosmos/cosmos-sdk/types" distributionkeeper "github.com/cosmos/cosmos-sdk/x/distribution/keeper" @@ -82,20 +78,6 @@ func (p *Precompile) ClaimRewards( totalCoins = totalCoins.Add(coins...) } - withdrawerHexAddr, err := p.getWithdrawerHexAddr(ctx, delegatorAddr) - if err != nil { - return nil, err - } - - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(totalCoins.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt())) - if err != nil { - return nil, err - } - // check if converted amount is greater than zero - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(withdrawerHexAddr, convertedAmount, cmn.Add)) - } - if err := p.EmitClaimRewardsEvent(ctx, stateDB, delegatorAddr, totalCoins); err != nil { return nil, err } @@ -157,21 +139,6 @@ func (p *Precompile) WithdrawDelegatorReward( return nil, err } - // rewards go to the withdrawer address - withdrawerHexAddr, err := p.getWithdrawerHexAddr(ctx, delegatorHexAddr) - if err != nil { - return nil, err - } - - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(res.Amount.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt())) - if err != nil { - return nil, err - } - // check if converted amount is greater than zero - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(withdrawerHexAddr, convertedAmount, cmn.Add)) - } - if err = p.EmitWithdrawDelegatorRewardEvent(ctx, stateDB, delegatorHexAddr, msg.ValidatorAddress, res.Amount); err != nil { return nil, err } @@ -203,21 +170,6 @@ func (p *Precompile) WithdrawValidatorCommission( return nil, err } - // commissions go to the withdrawer address - withdrawerHexAddr, err := p.getWithdrawerHexAddr(ctx, validatorHexAddr) - if err != nil { - return nil, err - } - - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(res.Amount.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt())) - if err != nil { - return nil, err - } - // check if converted amount is greater than zero - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(withdrawerHexAddr, convertedAmount, cmn.Add)) - } - if err = p.EmitWithdrawValidatorCommissionEvent(ctx, stateDB, msg.ValidatorAddress, res.Amount); err != nil { return nil, err } @@ -249,15 +201,6 @@ func (p *Precompile) FundCommunityPool( return nil, err } - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(msg.Amount.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt())) - if err != nil { - return nil, err - } - // check if converted amount is greater than zero - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(depositorHexAddr, convertedAmount, cmn.Sub)) - } - if err = p.EmitFundCommunityPoolEvent(ctx, stateDB, depositorHexAddr, msg.Amount); err != nil { return nil, err } @@ -289,16 +232,6 @@ func (p *Precompile) DepositValidatorRewardsPool( if err != nil { return nil, err } - if found, evmCoinAmount := msg.Amount.Find(evmtypes.GetEVMCoinDenom()); found { - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(evmCoinAmount.Amount.BigInt())) - if err != nil { - return nil, err - } - // check if converted amount is greater than zero - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(depositorHexAddr, convertedAmount, cmn.Sub)) - } - } if err = p.EmitDepositValidatorRewardsPoolEvent(ctx, stateDB, depositorHexAddr, msg.ValidatorAddress, msg.Amount); err != nil { return nil, err @@ -306,14 +239,3 @@ func (p *Precompile) DepositValidatorRewardsPool( return method.Outputs.Pack(true) } - -// getWithdrawerHexAddr is a helper function to get the hex address -// of the withdrawer for the specified account address -func (p Precompile) getWithdrawerHexAddr(ctx sdk.Context, delegatorAddr common.Address) (common.Address, error) { - withdrawerAccAddr, err := p.distributionKeeper.GetDelegatorWithdrawAddr(ctx, delegatorAddr.Bytes()) - if err != nil { - return common.Address{}, err - } - - return common.BytesToAddress(withdrawerAccAddr), nil -} diff --git a/precompiles/erc20/erc20.go b/precompiles/erc20/erc20.go index a8f543d2..cc3f9176 100644 --- a/precompiles/erc20/erc20.go +++ b/precompiles/erc20/erc20.go @@ -149,6 +149,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -163,9 +166,13 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ if !contract.UseGas(cost, nil, tracing.GasChangeCallPrecompiledContract) { return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + + // Process the native balance changes after the method execution. + err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB) + if err != nil { return nil, err } + return bz, nil } diff --git a/precompiles/erc20/tx.go b/precompiles/erc20/tx.go index 365e0cf1..05be1a11 100644 --- a/precompiles/erc20/tx.go +++ b/precompiles/erc20/tx.go @@ -7,10 +7,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" - cmn "github.com/cosmos/evm/precompiles/common" - "github.com/cosmos/evm/utils" - evmtypes "github.com/cosmos/evm/x/vm/types" - "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -125,16 +121,6 @@ func (p *Precompile) transfer( return nil, ConvertErrToERC20Error(err) } - evmDenom := evmtypes.GetEVMCoinDenom() - if p.tokenPair.Denom == evmDenom { - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(amount)) - if err != nil { - return nil, err - } - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(from, convertedAmount, cmn.Sub), - cmn.NewBalanceChangeEntry(to, convertedAmount, cmn.Add)) - } - if err = p.EmitTransferEvent(ctx, stateDB, from, to, amount); err != nil { return nil, err } diff --git a/precompiles/evidence/evidence.go b/precompiles/evidence/evidence.go index 36320e4c..d9db833b 100644 --- a/precompiles/evidence/evidence.go +++ b/precompiles/evidence/evidence.go @@ -87,6 +87,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -114,7 +117,8 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } diff --git a/precompiles/gov/gov.go b/precompiles/gov/gov.go index 9e66c926..447fab58 100644 --- a/precompiles/gov/gov.go +++ b/precompiles/gov/gov.go @@ -91,6 +91,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -141,7 +144,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB) + if err != nil { return nil, err } diff --git a/precompiles/gov/tx.go b/precompiles/gov/tx.go index ca2d24d7..572381f9 100644 --- a/precompiles/gov/tx.go +++ b/precompiles/gov/tx.go @@ -5,13 +5,8 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/core/vm" - "github.com/holiman/uint256" cmn "github.com/cosmos/evm/precompiles/common" - "github.com/cosmos/evm/utils" - evmtypes "github.com/cosmos/evm/x/vm/types" - - "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" govkeeper "github.com/cosmos/cosmos-sdk/x/gov/keeper" @@ -53,15 +48,6 @@ func (p *Precompile) SubmitProposal( return nil, err } - deposit := msg.InitialDeposit - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(deposit.AmountOf(evmtypes.GetEVMCoinDenom()).BigInt())) - if err != nil { - return nil, err - } - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(proposerHexAddr, convertedAmount, cmn.Sub)) - } - if err = p.EmitSubmitProposalEvent(ctx, stateDB, proposerHexAddr, res.ProposalId); err != nil { return nil, err } @@ -90,18 +76,6 @@ func (p *Precompile) Deposit( if _, err = govkeeper.NewMsgServerImpl(&p.govKeeper).Deposit(ctx, msg); err != nil { return nil, err } - for _, coin := range msg.Amount { - if coin.Denom != evmtypes.GetEVMCoinDenom() { - continue - } - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(coin.Amount.BigInt())) - if err != nil { - return nil, err - } - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(depositorHexAddr, convertedAmount, cmn.Sub)) - } - } if err = p.EmitDepositEvent(ctx, stateDB, depositorHexAddr, msg.ProposalId, msg.Amount); err != nil { return nil, err @@ -128,43 +102,10 @@ func (p *Precompile) CancelProposal( return nil, fmt.Errorf(cmn.ErrRequesterIsNotMsgSender, msgSender.String(), proposerHexAddr.String()) } - // pre-calculate the remaining deposit - govParams, err := p.govKeeper.Params.Get(ctx) - if err != nil { - return nil, err - } - cancelRate, err := math.LegacyNewDecFromStr(govParams.ProposalCancelRatio) - if err != nil { - return nil, err - } - deposits, err := p.govKeeper.GetDeposits(ctx, msg.ProposalId) - if err != nil { - return nil, err - } - var remaninig math.Int - for _, deposit := range deposits { - if deposit.Depositor != sdk.AccAddress(proposerHexAddr.Bytes()).String() { - continue - } - for _, coin := range deposit.Amount { - if coin.Denom == evmtypes.GetEVMCoinDenom() { - cancelFee := coin.Amount.ToLegacyDec().Mul(cancelRate).TruncateInt() - remaninig = coin.Amount.Sub(cancelFee) - } - } - } if _, err = govkeeper.NewMsgServerImpl(&p.govKeeper).CancelProposal(ctx, msg); err != nil { return nil, err } - convertedAmount, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(remaninig.BigInt())) - if err != nil { - return nil, err - } - if convertedAmount.Cmp(uint256.NewInt(0)) == 1 { - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(proposerHexAddr, convertedAmount, cmn.Add)) - } - if err = p.EmitCancelProposalEvent(ctx, stateDB, proposerHexAddr, msg.ProposalId); err != nil { return nil, err } diff --git a/precompiles/ics20/ics20.go b/precompiles/ics20/ics20.go index 300f9ae6..06cb7dac 100644 --- a/precompiles/ics20/ics20.go +++ b/precompiles/ics20/ics20.go @@ -97,6 +97,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -126,7 +129,8 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } diff --git a/precompiles/ics20/tx.go b/precompiles/ics20/tx.go index 6ee9b645..936ba707 100644 --- a/precompiles/ics20/tx.go +++ b/precompiles/ics20/tx.go @@ -4,13 +4,9 @@ import ( "fmt" "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" cmn "github.com/cosmos/evm/precompiles/common" - "github.com/cosmos/evm/utils" - evmtypes "github.com/cosmos/evm/x/vm/types" - transfertypes "github.com/cosmos/ibc-go/v10/modules/apps/transfer/types" channeltypes "github.com/cosmos/ibc-go/v10/modules/core/04-channel/types" host "github.com/cosmos/ibc-go/v10/modules/core/24-host" @@ -64,35 +60,11 @@ func (p *Precompile) Transfer( return nil, fmt.Errorf(cmn.ErrRequesterIsNotMsgSender, msgSender.String(), sender.String()) } - evmDenom := evmtypes.GetEVMCoinDenom() - sendAmt := msg.Token.Amount - if sendAmt.GTE(transfertypes.UnboundedSpendLimit()) { - spendable := p.bankKeeper.SpendableCoin(ctx, sender.Bytes(), evmDenom) - sendAmt = spendable.Amount - } res, err := p.transferKeeper.Transfer(ctx, msg) if err != nil { return nil, err } - if msg.Token.Denom == evmDenom { - // escrow address is also changed on this tx, and it is not a module account - // so we need to account for this on the UpdateDirties - escrowAccAddress := transfertypes.GetEscrowAddress(msg.SourcePort, msg.SourceChannel) - escrowHexAddr := common.BytesToAddress(escrowAccAddress) - // NOTE: This ensures that the changes in the bank keeper are correctly mirrored to the EVM stateDB - // when calling the precompile from another smart contract. - // This prevents the stateDB from overwriting the changed balance in the bank keeper when committing the EVM state. - amt, err := utils.Uint256FromBigInt(sendAmt.BigInt()) - if err != nil { - return nil, err - } - p.SetBalanceChangeEntries( - cmn.NewBalanceChangeEntry(sender, amt, cmn.Sub), - cmn.NewBalanceChangeEntry(escrowHexAddr, amt, cmn.Add), - ) - } - if err = EmitIBCTransferEvent( ctx, stateDB, diff --git a/precompiles/slashing/slashing.go b/precompiles/slashing/slashing.go index 04132d7b..352bde02 100644 --- a/precompiles/slashing/slashing.go +++ b/precompiles/slashing/slashing.go @@ -87,6 +87,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -114,7 +117,8 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err := p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } diff --git a/precompiles/staking/staking.go b/precompiles/staking/staking.go index 1439ac94..56fa9e53 100644 --- a/precompiles/staking/staking.go +++ b/precompiles/staking/staking.go @@ -86,6 +86,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of a precompile tx or query. // It avoids panics and returns the out of gas error so the EVM can continue gracefully. defer cmn.HandleGasError(ctx, contract, initialGas, &err)() @@ -129,7 +132,8 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err = p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } diff --git a/precompiles/staking/tx.go b/precompiles/staking/tx.go index ce61ab95..f84bff80 100644 --- a/precompiles/staking/tx.go +++ b/precompiles/staking/tx.go @@ -8,8 +8,6 @@ import ( "github.com/ethereum/go-ethereum/core/vm" cmn "github.com/cosmos/evm/precompiles/common" - "github.com/cosmos/evm/utils" - evmtypes "github.com/cosmos/evm/x/vm/types" sdk "github.com/cosmos/cosmos-sdk/types" stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" @@ -175,19 +173,6 @@ func (p *Precompile) Delegate( return nil, err } - if msg.Amount.Denom == evmtypes.GetEVMCoinDenom() { - // NOTE: This ensures that the changes in the bank keeper are correctly mirrored to the EVM stateDB - // when calling the precompile from a smart contract - // This prevents the stateDB from overwriting the changed balance in the bank keeper when committing the EVM state. - - // Need to scale the amount to 18 decimals for the EVM balance change entry - scaledAmt, err := utils.Uint256FromBigInt(evmtypes.ConvertAmountTo18DecimalsBigInt(msg.Amount.Amount.BigInt())) - if err != nil { - return nil, err - } - p.SetBalanceChangeEntries(cmn.NewBalanceChangeEntry(delegatorHexAddr, scaledAmt, cmn.Sub)) - } - return method.Outputs.Pack(true) } diff --git a/precompiles/werc20/tx.go b/precompiles/werc20/tx.go index 6ba36c0a..4990e98a 100644 --- a/precompiles/werc20/tx.go +++ b/precompiles/werc20/tx.go @@ -6,7 +6,6 @@ import ( "github.com/ethereum/go-ethereum/core/vm" - cmn "github.com/cosmos/evm/precompiles/common" evmtypes "github.com/cosmos/evm/x/vm/types" "cosmossdk.io/math" @@ -49,13 +48,6 @@ func (p Precompile) Deposit( return nil, err } - // Add the entries to the statedb journal since the function signature of - // the associated Solidity interface payable. - p.SetBalanceChangeEntries( - cmn.NewBalanceChangeEntry(caller, depositedAmount, cmn.Add), - cmn.NewBalanceChangeEntry(p.Address(), depositedAmount, cmn.Sub), - ) - if err := p.EmitDepositEvent(ctx, stateDB, caller, depositedAmount.ToBig()); err != nil { return nil, err } diff --git a/precompiles/werc20/werc20.go b/precompiles/werc20/werc20.go index d8827b16..607eefe5 100644 --- a/precompiles/werc20/werc20.go +++ b/precompiles/werc20/werc20.go @@ -112,6 +112,9 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, err } + // Start the balance change handler before executing the precompile. + p.GetBalanceHandler().BeforeBalanceChange(ctx) + // This handles any out of gas errors that may occur during the execution of // a precompile tx or query. It avoids panics and returns the out of gas error so // the EVM can continue gracefully. @@ -139,9 +142,11 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ return nil, vm.ErrOutOfGas } - if err = p.AddJournalEntries(stateDB); err != nil { + // Process the native balance changes after the method execution. + if err := p.GetBalanceHandler().AfterBalanceChange(ctx, stateDB); err != nil { return nil, err } + return bz, nil } diff --git a/tests/integration/precompiles/distribution/test_integration.go b/tests/integration/precompiles/distribution/test_integration.go index 73284087..66b35089 100644 --- a/tests/integration/precompiles/distribution/test_integration.go +++ b/tests/integration/precompiles/distribution/test_integration.go @@ -2274,7 +2274,7 @@ func TestPrecompileIntegrationTestSuite(t *testing.T, create network.CreateEvmAp // set gas such that the internal keeper function called by the precompile fails out mid-execution txArgs.GasLimit = 80_000 - _, _, err = s.factory.CallContractAndCheckLogs( + _, txRes, err := s.factory.CallContractAndCheckLogs( s.keyring.GetPrivKey(0), txArgs, callArgs, @@ -2286,7 +2286,7 @@ func TestPrecompileIntegrationTestSuite(t *testing.T, create network.CreateEvmAp balRes, err := s.grpcHandler.GetBalanceFromBank(s.keyring.GetAccAddr(0), s.bondDenom) Expect(err).To(BeNil()) finalBalance := balRes.Balance - expectedGasCost := math.NewInt(79_416_000_000_000) + expectedGasCost := math.NewIntFromUint64(txRes.GasUsed).Mul(math.NewIntFromBigInt(txArgs.GasPrice)) Expect(finalBalance.Amount.Equal(initialBalance.Amount.Sub(expectedGasCost))).To(BeTrue(), "expected final balance must be initial balance minus any gas spent") res, err = s.grpcHandler.GetDelegationTotalRewards(s.keyring.GetAccAddr(0).String()) diff --git a/testutil/statedb.go b/testutil/statedb.go index f248ae4d..f8849985 100644 --- a/testutil/statedb.go +++ b/testutil/statedb.go @@ -1,15 +1,123 @@ package testutil import ( + "errors" + "maps" + "math/big" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" anteinterfaces "github.com/cosmos/evm/ante/interfaces" "github.com/cosmos/evm/x/vm/statedb" + "github.com/cosmos/evm/x/vm/types" sdk "github.com/cosmos/cosmos-sdk/types" ) +var ( + _ statedb.Keeper = &MockKeeper{} + ErrAddress common.Address = common.BigToAddress(big.NewInt(100)) + EmptyCodeHash = crypto.Keccak256(nil) +) + // NewStateDB returns a new StateDB for testing purposes. func NewStateDB(ctx sdk.Context, evmKeeper anteinterfaces.EVMKeeper) *statedb.StateDB { return statedb.New(ctx, evmKeeper, statedb.NewEmptyTxConfig(common.BytesToHash(ctx.HeaderHash()))) } + +type MockAcount struct { + account statedb.Account + states statedb.Storage +} + +type MockKeeper struct { + accounts map[common.Address]MockAcount + codes map[common.Hash][]byte +} + +func NewMockKeeper() *MockKeeper { + return &MockKeeper{ + accounts: make(map[common.Address]MockAcount), + codes: make(map[common.Hash][]byte), + } +} + +func (k MockKeeper) GetAccount(_ sdk.Context, addr common.Address) *statedb.Account { + acct, ok := k.accounts[addr] + if !ok { + return nil + } + return &acct.account +} + +func (k MockKeeper) GetState(_ sdk.Context, addr common.Address, key common.Hash) common.Hash { + return k.accounts[addr].states[key] +} + +func (k MockKeeper) GetCode(_ sdk.Context, codeHash common.Hash) []byte { + return k.codes[codeHash] +} + +func (k MockKeeper) ForEachStorage(_ sdk.Context, addr common.Address, cb func(key, value common.Hash) bool) { + if acct, ok := k.accounts[addr]; ok { + for k, v := range acct.states { + if !cb(k, v) { + return + } + } + } +} + +func (k MockKeeper) SetAccount(_ sdk.Context, addr common.Address, account statedb.Account) error { + if addr == ErrAddress { + return errors.New("mock db error") + } + acct, exists := k.accounts[addr] + if exists { + // update + acct.account = account + k.accounts[addr] = acct + } else { + k.accounts[addr] = MockAcount{account: account, states: make(statedb.Storage)} + } + return nil +} + +func (k MockKeeper) SetState(_ sdk.Context, addr common.Address, key common.Hash, value []byte) { + if acct, ok := k.accounts[addr]; ok { + acct.states[key] = common.BytesToHash(value) + } +} + +func (k MockKeeper) DeleteState(_ sdk.Context, addr common.Address, key common.Hash) { + if acct, ok := k.accounts[addr]; ok { + delete(acct.states, key) + } +} + +func (k MockKeeper) SetCode(_ sdk.Context, codeHash []byte, code []byte) { + k.codes[common.BytesToHash(codeHash)] = code +} + +func (k MockKeeper) DeleteCode(_ sdk.Context, codeHash []byte) { + delete(k.codes, common.BytesToHash(codeHash)) +} + +func (k MockKeeper) DeleteAccount(_ sdk.Context, addr common.Address) error { + if addr == ErrAddress { + return errors.New("mock db error") + } + old := k.accounts[addr] + delete(k.accounts, addr) + if !types.IsEmptyCodeHash(old.account.CodeHash) { + delete(k.codes, common.BytesToHash(old.account.CodeHash)) + } + return nil +} + +func (k MockKeeper) Clone() *MockKeeper { + accounts := maps.Clone(k.accounts) + codes := maps.Clone(k.codes) + return &MockKeeper{accounts, codes} +} diff --git a/x/vm/statedb/mock_test.go b/x/vm/statedb/mock_test.go deleted file mode 100644 index 130d771e..00000000 --- a/x/vm/statedb/mock_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package statedb_test - -import ( - "errors" - "maps" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" - - "github.com/cosmos/evm/x/vm/statedb" - "github.com/cosmos/evm/x/vm/types" - - sdk "github.com/cosmos/cosmos-sdk/types" -) - -var ( - _ statedb.Keeper = &MockKeeper{} - errAddress common.Address = common.BigToAddress(big.NewInt(100)) - emptyCodeHash = crypto.Keccak256(nil) -) - -type MockAcount struct { - account statedb.Account - states statedb.Storage -} - -type MockKeeper struct { - accounts map[common.Address]MockAcount - codes map[common.Hash][]byte -} - -func NewMockKeeper() *MockKeeper { - return &MockKeeper{ - accounts: make(map[common.Address]MockAcount), - codes: make(map[common.Hash][]byte), - } -} - -func (k MockKeeper) GetAccount(_ sdk.Context, addr common.Address) *statedb.Account { - acct, ok := k.accounts[addr] - if !ok { - return nil - } - return &acct.account -} - -func (k MockKeeper) GetState(_ sdk.Context, addr common.Address, key common.Hash) common.Hash { - return k.accounts[addr].states[key] -} - -func (k MockKeeper) GetCode(_ sdk.Context, codeHash common.Hash) []byte { - return k.codes[codeHash] -} - -func (k MockKeeper) ForEachStorage(_ sdk.Context, addr common.Address, cb func(key, value common.Hash) bool) { - if acct, ok := k.accounts[addr]; ok { - for k, v := range acct.states { - if !cb(k, v) { - return - } - } - } -} - -func (k MockKeeper) SetAccount(_ sdk.Context, addr common.Address, account statedb.Account) error { - if addr == errAddress { - return errors.New("mock db error") - } - acct, exists := k.accounts[addr] - if exists { - // update - acct.account = account - k.accounts[addr] = acct - } else { - k.accounts[addr] = MockAcount{account: account, states: make(statedb.Storage)} - } - return nil -} - -func (k MockKeeper) SetState(_ sdk.Context, addr common.Address, key common.Hash, value []byte) { - if acct, ok := k.accounts[addr]; ok { - acct.states[key] = common.BytesToHash(value) - } -} - -func (k MockKeeper) DeleteState(_ sdk.Context, addr common.Address, key common.Hash) { - if acct, ok := k.accounts[addr]; ok { - delete(acct.states, key) - } -} - -func (k MockKeeper) SetCode(_ sdk.Context, codeHash []byte, code []byte) { - k.codes[common.BytesToHash(codeHash)] = code -} - -func (k MockKeeper) DeleteCode(_ sdk.Context, codeHash []byte) { - delete(k.codes, common.BytesToHash(codeHash)) -} - -func (k MockKeeper) DeleteAccount(_ sdk.Context, addr common.Address) error { - if addr == errAddress { - return errors.New("mock db error") - } - old := k.accounts[addr] - delete(k.accounts, addr) - if !types.IsEmptyCodeHash(old.account.CodeHash) { - delete(k.codes, common.BytesToHash(old.account.CodeHash)) - } - return nil -} - -func (k MockKeeper) Clone() *MockKeeper { - accounts := maps.Clone(k.accounts) - codes := maps.Clone(k.codes) - return &MockKeeper{accounts, codes} -} diff --git a/x/vm/statedb/statedb_test.go b/x/vm/statedb/statedb_test.go index ff3d1338..6bcfe2dc 100644 --- a/x/vm/statedb/statedb_test.go +++ b/x/vm/statedb/statedb_test.go @@ -14,6 +14,7 @@ import ( "github.com/holiman/uint256" "github.com/stretchr/testify/suite" + "github.com/cosmos/evm/testutil" "github.com/cosmos/evm/x/vm/statedb" sdk "github.com/cosmos/cosmos-sdk/types" @@ -38,9 +39,9 @@ func (suite *StateDBTestSuite) TestAccount() { value2 := common.BigToHash(big.NewInt(4)) testCases := []struct { name string - malleate func(*statedb.StateDB) + malleate func(sdk.Context, *statedb.StateDB) }{ - {"non-exist account", func(db *statedb.StateDB) { + {"non-exist account", func(_ sdk.Context, db *statedb.StateDB) { suite.Require().Equal(false, db.Exist(address)) suite.Require().Equal(true, db.Empty(address)) suite.Require().Equal(common.U2560, db.GetBalance(address)) @@ -48,25 +49,25 @@ func (suite *StateDBTestSuite) TestAccount() { suite.Require().Equal(common.Hash{}, db.GetCodeHash(address)) suite.Require().Equal(uint64(0), db.GetNonce(address)) }}, - {"empty account", func(db *statedb.StateDB) { + {"empty account", func(ctx sdk.Context, db *statedb.StateDB) { db.CreateAccount(address) suite.Require().NoError(db.Commit()) - keeper := db.Keeper().(*MockKeeper) - acct := keeper.accounts[address] - suite.Require().Equal(statedb.NewEmptyAccount(), &acct.account) - suite.Require().Empty(acct.states) - suite.Require().False(acct.account.IsContract()) + keeper := db.Keeper().(*testutil.MockKeeper) + acct := keeper.GetAccount(ctx, address) + suite.Require().Equal(statedb.NewEmptyAccount(), acct) + suite.Require().Empty(acct.Balance) + suite.Require().False(acct.IsContract()) db = statedb.New(sdk.Context{}, keeper, emptyTxConfig) suite.Require().Equal(true, db.Exist(address)) suite.Require().Equal(true, db.Empty(address)) suite.Require().Equal(common.U2560, db.GetBalance(address)) suite.Require().Equal([]byte(nil), db.GetCode(address)) - suite.Require().Equal(common.BytesToHash(emptyCodeHash), db.GetCodeHash(address)) + suite.Require().Equal(common.BytesToHash(testutil.EmptyCodeHash), db.GetCodeHash(address)) suite.Require().Equal(uint64(0), db.GetNonce(address)) }}, - {"suicide", func(db *statedb.StateDB) { + {"suicide", func(ctx sdk.Context, db *statedb.StateDB) { // non-exist account. db.SelfDestruct(address) suite.Require().False(db.HasSelfDestructed(address)) @@ -99,22 +100,24 @@ func (suite *StateDBTestSuite) TestAccount() { suite.Require().False(db.Exist(address)) // and cleared in keeper too - keeper := db.Keeper().(*MockKeeper) - suite.Require().Empty(keeper.accounts) - suite.Require().Empty(keeper.codes) + keeper := db.Keeper().(*testutil.MockKeeper) + keeper.ForEachStorage(ctx, address, func(key, value common.Hash) bool { + return len(value) == 0 + }) }}, } for _, tc := range testCases { suite.Run(tc.name, func() { - keeper := NewMockKeeper() + ctx := sdk.Context{} + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) - tc.malleate(db) + tc.malleate(ctx, db) }) } } func (suite *StateDBTestSuite) TestAccountOverride() { - keeper := NewMockKeeper() + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) // test balance carry over when overwritten amount := uint256.NewInt(1) @@ -138,16 +141,16 @@ func (suite *StateDBTestSuite) TestDBError() { malleate func(vm.StateDB) }{ {"set account", func(db vm.StateDB) { - db.SetNonce(errAddress, 1, tracing.NonceChangeUnspecified) + db.SetNonce(testutil.ErrAddress, 1, tracing.NonceChangeUnspecified) }}, {"delete account", func(db vm.StateDB) { - db.SetNonce(errAddress, 1, tracing.NonceChangeUnspecified) - db.SelfDestruct(errAddress) - suite.Require().True(db.HasSelfDestructed(errAddress)) + db.SetNonce(testutil.ErrAddress, 1, tracing.NonceChangeUnspecified) + db.SelfDestruct(testutil.ErrAddress) + suite.Require().True(db.HasSelfDestructed(testutil.ErrAddress)) }}, } for _, tc := range testCases { - db := statedb.New(sdk.Context{}, NewMockKeeper(), emptyTxConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), emptyTxConfig) tc.malleate(db) suite.Require().Error(db.Commit()) } @@ -179,7 +182,8 @@ func (suite *StateDBTestSuite) TestBalance() { for _, tc := range testCases { suite.Run(tc.name, func() { - keeper := NewMockKeeper() + ctx := sdk.Context{} + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) tc.malleate(db) @@ -187,7 +191,7 @@ func (suite *StateDBTestSuite) TestBalance() { suite.Require().Equal(tc.expBalance, db.GetBalance(address)) suite.Require().NoError(db.Commit()) // check committed balance too - suite.Require().Equal(tc.expBalance, keeper.accounts[address].account.Balance) + suite.Require().Equal(tc.expBalance, keeper.GetAccount(ctx, address).Balance) }) } } @@ -233,13 +237,16 @@ func (suite *StateDBTestSuite) TestState() { for _, tc := range testCases { suite.Run(tc.name, func() { - keeper := NewMockKeeper() + ctx := sdk.Context{} + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) tc.malleate(db) suite.Require().NoError(db.Commit()) // check committed states in keeper - suite.Require().Equal(tc.expStates, keeper.accounts[address].states) + for _, key := range tc.expStates.SortedKeys() { + suite.Require().Equal(tc.expStates[key], keeper.GetState(ctx, address, key)) + } // check ForEachStorage db = statedb.New(sdk.Context{}, keeper, emptyTxConfig) @@ -266,7 +273,7 @@ func (suite *StateDBTestSuite) TestCode() { {"non-exist account", func(vm.StateDB) {}, nil, common.Hash{}}, {"empty account", func(db vm.StateDB) { db.CreateAccount(address) - }, nil, common.BytesToHash(emptyCodeHash)}, + }, nil, common.BytesToHash(testutil.EmptyCodeHash)}, {"set code", func(db vm.StateDB) { db.SetCode(address, code) }, code, codeHash}, @@ -274,7 +281,7 @@ func (suite *StateDBTestSuite) TestCode() { for _, tc := range testCases { suite.Run(tc.name, func() { - keeper := NewMockKeeper() + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) tc.malleate(db) @@ -341,7 +348,7 @@ func (suite *StateDBTestSuite) TestRevertSnapshot() { for _, tc := range testCases { suite.Run(tc.name, func() { ctx := sdk.Context{} - keeper := NewMockKeeper() + keeper := testutil.NewMockKeeper() { // do some arbitrary changes to the storage @@ -379,7 +386,7 @@ func (suite *StateDBTestSuite) TestNestedSnapshot() { value1 := common.BigToHash(big.NewInt(1)) value2 := common.BigToHash(big.NewInt(2)) - db := statedb.New(sdk.Context{}, NewMockKeeper(), emptyTxConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), emptyTxConfig) rev1 := db.Snapshot() db.SetState(address, key, value1) @@ -396,7 +403,7 @@ func (suite *StateDBTestSuite) TestNestedSnapshot() { } func (suite *StateDBTestSuite) TestInvalidSnapshotId() { - db := statedb.New(sdk.Context{}, NewMockKeeper(), emptyTxConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), emptyTxConfig) suite.Require().Panics(func() { db.RevertToSnapshot(1) }) @@ -486,7 +493,7 @@ func (suite *StateDBTestSuite) TestAccessList() { } for _, tc := range testCases { - db := statedb.New(sdk.Context{}, NewMockKeeper(), emptyTxConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), emptyTxConfig) tc.malleate(db) } } @@ -499,7 +506,7 @@ func (suite *StateDBTestSuite) TestLog() { txHash, 1, 1, ) - db := statedb.New(sdk.Context{}, NewMockKeeper(), txConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), txConfig) data := []byte("hello world") db.AddLog(ðtypes.Log{ Address: address, @@ -551,7 +558,7 @@ func (suite *StateDBTestSuite) TestRefund() { }, 0, true}, } for _, tc := range testCases { - db := statedb.New(sdk.Context{}, NewMockKeeper(), emptyTxConfig) + db := statedb.New(sdk.Context{}, testutil.NewMockKeeper(), emptyTxConfig) if !tc.expPanic { tc.malleate(db) suite.Require().Equal(tc.expRefund, db.GetRefund()) @@ -564,12 +571,14 @@ func (suite *StateDBTestSuite) TestRefund() { } func (suite *StateDBTestSuite) TestIterateStorage() { + ctx := sdk.Context{} + key1 := common.BigToHash(big.NewInt(1)) value1 := common.BigToHash(big.NewInt(2)) key2 := common.BigToHash(big.NewInt(3)) value2 := common.BigToHash(big.NewInt(4)) - keeper := NewMockKeeper() + keeper := testutil.NewMockKeeper() db := statedb.New(sdk.Context{}, keeper, emptyTxConfig) db.SetState(address, key1, value1) db.SetState(address, key2, value2) @@ -581,7 +590,9 @@ func (suite *StateDBTestSuite) TestIterateStorage() { storage := CollectContractStorage(db) suite.Require().Equal(2, len(storage)) - suite.Require().Equal(keeper.accounts[address].states, storage) + for _, key := range storage.SortedKeys() { + suite.Require().Equal(keeper.GetState(ctx, address, key), storage[key]) + } // break early iteration storage = make(statedb.Storage)