Skip to content

Commit 70c7f3b

Browse files
committed
universe/supplycommit: add unit tests
1 parent 53cd06a commit 70c7f3b

File tree

2 files changed

+1882
-0
lines changed

2 files changed

+1882
-0
lines changed

universe/supplycommit/mock.go

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
package supplycommit
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"github.com/btcsuite/btcd/btcec/v2"
8+
"github.com/btcsuite/btcd/btcutil"
9+
"github.com/btcsuite/btcd/btcutil/psbt"
10+
"github.com/btcsuite/btcd/chaincfg/chainhash"
11+
"github.com/btcsuite/btcd/wire"
12+
"github.com/lightninglabs/taproot-assets/asset"
13+
"github.com/lightninglabs/taproot-assets/mssmt"
14+
"github.com/lightninglabs/taproot-assets/proof"
15+
"github.com/lightninglabs/taproot-assets/tapsend"
16+
"github.com/lightningnetwork/lnd/chainntnfs"
17+
lfn "github.com/lightningnetwork/lnd/fn/v2"
18+
"github.com/lightningnetwork/lnd/keychain"
19+
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
20+
"github.com/lightningnetwork/lnd/lnwire"
21+
"github.com/stretchr/testify/mock"
22+
)
23+
24+
// mockSupplyTreeView is a mock implementation of the SupplyTreeView interface.
25+
type mockSupplyTreeView struct {
26+
mock.Mock
27+
}
28+
29+
func (m *mockSupplyTreeView) FetchSubTree(assetSpec asset.Specifier,
30+
treeType SupplySubTree) lfn.Result[mssmt.Tree] {
31+
32+
args := m.Called(assetSpec, treeType)
33+
return args.Get(0).(lfn.Result[mssmt.Tree])
34+
}
35+
36+
func (m *mockSupplyTreeView) FetchSubTrees(
37+
assetSpec asset.Specifier) lfn.Result[SupplyTrees] {
38+
39+
args := m.Called(assetSpec)
40+
return args.Get(0).(lfn.Result[SupplyTrees])
41+
}
42+
43+
func (m *mockSupplyTreeView) FetchRootSupplyTree(
44+
assetSpec asset.Specifier) lfn.Result[mssmt.Tree] {
45+
46+
args := m.Called(assetSpec)
47+
return args.Get(0).(lfn.Result[mssmt.Tree])
48+
}
49+
50+
// mockCommitmentTracker is a mock implementation of the CommitmentTracker
51+
// interface.
52+
type mockCommitmentTracker struct {
53+
mock.Mock
54+
}
55+
56+
func (m *mockCommitmentTracker) UnspentPrecommits(ctx context.Context,
57+
assetSpec asset.Specifier) lfn.Result[PreCommits] {
58+
59+
args := m.Called(ctx, assetSpec)
60+
return args.Get(0).(lfn.Result[PreCommits])
61+
}
62+
63+
func (m *mockCommitmentTracker) SupplyCommit(ctx context.Context,
64+
assetSpec asset.Specifier) RootCommitResp {
65+
66+
args := m.Called(ctx, assetSpec)
67+
return args.Get(0).(RootCommitResp)
68+
}
69+
70+
// fundPsbtMockFn defines a type for the mock function used in FundPsbt,
71+
// to simplify a long type assertion.
72+
type fundPsbtMockFn func(
73+
context.Context, *psbt.Packet, uint32,
74+
chainfee.SatPerKWeight, int32,
75+
) (*tapsend.FundedPsbt, error)
76+
77+
// signAndFinalizePsbtMockFn defines a type for the mock function used in
78+
// SignAndFinalizePsbt, to simplify a long type assertion.
79+
type signAndFinalizePsbtMockFn func(
80+
context.Context, *psbt.Packet,
81+
) (*psbt.Packet, error)
82+
83+
// mockWallet is a mock implementation of the Wallet interface.
84+
type mockWallet struct {
85+
mock.Mock
86+
}
87+
88+
func (m *mockWallet) FundPsbt(
89+
ctx context.Context, packet *psbt.Packet, minConfs uint32,
90+
feeRate chainfee.SatPerKWeight, changeIdx int32,
91+
) (*tapsend.FundedPsbt, error) {
92+
93+
args := m.Called(ctx, packet, minConfs, feeRate, changeIdx)
94+
95+
// Check if the first argument returned by the mock is a function.
96+
// If so, this indicates a custom mock implementation that should be
97+
// executed to get the actual return values.
98+
arg0 := args.Get(0)
99+
if fn, ok := arg0.(fundPsbtMockFn); ok {
100+
return fn(ctx, packet, minConfs, feeRate, changeIdx)
101+
}
102+
103+
if args.Get(0) == nil {
104+
return nil, args.Error(1)
105+
}
106+
return args.Get(0).(*tapsend.FundedPsbt), args.Error(1)
107+
}
108+
109+
func (m *mockWallet) SignAndFinalizePsbt(ctx context.Context,
110+
packet *psbt.Packet) (*psbt.Packet, error) {
111+
112+
args := m.Called(ctx, packet)
113+
114+
// Check if the first argument returned by the mock is a function.
115+
// If so, this indicates a custom mock implementation that should be
116+
// executed to get the actual return values.
117+
arg0 := args.Get(0)
118+
if fn, ok := arg0.(signAndFinalizePsbtMockFn); ok {
119+
return fn(ctx, packet)
120+
}
121+
122+
if args.Get(0) == nil {
123+
return nil, args.Error(1)
124+
}
125+
return args.Get(0).(*psbt.Packet), args.Error(1)
126+
}
127+
128+
func (m *mockWallet) ImportTaprootOutput(ctx context.Context,
129+
pubKey *btcec.PublicKey) (btcutil.Address, error) {
130+
131+
args := m.Called(ctx, pubKey)
132+
if args.Get(0) == nil {
133+
return nil, args.Error(1)
134+
}
135+
return args.Get(0).(btcutil.Address), args.Error(1)
136+
}
137+
138+
func (m *mockWallet) UnlockInput(ctx context.Context, op wire.OutPoint) error {
139+
args := m.Called(ctx, op)
140+
return args.Error(0)
141+
}
142+
143+
func (m *mockWallet) DeriveNextKey(
144+
ctx context.Context) (keychain.KeyDescriptor, error) {
145+
146+
args := m.Called(ctx)
147+
if args.Get(0) == nil {
148+
return keychain.KeyDescriptor{}, args.Error(1)
149+
}
150+
return args.Get(0).(keychain.KeyDescriptor), args.Error(1)
151+
}
152+
153+
// mockChainBridge is a mock implementation of the tapgarden.ChainBridge
154+
// interface.
155+
type mockChainBridge struct {
156+
mock.Mock
157+
}
158+
159+
func (m *mockChainBridge) RegisterConfirmationsNtfn(
160+
ctx context.Context, txid *chainhash.Hash, pkScript []byte,
161+
numConfs, heightHint uint32, includeBlock bool,
162+
reOrgChan chan struct{},
163+
) (*chainntnfs.ConfirmationEvent, chan error, error) {
164+
165+
args := m.Called(
166+
ctx, txid, pkScript, numConfs, heightHint, includeBlock,
167+
reOrgChan,
168+
)
169+
if args.Get(0) == nil {
170+
return nil, nil, args.Error(2)
171+
}
172+
return args.Get(0).(*chainntnfs.ConfirmationEvent),
173+
args.Get(1).(chan error), args.Error(2)
174+
}
175+
176+
func (m *mockChainBridge) RegisterSpendNtfn(ctx context.Context,
177+
outpoint *wire.OutPoint, pkScript []byte,
178+
heightHint uint32) (*chainntnfs.SpendEvent, error) {
179+
180+
args := m.Called(ctx, outpoint, pkScript, heightHint)
181+
if args.Get(0) == nil {
182+
return nil, args.Error(1)
183+
}
184+
return args.Get(0).(*chainntnfs.SpendEvent), args.Error(1)
185+
}
186+
187+
func (m *mockChainBridge) PublishTransaction(ctx context.Context,
188+
tx *wire.MsgTx, label string) error {
189+
190+
args := m.Called(ctx, tx, label)
191+
return args.Error(0)
192+
}
193+
194+
func (m *mockChainBridge) EstimateFee(ctx context.Context,
195+
confTarget uint32) (chainfee.SatPerKWeight, error) {
196+
197+
args := m.Called(ctx, confTarget)
198+
if args.Get(0) == nil {
199+
return chainfee.SatPerKWeight(0), args.Error(1)
200+
}
201+
return args.Get(0).(chainfee.SatPerKWeight), args.Error(1)
202+
}
203+
204+
func (m *mockChainBridge) CurrentHeight(ctx context.Context) (uint32, error) {
205+
args := m.Called(ctx)
206+
return args.Get(0).(uint32), args.Error(1)
207+
}
208+
209+
func (m *mockChainBridge) RegisterBlockEpochNtfn(
210+
ctx context.Context) (chan int32, chan error, error) {
211+
212+
args := m.Called(ctx)
213+
if args.Get(0) == nil {
214+
return nil, nil, args.Error(2)
215+
}
216+
return args.Get(0).(chan int32), args.Get(1).(chan error), args.Error(2)
217+
}
218+
219+
func (m *mockChainBridge) GetBlock(ctx context.Context,
220+
hash chainhash.Hash) (*wire.MsgBlock, error) {
221+
222+
args := m.Called(ctx, hash)
223+
if args.Get(0) == nil {
224+
return nil, args.Error(1)
225+
}
226+
return args.Get(0).(*wire.MsgBlock), args.Error(1)
227+
}
228+
229+
func (m *mockChainBridge) GetBlockHash(ctx context.Context,
230+
height int64) (chainhash.Hash, error) {
231+
232+
args := m.Called(ctx, height)
233+
return args.Get(0).(chainhash.Hash), args.Error(1)
234+
}
235+
236+
func (m *mockChainBridge) VerifyBlock(ctx context.Context,
237+
header wire.BlockHeader, height uint32) error {
238+
239+
args := m.Called(ctx, header, height)
240+
return args.Error(0)
241+
}
242+
243+
func (m *mockChainBridge) GetBlockTimestamp(ctx context.Context,
244+
height uint32) int64 {
245+
246+
args := m.Called(ctx, height)
247+
return args.Get(0).(int64)
248+
}
249+
250+
func (m *mockChainBridge) GenFileChainLookup(f *proof.File) asset.ChainLookup {
251+
args := m.Called(f)
252+
return args.Get(0).(asset.ChainLookup)
253+
}
254+
255+
func (m *mockChainBridge) GenProofChainLookup(
256+
p *proof.Proof) (asset.ChainLookup, error) {
257+
258+
args := m.Called(p)
259+
if args.Get(0) == nil {
260+
return nil, args.Error(1)
261+
}
262+
return args.Get(0).(asset.ChainLookup), args.Error(1)
263+
}
264+
265+
// mockStateMachineStore is a mock implementation of the StateMachineStore
266+
// interface.
267+
type mockStateMachineStore struct {
268+
mock.Mock
269+
}
270+
271+
func (m *mockStateMachineStore) InsertPendingUpdate(ctx context.Context,
272+
spec asset.Specifier, event SupplyUpdateEvent) error {
273+
274+
args := m.Called(ctx, spec, event)
275+
return args.Error(0)
276+
}
277+
278+
func (m *mockStateMachineStore) InsertSignedCommitTx(ctx context.Context,
279+
spec asset.Specifier, tx SupplyCommitTxn) error {
280+
281+
args := m.Called(ctx, spec, tx)
282+
return args.Error(0)
283+
}
284+
285+
func (m *mockStateMachineStore) CommitState(ctx context.Context,
286+
spec asset.Specifier, state State) error {
287+
288+
args := m.Called(ctx, spec, state)
289+
return args.Error(0)
290+
}
291+
292+
func (m *mockStateMachineStore) FetchState(ctx context.Context,
293+
spec asset.Specifier) (State, lfn.Option[SupplyStateTransition],
294+
error) {
295+
296+
args := m.Called(ctx, spec)
297+
if args.Get(2) != nil {
298+
return nil, lfn.None[SupplyStateTransition](), args.Error(2)
299+
}
300+
state := args.Get(0)
301+
if state == nil {
302+
return nil, args.Get(1).(lfn.Option[SupplyStateTransition]),
303+
args.Error(2)
304+
}
305+
return state.(State),
306+
args.Get(1).(lfn.Option[SupplyStateTransition]), args.Error(2)
307+
}
308+
309+
func (m *mockStateMachineStore) ApplyStateTransition(ctx context.Context,
310+
spec asset.Specifier, transition SupplyStateTransition) error {
311+
312+
args := m.Called(ctx, spec, transition)
313+
return args.Error(0)
314+
}
315+
316+
// mockDaemonAdapters is a mock implementation of the protofsm.DaemonAdapters
317+
// interface.
318+
type mockDaemonAdapters struct {
319+
mock.Mock
320+
321+
confChan chan *chainntnfs.TxConfirmation
322+
spendChan chan *chainntnfs.SpendDetail
323+
}
324+
325+
func newMockDaemonAdapters() *mockDaemonAdapters {
326+
return &mockDaemonAdapters{
327+
confChan: make(chan *chainntnfs.TxConfirmation, 1),
328+
spendChan: make(chan *chainntnfs.SpendDetail, 1),
329+
}
330+
}
331+
332+
func (m *mockDaemonAdapters) BroadcastTransaction(
333+
tx *wire.MsgTx, label string) error {
334+
335+
args := m.Called(tx, label)
336+
return args.Error(0)
337+
}
338+
339+
func (m *mockDaemonAdapters) RegisterConfirmationsNtfn(
340+
txid *chainhash.Hash, pkScript []byte,
341+
numConfs, heightHint uint32, opts ...chainntnfs.NotifierOption,
342+
) (*chainntnfs.ConfirmationEvent, error) {
343+
344+
args := m.Called(txid, pkScript, numConfs, heightHint, opts)
345+
346+
err := args.Error(0)
347+
348+
return &chainntnfs.ConfirmationEvent{
349+
Confirmed: m.confChan,
350+
}, err
351+
}
352+
353+
func (m *mockDaemonAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint,
354+
pkScript []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
355+
356+
args := m.Called(outpoint, pkScript, heightHint)
357+
358+
err := args.Error(0)
359+
360+
return &chainntnfs.SpendEvent{
361+
Spend: m.spendChan,
362+
}, err
363+
}
364+
365+
func (m *mockDaemonAdapters) SendMessages(pub btcec.PublicKey,
366+
msgs []lnwire.Message) error {
367+
368+
args := m.Called(pub, msgs)
369+
return args.Error(0)
370+
}
371+
372+
// mockErrorReporter is a mock implementation of the protofsm.ErrorReporter
373+
// interface.
374+
type mockErrorReporter struct {
375+
mock.Mock
376+
reportedError error
377+
mu sync.Mutex
378+
}
379+
380+
func (m *mockErrorReporter) ReportError(err error) {
381+
m.mu.Lock()
382+
defer m.mu.Unlock()
383+
m.reportedError = err
384+
m.Called(err)
385+
}
386+
387+
func (m *mockErrorReporter) GetReportedError() error {
388+
m.mu.Lock()
389+
defer m.mu.Unlock()
390+
return m.reportedError
391+
}

0 commit comments

Comments
 (0)