Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/geth_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,9 @@ func (b *GethStatusBackend) initProtocol() error {
AccountsPublisher: b.statusNode.AccountsPublisher(),
TimeSource: b.statusNode.TimeSource(),
MetricsEnabled: b.prometheusMetrics != nil,
TokenManager: b.statusNode.TokenManager(),
TokenManager: NewCommunitiesTokenManager(b.statusNode.TokenManager()),
TokenBalanceManager: NewCommunitiesTokenBalanceManager(b.statusNode.TokenManager()),
NetworkManager: NewCommunitiesNetworkManager(b.statusNode.RPCClient().GetNetworkManager()),
}
err = st.InitProtocol(params)
if err != nil {
Expand Down
75 changes: 75 additions & 0 deletions api/protocol_adaptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package api

import (
"context"

gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/protocol/communities"
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/wallet/token"
tokenTypes "github.com/status-im/status-go/services/wallet/token/types"
)

var _ communities.NetworkManager = (*CommunitiesNetworkManager)(nil)
var _ communities.TokenManager = (*CommunitiesTokenManager)(nil)
var _ communities.TokenBalanceManager = (*CommunitiesTokenBalanceManager)(nil)

type CommunitiesNetworkManager struct {
networkManager *network.Manager
}

func NewCommunitiesNetworkManager(nm *network.Manager) *CommunitiesNetworkManager {
return &CommunitiesNetworkManager{networkManager: nm}
}

func (m *CommunitiesNetworkManager) GetAllChainIDs() ([]uint64, error) {
networks, err := m.networkManager.GetActiveNetworks()
if err != nil {
return nil, err
}

chainIDs := make([]uint64, 0)
for _, network := range networks {
chainIDs = append(chainIDs, network.ChainID)
}
return chainIDs, nil
}

type CommunitiesTokenManager struct {
tokenManager *token.Manager
}

func NewCommunitiesTokenManager(tm *token.Manager) *CommunitiesTokenManager {
return &CommunitiesTokenManager{tokenManager: tm}
}

func (m *CommunitiesTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *tokenTypes.Token {
return m.tokenManager.FindOrCreateTokenByAddress(ctx, chainID, address)
}

type CommunitiesTokenBalanceManager struct {
tokenManager *token.Manager
}

func NewCommunitiesTokenBalanceManager(tm *token.Manager) *CommunitiesTokenBalanceManager {
return &CommunitiesTokenBalanceManager{tokenManager: tm}
}

func (m *CommunitiesTokenBalanceManager) GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (communities.BalancesByChain, error) {
chainClients, err := m.tokenManager.RPCClient.EthClients(chainIDs)
if err != nil {
return nil, err
}

resp, err := m.tokenManager.GetBalancesByChain(context.Background(), chainClients, accounts, tokenAddresses)
return resp, err
}

func (m *CommunitiesTokenBalanceManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (communities.BalancesByChain, error) {
resp, err := m.tokenManager.GetCachedBalancesByChain(accounts, tokenAddresses, chainIDs)
if err != nil {
return resp, err
}

return resp, nil
}
98 changes: 38 additions & 60 deletions protocol/communities/manager.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package communities

//go:generate go tool mockgen -package=mock_communities -source=manager.go -destination=mock/communities/manager.go

import (
"bytes"
"context"
Expand Down Expand Up @@ -41,13 +43,11 @@ import (
"github.com/status-im/status-go/protocol/ens"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/requests"
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/server"
"github.com/status-im/status-go/services/personal"
"github.com/status-im/status-go/services/wallet/bigint"
walletcommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/services/wallet/token"
tokenTypes "github.com/status-im/status-go/services/wallet/token/types"
"github.com/status-im/status-go/signal"
)
Expand Down Expand Up @@ -104,7 +104,9 @@ type Manager struct {
identity *ecdsa.PrivateKey
installationID string
accountsManager *accsmanagement.AccountsManager
networkManager NetworkManager
tokenManager TokenManager
tokenBalanceManager TokenBalanceManager
collectiblesManager CollectiblesManager
logger *zap.Logger
signer MessageSigner
Expand Down Expand Up @@ -245,7 +247,9 @@ type membersReevaluationTask struct {

type managerOptions struct {
accountsManager *accsmanagement.AccountsManager
networkManager NetworkManager
tokenManager TokenManager
tokenBalanceManager TokenBalanceManager
collectiblesManager CollectiblesManager
communityTokensService CommunityTokensServiceInterface
permissionChecker PermissionChecker
Expand All @@ -256,11 +260,16 @@ type managerOptions struct {
allowForcingCommunityMembersReevaluation bool
}

type NetworkManager interface {
GetAllChainIDs() ([]uint64, error)
}
type TokenManager interface {
GetBalancesByChain(ctx context.Context, accounts, tokens []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *tokenTypes.Token
GetAllChainIDs() ([]uint64, error)
}

type TokenBalanceManager interface {
GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error)
}

type CollectibleContractData struct {
Expand All @@ -283,67 +292,15 @@ type CommunityTokensServiceInterface interface {
ProcessCommunityTokenAction(message *protobuf.CommunityTokenAction) error
}

type DefaultTokenManager struct {
tokenManager *token.Manager
networkManager network.ManagerInterface
}

func NewDefaultTokenManager(tm *token.Manager, nm network.ManagerInterface) *DefaultTokenManager {
return &DefaultTokenManager{tokenManager: tm, networkManager: nm}
}

type BalancesByChain = map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big

func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) {
networks, err := m.networkManager.GetAll()
if err != nil {
return nil, err
}

areTestNetworksEnabled, err := m.networkManager.GetTestNetworksEnabled()
if err != nil {
return nil, err
}

chainIDs := make([]uint64, 0)
for _, network := range networks {
if areTestNetworksEnabled == network.IsTest {
chainIDs = append(chainIDs, network.ChainID)
}
}
return chainIDs, nil
}

type CollectiblesManager interface {
FetchBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
FetchCachedBalancesByOwnerAndContractAddress(ctx context.Context, chainID walletcommon.ChainID, ownerAddress gethcommon.Address, contractAddresses []gethcommon.Address) (thirdparty.TokenBalancesPerContractAddress, error)
GetCollectibleOwnership(id thirdparty.CollectibleUniqueID) ([]thirdparty.AccountBalance, error)
FetchCollectibleOwnersByContractAddress(ctx context.Context, chainID walletcommon.ChainID, contractAddress gethcommon.Address) (*thirdparty.CollectibleContractOwnership, error)
}

func (m *DefaultTokenManager) GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
clients, err := m.tokenManager.RPCClient.EthClients(chainIDs)
if err != nil {
return nil, err
}

resp, err := m.tokenManager.GetBalancesByChain(context.Background(), clients, accounts, tokenAddresses)
return resp, err
}

func (m *DefaultTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
resp, err := m.tokenManager.GetCachedBalancesByChain(accounts, tokenAddresses, chainIDs)
if err != nil {
return resp, err
}

return resp, nil
}

func (m *DefaultTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *tokenTypes.Token {
return m.tokenManager.FindOrCreateTokenByAddress(ctx, chainID, address)
}

type ManagerOption func(*managerOptions)

func WithMessageSigner(signer MessageSigner) ManagerOption {
Expand All @@ -364,12 +321,24 @@ func WithCollectiblesManager(collectiblesManager CollectiblesManager) ManagerOpt
}
}

func WithNetworkManager(networkManager NetworkManager) ManagerOption {
return func(opts *managerOptions) {
opts.networkManager = networkManager
}
}

func WithTokenManager(tokenManager TokenManager) ManagerOption {
return func(opts *managerOptions) {
opts.tokenManager = tokenManager
}
}

func WithTokenBalanceManager(tokenBalanceManager TokenBalanceManager) ManagerOption {
return func(opts *managerOptions) {
opts.tokenBalanceManager = tokenBalanceManager
}
}

func WithCommunityTokensService(communityTokensService CommunityTokensServiceInterface) ManagerOption {
return func(opts *managerOptions) {
opts.communityTokensService = communityTokensService
Expand Down Expand Up @@ -455,6 +424,14 @@ func NewManager(
manager.tokenManager = managerConfig.tokenManager
}

if managerConfig.tokenBalanceManager != nil {
manager.tokenBalanceManager = managerConfig.tokenBalanceManager
}

if managerConfig.networkManager != nil {
manager.networkManager = managerConfig.networkManager
}

if managerConfig.communityTokensService != nil {
manager.communityTokensService = managerConfig.communityTokensService
}
Expand All @@ -469,7 +446,8 @@ func NewManager(
manager.PermissionChecker = managerConfig.permissionChecker
} else {
manager.PermissionChecker = &DefaultPermissionChecker{
tokenManager: manager.tokenManager,
networkManager: manager.networkManager,
tokenBalanceManager: manager.tokenBalanceManager,
collectiblesManager: manager.collectiblesManager,
logger: logger,
ensVerifier: ensverifier,
Expand Down Expand Up @@ -3356,7 +3334,7 @@ func (m *Manager) CheckChannelPermissions(communityID types.HexBytes, chatID str
viewOnlyPreParsedPermissions := preParsedCommunityPermissionsData(viewOnlyPermissions)
viewAndPostPreParsedPermissions := preParsedCommunityPermissionsData(viewAndPostPermissions)

allChainIDs, err := m.tokenManager.GetAllChainIDs()
allChainIDs, err := m.networkManager.GetAllChainIDs()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -3465,7 +3443,7 @@ func (m *Manager) CheckAllChannelsPermissions(communityID types.HexBytes, addres
}
channels := community.Chats()

allChainIDs, err := m.tokenManager.GetAllChainIDs()
allChainIDs, err := m.networkManager.GetAllChainIDs()
if err != nil {
return nil, err
}
Expand Down
24 changes: 17 additions & 7 deletions protocol/communities/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ func (m *testCollectiblesManager) FetchCachedBalancesByOwnerAndContractAddress(c
return m.response[uint64(chainID)][ownerAddress], nil
}

type testTokenManager struct {
type testTokenBalanceManager struct {
response map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big
}

func (m *testTokenManager) setResponse(chainID uint64, walletAddress, tokenAddress gethcommon.Address, balance int64) {
func (m *testTokenBalanceManager) setResponse(chainID uint64, walletAddress, tokenAddress gethcommon.Address, balance int64) {

if m.response == nil {
m.response = make(map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big)
Expand All @@ -188,23 +188,29 @@ func (m *testTokenManager) setResponse(chainID uint64, walletAddress, tokenAddre

}

func (m *testTokenManager) GetAllChainIDs() ([]uint64, error) {
type testNetworkManager struct {
}

func (m *testNetworkManager) GetAllChainIDs() ([]uint64, error) {
return []uint64{5}, nil
}

func (m *testTokenManager) GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) {
func (m *testTokenBalanceManager) GetBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big, error) {
return m.response, nil
}

func (m *testTokenManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
func (m *testTokenBalanceManager) GetCachedBalancesByChain(ctx context.Context, accounts, tokenAddresses []gethcommon.Address, chainIDs []uint64) (BalancesByChain, error) {
return m.response, nil
}

type testTokenManager struct {
}

func (m *testTokenManager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint64, address gethcommon.Address) *tokenTypes.Token {
return nil
}

func (s *ManagerSuite) setupManagerForTokenPermissions() (*Manager, *testCollectiblesManager, *testTokenManager) {
func (s *ManagerSuite) setupManagerForTokenPermissions() (*Manager, *testCollectiblesManager, *testTokenBalanceManager) {
db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
s.NoError(err, "creating sqlite db instance")
err = sqlite.Migrate(db)
Expand All @@ -216,17 +222,21 @@ func (s *ManagerSuite) setupManagerForTokenPermissions() (*Manager, *testCollect

cm := &testCollectiblesManager{}
tm := &testTokenManager{}
tbm := &testTokenBalanceManager{}
nm := &testNetworkManager{}

options := []ManagerOption{
WithCollectiblesManager(cm),
WithTokenManager(tm),
WithTokenBalanceManager(tbm),
WithNetworkManager(nm),
}

m, err := NewManager(key, "", db, nil, nil, nil, nil, &TimeSourceStub{}, nil, nil, options...)
s.Require().NoError(err)
s.Require().NoError(m.Start())

return m, cm, tm
return m, cm, tbm
}

func (s *ManagerSuite) TestRetrieveTokens() {
Expand Down
Loading
Loading