diff --git a/modules/network/keeper/fixtures_test.go b/modules/network/keeper/fixtures_test.go new file mode 100644 index 0000000..1959882 --- /dev/null +++ b/modules/network/keeper/fixtures_test.go @@ -0,0 +1,122 @@ +package keeper + +import ( + "context" + "maps" + "slices" + "strings" + "time" + + "cosmossdk.io/math" + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + cmttypes "github.com/cometbft/cometbft/types" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/libp2p/go-libp2p/core/crypto" + + rollkittypes "github.com/rollkit/rollkit/types" + + "github.com/rollkit/go-execution-abci/modules/network/types" +) + +func HeaderFixture(signer *ed25519.PrivKey, appHash []byte, mutators ...func(*rollkittypes.SignedHeader)) *rollkittypes.SignedHeader { + header := rollkittypes.Header{ + BaseHeader: rollkittypes.BaseHeader{ + Height: 10, + Time: uint64(time.Now().UnixNano()), + ChainID: "testing", + }, + Version: rollkittypes.Version{Block: 1, App: 1}, + ProposerAddress: signer.PubKey().Address(), + AppHash: appHash, + DataHash: []byte("data_hash"), + ConsensusHash: []byte("consensus_hash"), + ValidatorHash: []byte("validator_hash"), + } + signedHeader := &rollkittypes.SignedHeader{ + Header: header, + Signature: appHash, + Signer: rollkittypes.Signer{PubKey: must(crypto.UnmarshalEd25519PublicKey(signer.PubKey().Bytes()))}, + } + for _, m := range mutators { + m(signedHeader) + } + return signedHeader +} + +func VoteFixture(myAppHash []byte, voteSigner *ed25519.PrivKey, mutators ...func(vote *cmtproto.Vote)) *cmtproto.Vote { + const chainID = "testing" + + vote := &cmtproto.Vote{ + Type: cmtproto.PrecommitType, + Height: 10, + Round: 0, + BlockID: cmtproto.BlockID{Hash: myAppHash, PartSetHeader: cmtproto.PartSetHeader{Total: 1, Hash: myAppHash}}, + Timestamp: time.Now().UTC(), + ValidatorAddress: voteSigner.PubKey().Address(), + ValidatorIndex: 0, + } + vote.Signature = must(voteSigner.Sign(cmttypes.VoteSignBytes(chainID, vote))) + + for _, m := range mutators { + m(vote) + } + return vote +} + +var _ types.StakingKeeper = &MockStakingKeeper{} + +type MockStakingKeeper struct { + activeSet map[string]stakingtypes.Validator +} + +func NewMockStakingKeeper() MockStakingKeeper { + return MockStakingKeeper{ + activeSet: make(map[string]stakingtypes.Validator), + } +} + +func (m *MockStakingKeeper) SetValidator(ctx context.Context, validator stakingtypes.Validator) error { + m.activeSet[validator.GetOperator()] = validator + return nil +} +func (m MockStakingKeeper) GetAllValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) { + return slices.SortedFunc(maps.Values(m.activeSet), func(v1 stakingtypes.Validator, v2 stakingtypes.Validator) int { + return strings.Compare(v1.OperatorAddress, v2.OperatorAddress) + }), nil +} +func (m MockStakingKeeper) GetValidator(ctx context.Context, addr sdk.ValAddress) (validator stakingtypes.Validator, err error) { + // First try to find the validator by address + validator, found := m.activeSet[addr.String()] + if found { + return validator, nil + } + + //// If not found by address, try to find by public key address + //addrStr := addr.String() + //for valAddrStr, pubKey := range m.pubKeys { + // if pubKey.Address().String() == addrStr { + // validator, found = m.activeSet[valAddrStr] + // if found { + // return validator, nil + // } + // } + //} + + return validator, sdkerrors.ErrNotFound +} + +func (m MockStakingKeeper) GetLastValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) { + for _, validator := range m.activeSet { + if validator.IsBonded() { // Assuming IsBonded() identifies if a validator is in the last validators + validators = append(validators, validator) + } + } + return +} + +func (m MockStakingKeeper) GetLastTotalPower(ctx context.Context) (math.Int, error) { + return math.NewInt(int64(len(m.activeSet))), nil +} diff --git a/modules/network/keeper/msg_server.go b/modules/network/keeper/msg_server.go index 48176c7..1e89a0f 100644 --- a/modules/network/keeper/msg_server.go +++ b/modules/network/keeper/msg_server.go @@ -1,16 +1,19 @@ package keeper import ( + "bytes" "context" - "errors" "fmt" "cosmossdk.io/collections" sdkerr "cosmossdk.io/errors" "cosmossdk.io/math" + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + cmttypes "github.com/cometbft/cometbft/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + "github.com/cosmos/gogoproto/proto" "github.com/rollkit/go-execution-abci/modules/network/types" ) @@ -30,33 +33,53 @@ var _ types.MsgServer = msgServer{} func (k msgServer) Attest(goCtx context.Context, msg *types.MsgAttest) (*types.MsgAttestResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if k.GetParams(ctx).SignMode == types.SignMode_SIGN_MODE_CHECKPOINT && - !k.IsCheckpointHeight(ctx, msg.Height) { - return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "height %d is not a checkpoint", msg.Height) - } - has, err := k.IsInAttesterSet(ctx, msg.Validator) - if err != nil { - return nil, sdkerr.Wrapf(err, "in attester set") + if err := k.validateAttestation(ctx, msg); err != nil { + return nil, err } - if !has { - return nil, sdkerr.Wrapf(sdkerrors.ErrUnauthorized, "validator %s not in attester set", msg.Validator) + // can vote only for the last epoch + if delta := ctx.BlockHeight() - msg.Height; delta < 0 || delta > int64(k.GetParams(ctx).EpochLength) { + return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "exceeded voting window: %d blocks", delta) } - index, found := k.GetValidatorIndex(ctx, msg.Validator) + valIndexPos, found := k.GetValidatorIndex(ctx, msg.Validator) if !found { return nil, sdkerr.Wrapf(sdkerrors.ErrNotFound, "validator index not found for %s", msg.Validator) } - // todo (Alex): we need to set a limit to not have validators attest old blocks. Also make sure that this relates with - // the retention period for pruning - bitmap, err := k.GetAttestationBitmap(ctx, msg.Height) - if err != nil && !errors.Is(err, collections.ErrNotFound) { - return nil, sdkerr.Wrap(err, "get attestation bitmap") + vote, err := k.verifyVote(ctx, msg) + if err != nil { + return nil, err } - if bitmap == nil { + + if err := k.updateAttestationBitmap(ctx, msg, valIndexPos); err != nil { + return nil, sdkerr.Wrap(err, "update attestation bitmap") + } + + if err := k.SetSignature(ctx, msg.Height, msg.Validator, vote.Signature); err != nil { + return nil, sdkerr.Wrap(err, "store signature") + } + + if err := k.updateEpochBitmap(ctx, uint64(msg.Height), valIndexPos); err != nil { + return nil, err + } + + // Emit event + ctx.EventManager().EmitEvent( + sdk.NewEvent( + types.TypeMsgAttest, + sdk.NewAttribute("validator", msg.Validator), + sdk.NewAttribute("height", math.NewInt(msg.Height).String()), + ), + ) + return &types.MsgAttestResponse{}, nil +} + +func (k msgServer) updateEpochBitmap(ctx sdk.Context, votedEpoch uint64, index uint16) error { + epochBitmap := k.GetEpochBitmap(ctx, votedEpoch) + if epochBitmap == nil { validators, err := k.stakingKeeper.GetLastValidators(ctx) if err != nil { - return nil, err + return err } numValidators := 0 for _, v := range validators { @@ -64,32 +87,43 @@ func (k msgServer) Attest(goCtx context.Context, msg *types.MsgAttest) (*types.M numValidators++ } } - bitmap = k.bitmapHelper.NewBitmap(numValidators) + epochBitmap = k.bitmapHelper.NewBitmap(numValidators) } - - if k.bitmapHelper.IsSet(bitmap, int(index)) { - return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "validator %s already attested for height %d", msg.Validator, msg.Height) + k.bitmapHelper.SetBit(epochBitmap, int(index)) + if err := k.SetEpochBitmap(ctx, votedEpoch, epochBitmap); err != nil { + return sdkerr.Wrap(err, "set epoch bitmap") } + return nil +} - // TODO: Verify the vote signature here once we implement vote parsing +// validateAttestation validates the attestation request +func (k msgServer) validateAttestation(ctx sdk.Context, msg *types.MsgAttest) error { + if k.GetParams(ctx).SignMode == types.SignMode_SIGN_MODE_CHECKPOINT && + !k.IsCheckpointHeight(ctx, msg.Height) { + return sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "height %d is not a checkpoint", msg.Height) + } - // Set the bit - k.bitmapHelper.SetBit(bitmap, int(index)) - if err := k.SetAttestationBitmap(ctx, msg.Height, bitmap); err != nil { - return nil, sdkerr.Wrap(err, "set attestation bitmap") + has, err := k.IsInAttesterSet(ctx, msg.Validator) + if err != nil { + return sdkerr.Wrapf(err, "in attester set") } + if !has { + return sdkerr.Wrapf(sdkerrors.ErrUnauthorized, "validator %s not in attester set", msg.Validator) + } + return nil +} - // Store signature using the new collection method - if err := k.SetSignature(ctx, msg.Height, msg.Validator, msg.Vote); err != nil { - return nil, sdkerr.Wrap(err, "store signature") +// updateAttestationBitmap handles bitmap operations for attestation +func (k msgServer) updateAttestationBitmap(ctx sdk.Context, msg *types.MsgAttest, index uint16) error { + bitmap, err := k.GetAttestationBitmap(ctx, msg.Height) + if err != nil && !sdkerr.IsOf(err, collections.ErrNotFound) { + return err } - epoch := k.GetCurrentEpoch(ctx) - epochBitmap := k.GetEpochBitmap(ctx, epoch) - if epochBitmap == nil { + if bitmap == nil { validators, err := k.stakingKeeper.GetLastValidators(ctx) if err != nil { - return nil, err + return err } numValidators := 0 for _, v := range validators { @@ -97,23 +131,57 @@ func (k msgServer) Attest(goCtx context.Context, msg *types.MsgAttest) (*types.M numValidators++ } } - epochBitmap = k.bitmapHelper.NewBitmap(numValidators) + bitmap = k.bitmapHelper.NewBitmap(numValidators) } - k.bitmapHelper.SetBit(epochBitmap, int(index)) - if err := k.SetEpochBitmap(ctx, epoch, epochBitmap); err != nil { - return nil, sdkerr.Wrap(err, "set epoch bitmap") + + if k.bitmapHelper.IsSet(bitmap, int(index)) { + return sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "validator %s already attested for height %d", msg.Validator, msg.Height) } - // Emit event - ctx.EventManager().EmitEvent( - sdk.NewEvent( - types.TypeMsgAttest, - sdk.NewAttribute("validator", msg.Validator), - sdk.NewAttribute("height", math.NewInt(msg.Height).String()), - ), - ) + k.bitmapHelper.SetBit(bitmap, int(index)) - return &types.MsgAttestResponse{}, nil + if err := k.SetAttestationBitmap(ctx, msg.Height, bitmap); err != nil { + return sdkerr.Wrap(err, "set attestation bitmap") + } + return nil +} + +// verifyVote verifies the vote signature and block hash +func (k msgServer) verifyVote(ctx sdk.Context, msg *types.MsgAttest) (*cmtproto.Vote, error) { + var vote cmtproto.Vote + if err := proto.Unmarshal(msg.Vote, &vote); err != nil { + return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "unmarshal vote: %s", err) + } + if msg.Height != vote.Height { + return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "vote height does not match attestation height") + } + if len(vote.Signature) == 0 { + return nil, sdkerrors.ErrInvalidRequest.Wrap("empty signature") + } + + // todo (Alex): validate app hash match, vote clock drift + + valAddress, err := sdk.ValAddressFromBech32(msg.Validator) + if err != nil { + return nil, sdkerr.Wrap(err, "invalid validator address") + } + validator, err := k.stakingKeeper.GetValidator(ctx, valAddress) + if err != nil { + return nil, sdkerr.Wrapf(err, "get validator") + } + pubKey, err := validator.ConsPubKey() + if err != nil { + return nil, sdkerr.Wrapf(err, "pubkey") + } + if !bytes.Equal(pubKey.Address().Bytes(), vote.ValidatorAddress) { + return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "pubkey address does not match validator address") + } + voteSignBytes := cmttypes.VoteSignBytes(ctx.ChainID(), &vote) + if !pubKey.VerifySignature(voteSignBytes, vote.Signature) { + return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "invalid vote signature") + } + + return &vote, nil } // JoinAttesterSet handles MsgJoinAttesterSet diff --git a/modules/network/keeper/msg_server_test.go b/modules/network/keeper/msg_server_test.go index 7f8a992..c91b9b3 100644 --- a/modules/network/keeper/msg_server_test.go +++ b/modules/network/keeper/msg_server_test.go @@ -1,17 +1,16 @@ package keeper import ( - "context" - "maps" - "slices" - "strings" + "crypto/sha256" "testing" "time" + "cosmossdk.io/collections" "cosmossdk.io/log" - "cosmossdk.io/math" storetypes "cosmossdk.io/store/types" cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + cmttypes "github.com/cometbft/cometbft/types" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" "github.com/cosmos/cosmos-sdk/runtime" "github.com/cosmos/cosmos-sdk/testutil/integration" sdk "github.com/cosmos/cosmos-sdk/types" @@ -19,17 +18,24 @@ import ( moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/cosmos/gogoproto/proto" + ds "github.com/ipfs/go-datastore" + kt "github.com/ipfs/go-datastore/keytransform" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + rollnode "github.com/rollkit/rollkit/node" + rstore "github.com/rollkit/rollkit/pkg/store" + rollkittypes "github.com/rollkit/rollkit/types" + "github.com/rollkit/go-execution-abci/modules/network/types" ) func TestJoinAttesterSet(t *testing.T) { - myValAddr := sdk.ValAddress("validator4") + myValAddr := sdk.ValAddress("validator") type testCase struct { - setup func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) + setup func(t *testing.T, env testEnv) msg *types.MsgJoinAttesterSet expErr error expSet bool @@ -37,150 +43,379 @@ func TestJoinAttesterSet(t *testing.T) { tests := map[string]testCase{ "valid": { - setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) { + setup: func(t *testing.T, env testEnv) { validator := stakingtypes.Validator{ OperatorAddress: myValAddr.String(), Status: stakingtypes.Bonded, } - err := sk.SetValidator(ctx, validator) + err := env.SK.SetValidator(env.Ctx, validator) require.NoError(t, err, "failed to set validator") }, msg: &types.MsgJoinAttesterSet{Validator: myValAddr.String()}, expSet: true, }, "invalid_addr": { - setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) {}, + setup: func(t *testing.T, env testEnv) {}, msg: &types.MsgJoinAttesterSet{Validator: "invalidAddr"}, expErr: sdkerrors.ErrInvalidAddress, }, "val not exists": { - setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) {}, + setup: func(t *testing.T, env testEnv) {}, msg: &types.MsgJoinAttesterSet{Validator: myValAddr.String()}, expErr: sdkerrors.ErrNotFound, }, "val not bonded": { - setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) { + setup: func(t *testing.T, env testEnv) { validator := stakingtypes.Validator{ OperatorAddress: myValAddr.String(), Status: stakingtypes.Unbonded, // Validator is not bonded } - err := sk.SetValidator(ctx, validator) + err := env.SK.SetValidator(env.Ctx, validator) require.NoError(t, err, "failed to set validator") }, msg: &types.MsgJoinAttesterSet{Validator: myValAddr.String()}, expErr: sdkerrors.ErrInvalidRequest, }, "already set": { - setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) { + setup: func(t *testing.T, env testEnv) { validator := stakingtypes.Validator{ OperatorAddress: myValAddr.String(), Status: stakingtypes.Bonded, } - require.NoError(t, sk.SetValidator(ctx, validator)) - require.NoError(t, keeper.SetAttesterSetMember(ctx, myValAddr.String())) + require.NoError(t, env.SK.SetValidator(env.Ctx, validator)) + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, myValAddr.String())) }, msg: &types.MsgJoinAttesterSet{Validator: myValAddr.String()}, expErr: sdkerrors.ErrInvalidRequest, expSet: true, }, - //{ - // name: "failed to set attester set member", - // setup: func(t *testing.T, ctx sdk.Context, keeper *Keeper, sk *MockStakingKeeper) { - // validatorAddr := sdk.ValAddress([]byte("validator5")) - // validator := stakingtypes.Validator{ - // OperatorAddress: validatorAddr.String(), - // Status: stakingtypes.Bonded, - // } - // err := sk.SetValidator(ctx, validator) - // require.NoError(t, err, "failed to set validator") - // keeper.forceError = true - // }, - // msg: &types.MsgJoinAttesterSet{ - // Validator: "validator5", - // }, - // expErr: sdkerrors.ErrInternal, - // expectResponse: false, - //}, } - for name, spec := range tests { t.Run(name, func(t *testing.T) { - sk := NewMockStakingKeeper() - - cdc := moduletestutil.MakeTestEncodingConfig().Codec + // Setup test environment + env := setupTestEnv(t, 10) - keys := storetypes.NewKVStoreKeys(types.StoreKey) - - logger := log.NewTestLogger(t) - cms := integration.CreateMultiStore(keys, logger) - authority := authtypes.NewModuleAddress("gov") - keeper := NewKeeper(cdc, runtime.NewKVStoreService(keys[types.StoreKey]), sk, nil, nil, authority.String()) - server := msgServer{Keeper: keeper} - ctx := sdk.NewContext(cms, cmtproto.Header{ChainID: "test-chain", Time: time.Now().UTC(), Height: 10}, false, logger). - WithContext(t.Context()) - - spec.setup(t, ctx, &keeper, &sk) + // Apply test-specific setup + spec.setup(t, env) // when - rsp, err := server.JoinAttesterSet(ctx, spec.msg) + rsp, err := env.Server.JoinAttesterSet(env.Ctx, spec.msg) + // then if spec.expErr != nil { require.ErrorIs(t, err, spec.expErr) require.Nil(t, rsp) - exists, gotErr := keeper.AttesterSet.Has(ctx, spec.msg.Validator) + exists, gotErr := env.Keeper.AttesterSet.Has(env.Ctx, spec.msg.Validator) require.NoError(t, gotErr) assert.Equal(t, exists, spec.expSet) return } require.NoError(t, err) require.NotNil(t, rsp) - exists, gotErr := keeper.AttesterSet.Has(ctx, spec.msg.Validator) + exists, gotErr := env.Keeper.AttesterSet.Has(env.Ctx, spec.msg.Validator) require.NoError(t, gotErr) assert.True(t, exists) }) } } -var _ types.StakingKeeper = &MockStakingKeeper{} +func TestAttest(t *testing.T) { + const epochLength = 10 + var ( + myHash = sha256.Sum256([]byte("app_hash")) + myAppHash = myHash[:] + voteSigner = ed25519.GenPrivKey() + valAddrStr = sdk.ValAddress(voteSigner.PubKey().Address()).String() + ) -type MockStakingKeeper struct { - activeSet map[string]stakingtypes.Validator -} + // Setup test environment with block store + env := setupTestEnv(t, 2*epochLength) + + // Set up validator + validator, err := stakingtypes.NewValidator(valAddrStr, voteSigner.PubKey(), stakingtypes.Description{}) + require.NoError(t, err) + validator.Status = stakingtypes.Bonded + require.NoError(t, env.SK.SetValidator(env.Ctx, validator)) + + // Save block data + signedHeader := HeaderFixture(voteSigner, myAppHash) + data := &rollkittypes.Data{Txs: rollkittypes.Txs{}} + var signature rollkittypes.Signature + require.NoError(t, env.BlockStore.SaveBlockData(env.Ctx, signedHeader, data, &signature)) -func NewMockStakingKeeper() MockStakingKeeper { - return MockStakingKeeper{ - activeSet: make(map[string]stakingtypes.Validator), + var ( + validVote = VoteFixture(myAppHash, voteSigner) + validVoteBz = must(proto.Marshal(validVote)) + ) + parentCtx := env.Ctx + + specs := map[string]struct { + setup func(t *testing.T, env testEnv) sdk.Context + msg func(t *testing.T) *types.MsgAttest + expErr error + }{ + "valid attestation": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength, Vote: validVoteBz} + }, + }, + "invalid vote content": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength, Vote: []byte("not a valid proto vote")} + }, + expErr: sdkerrors.ErrInvalidRequest, + }, + "validator not in attester set": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength, Vote: validVoteBz} + }, + expErr: sdkerrors.ErrUnauthorized, + }, + "invalid signature": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx + }, + msg: func(t *testing.T) *types.MsgAttest { + invalidVote := VoteFixture(myAppHash, voteSigner, func(vote *cmtproto.Vote) { + vote.Signature = []byte("invalid signature") + }) + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength, Vote: must(proto.Marshal(invalidVote))} + }, + expErr: sdkerrors.ErrInvalidRequest, + }, + "not a checkpoint height": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength + 1, Vote: validVoteBz} + }, + expErr: sdkerrors.ErrInvalidRequest, + }, + "vote window expired": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx.WithBlockHeight(2*epochLength + 1) + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: epochLength, Vote: validVoteBz} + }, + expErr: sdkerrors.ErrInvalidRequest, + }, + "voting for a future epoch": { + setup: func(t *testing.T, env testEnv) sdk.Context { + require.NoError(t, env.Keeper.SetAttesterSetMember(env.Ctx, valAddrStr)) + require.NoError(t, env.Keeper.SetValidatorIndex(env.Ctx, valAddrStr, 0, 100)) + return env.Ctx.WithBlockHeight(2 * epochLength) + }, + msg: func(t *testing.T) *types.MsgAttest { + return &types.MsgAttest{Validator: valAddrStr, Height: 3 * epochLength, Vote: validVoteBz} + }, + expErr: sdkerrors.ErrInvalidRequest, + }, } -} + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // Create a new environment for each test case with a cached context + testEnv := env + testEnv.Ctx, _ = parentCtx.CacheContext() + ctx := spec.setup(t, testEnv) -func (m *MockStakingKeeper) SetValidator(ctx context.Context, validator stakingtypes.Validator) error { - m.activeSet[validator.GetOperator()] = validator - return nil + // when + srcMsg := spec.msg(t) + gotRsp, gotErr := testEnv.Server.Attest(ctx, srcMsg) + // then + if spec.expErr != nil { + require.Error(t, gotErr) + require.ErrorIs(t, gotErr, spec.expErr) + // and ensure the signature is not stored + _, err := testEnv.Keeper.GetSignature(ctx, srcMsg.Height, valAddrStr) + assert.ErrorIs(t, err, collections.ErrNotFound) + return + } + + require.NoError(t, gotErr) + require.NotNil(t, gotRsp) + + // and attestation marked + bitmap, gotErr := testEnv.Keeper.GetAttestationBitmap(ctx, srcMsg.Height) + require.NoError(t, gotErr) + require.NotEmpty(t, bitmap) + require.Equal(t, byte(1), bitmap[0]) + + // and the signature was stored properly + gotSig, err := testEnv.Keeper.GetSignature(ctx, srcMsg.Height, valAddrStr) + require.NoError(t, err) + var vote cmtproto.Vote + require.NoError(t, proto.Unmarshal(srcMsg.Vote, &vote)) + assert.Equal(t, vote.Signature, gotSig) + }) + } } -func (m MockStakingKeeper) GetAllValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) { - return slices.SortedFunc(maps.Values(m.activeSet), func(v1 stakingtypes.Validator, v2 stakingtypes.Validator) int { - return strings.Compare(v1.OperatorAddress, v2.OperatorAddress) - }), nil -} -func (m MockStakingKeeper) GetValidator(ctx context.Context, addr sdk.ValAddress) (validator stakingtypes.Validator, err error) { - validator, found := m.activeSet[addr.String()] - if !found { - return validator, sdkerrors.ErrNotFound +func TestVerifyVote(t *testing.T) { + var ( + myHash = sha256.Sum256([]byte("app_hash")) + myAppHash = myHash[:] + validatorPrivKey = ed25519.GenPrivKey() + valAddrStr = sdk.ValAddress(validatorPrivKey.PubKey().Address()).String() + otherValidatorPrivKey = ed25519.GenPrivKey() + ) + + // Setup test environment with block store + env := setupTestEnv(t, 10) + + // Set up validator + validator, err := stakingtypes.NewValidator(valAddrStr, validatorPrivKey.PubKey(), stakingtypes.Description{}) + require.NoError(t, err) + validator.Status = stakingtypes.Bonded + require.NoError(t, env.SK.SetValidator(env.Ctx, validator)) + + // Save block data + header := HeaderFixture(validatorPrivKey, myAppHash) + var signature rollkittypes.Signature + require.NoError(t, env.BlockStore.SaveBlockData(env.Ctx, header, &rollkittypes.Data{}, &signature)) + + parentCtx := env.Ctx + + testCases := map[string]struct { + voteFn func(t *testing.T) *cmtproto.Vote + sender string + expErr error + }{ + "valid vote": { + voteFn: func(t *testing.T) *cmtproto.Vote { + return VoteFixture(myAppHash, validatorPrivKey) + }, + sender: valAddrStr, + }, + "block data not found": { + voteFn: func(t *testing.T) *cmtproto.Vote { + return VoteFixture(myAppHash, validatorPrivKey, func(vote *cmtproto.Vote) { + vote.Height++ + vote.Signature = must(validatorPrivKey.Sign(cmttypes.VoteSignBytes("testing", vote))) + }) + }, + sender: valAddrStr, + expErr: sdkerrors.ErrInvalidRequest, + }, + "validator not found": { + voteFn: func(t *testing.T) *cmtproto.Vote { + return VoteFixture(myAppHash, ed25519.GenPrivKey()) + }, + expErr: sdkerrors.ErrUnauthorized, + sender: sdk.ValAddress(otherValidatorPrivKey.PubKey().Address()).String(), + }, + "invalid vote signature": { + voteFn: func(t *testing.T) *cmtproto.Vote { + return VoteFixture(myAppHash, validatorPrivKey, func(vote *cmtproto.Vote) { + vote.Signature = []byte("invalid signature") + }) + }, + sender: valAddrStr, + expErr: sdkerrors.ErrInvalidRequest, + }, + "invalid sender": { + voteFn: func(t *testing.T) *cmtproto.Vote { + return VoteFixture(myAppHash, validatorPrivKey) + }, + sender: sdk.ValAddress(otherValidatorPrivKey.PubKey().Address()).String(), + expErr: sdkerrors.ErrUnauthorized, + }, + } + for name, spec := range testCases { + t.Run(name, func(t *testing.T) { + ctx, _ := parentCtx.CacheContext() + + // when + vote, err := env.Server.verifyVote(ctx, &types.MsgAttest{ + Height: 10, + Validator: spec.sender, + Vote: must(proto.Marshal(spec.voteFn(t))), + }) + + // then + if spec.expErr != nil { + require.Error(t, err) + require.ErrorIs(t, err, spec.expErr) + require.Nil(t, vote) + return + } + require.NoError(t, err) + require.NotNil(t, vote) + }) } - return validator, nil } -func (m MockStakingKeeper) GetLastValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) { - for _, validator := range m.activeSet { - if validator.IsBonded() { // Assuming IsBonded() identifies if a validator is in the last validators - validators = append(validators, validator) - } +// testEnv contains all the common components needed for testing +type testEnv struct { + Ctx sdk.Context + Keeper Keeper + Server msgServer + SK MockStakingKeeper + BlockStore rstore.Store +} + +func setupTestEnv(t *testing.T, height int64) testEnv { + t.Helper() + // Set up codec and store + cdc := moduletestutil.MakeTestEncodingConfig().Codec + keys := storetypes.NewKVStoreKeys(types.StoreKey) + cms := integration.CreateMultiStore(keys, log.NewTestLogger(t)) + + sk := NewMockStakingKeeper() + rollkitPrefixStore := kt.Wrap(ds.NewMapDatastore(), &kt.PrefixTransform{ + Prefix: ds.NewKey(rollnode.RollkitPrefix), + }) + bs := rstore.New(rollkitPrefixStore) + + authority := authtypes.NewModuleAddress("gov") + keeper := NewKeeper(cdc, runtime.NewKVStoreService(keys[types.StoreKey]), &sk, nil, nil, authority.String()) + + ctx := sdk.NewContext(cms, cmtproto.Header{ + ChainID: "testing", + Time: time.Now().UTC(), + Height: height, + }, false, log.NewTestLogger(t)).WithContext(t.Context()) + + server := msgServer{Keeper: keeper} + + params := types.DefaultParams() + params.EpochLength = 10 // test default + require.NoError(t, keeper.SetParams(ctx, params)) + + return testEnv{ + Ctx: ctx, + Keeper: keeper, + Server: server, + SK: sk, + BlockStore: bs, } - return } -func (m MockStakingKeeper) GetLastTotalPower(ctx context.Context) (math.Int, error) { - return math.NewInt(int64(len(m.activeSet))), nil +func must[T any](r T, err error) T { + if err != nil { + panic(err) + } + return r }