Skip to content

Commit a62c55b

Browse files
authored
Add a FlexibleTransaction type and constructor (#61)
* Add a FlexibleTransaction type and constructor This allows to avoid the boilerplate of handling transactions between functions by wrapping DB into a state tracking shim that simulates a Tx.
1 parent eb71091 commit a62c55b

File tree

3 files changed

+285
-2
lines changed

3 files changed

+285
-2
lines changed

db/chain/chain.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ func NewExpressionChain(db connection.DB) *ExpressionChain {
2828
return &ExpressionChain{db: db}
2929
}
3030

31-
// NewNoDB creates an expression chain withouth the db, mostly with the purpose of making a more
32-
// abbreviated syntax for transient ExpresionChains such as CTE or subquery ones.
31+
// NewNoDB creates an expression chain without the db, mostly with the purpose of making a more
32+
// abbreviated syntax for transient ExpressionChains such as CTE or sub-query ones.
3333
func NewNoDB() *ExpressionChain {
3434
return &ExpressionChain{}
3535
}

db/connection/connection.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ package connection
1616

1717
import (
1818
"context"
19+
"fmt"
1920
"net"
2021
"strconv"
2122
"strings"
23+
"sync"
2224
"time"
2325

2426
"github.com/ShiftLeftSecurity/gaum/v2/db/logging"
@@ -117,6 +119,97 @@ type DB interface {
117119
BulkInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (execError error)
118120
}
119121

122+
var _ DB = (*FlexibleTransaction)(nil)
123+
124+
// FlexibleTransaction allows for a DB transaction to be passed through functions and avoid multiple commit/rollbacks
125+
// it also takes care of some of the most repeated checks at the time of commit/rollback and tx checking.
126+
type FlexibleTransaction struct {
127+
DB
128+
rolled bool
129+
concurrencySafeguard sync.Mutex
130+
}
131+
132+
// Cleanup is an implementation of TXFinishFunc for FlexibleTransaction, it handles an attempt to either Commit
133+
// or rollback a transaction depending on the perceived outcome: If someone invoked rollback on the FlexibleTransaction
134+
// we assume the process went wrong and will rollback all. This is intended as a way to mitigate the lack of different
135+
// abstractions for Transaction and Connection in the current version of gaum, retaining the ability to finalize the
136+
// transaction at the initiator level.
137+
// This does however allow some bad habits such as functions acting different depending on if they think they receive
138+
// a transaction or a connection instead of having two functions that force the former or later as arguments.
139+
func (f *FlexibleTransaction) Cleanup(ctx context.Context) (bool, bool, error) {
140+
f.concurrencySafeguard.Lock()
141+
defer f.concurrencySafeguard.Unlock()
142+
if f.DB == nil {
143+
return false, false, nil
144+
}
145+
if f.rolled {
146+
if err := f.DB.RollbackTransaction(ctx); err != nil {
147+
return false, false, fmt.Errorf("rolling back transaction: %w", err)
148+
}
149+
return false, true, nil
150+
}
151+
152+
if err := f.DB.CommitTransaction(ctx); err != nil {
153+
return false, false, fmt.Errorf("committing transaction: %w", err)
154+
}
155+
return true, false, nil
156+
}
157+
158+
// TXFinishFunc represents an all-encompassing function that either rolls or commits a tx based on the outcome.
159+
type TXFinishFunc func(ctx context.Context) (committed, rolled bool, err error)
160+
161+
// BeginTransaction will wrap the passed DB into a transaction handler that supports it being used with less care
162+
// and prevents having to check if we are already in a tx and failures due to eager committers.
163+
func BeginTransaction(ctx context.Context, conn DB) (DB, TXFinishFunc, error) {
164+
// this can happen so let's work around it
165+
ft, isFT := conn.(*FlexibleTransaction)
166+
if isFT {
167+
return ft, func(ctx2 context.Context) (bool, bool, error) {
168+
return false, false, nil
169+
}, nil
170+
}
171+
172+
// the underlying conn is a tx, let's be careful not to commit/rollback it
173+
if conn.IsTransaction() {
174+
return &FlexibleTransaction{
175+
DB: conn,
176+
},
177+
func(ctx2 context.Context) (bool, bool, error) {
178+
return false, false, nil
179+
},
180+
nil
181+
182+
}
183+
184+
tx, err := conn.BeginTransaction(ctx)
185+
if err != nil {
186+
return nil, nil, fmt.Errorf("beginning transaction: %w", err)
187+
}
188+
189+
f := &FlexibleTransaction{
190+
DB: tx,
191+
}
192+
return f, f.Cleanup, nil
193+
}
194+
195+
// BeginTransaction implements DB for FlexibleTransaction
196+
func (f *FlexibleTransaction) BeginTransaction(ctx context.Context) (DB, error) {
197+
return f, nil
198+
}
199+
200+
// CommitTransaction implements DB for FlexibleTransaction
201+
func (f *FlexibleTransaction) CommitTransaction(ctx context.Context) error {
202+
return nil
203+
}
204+
205+
// RollbackTransaction implements DB for FlexibleTransaction
206+
func (f *FlexibleTransaction) RollbackTransaction(ctx context.Context) error {
207+
f.concurrencySafeguard.Lock()
208+
defer f.concurrencySafeguard.Unlock()
209+
f.rolled = true
210+
return nil
211+
}
212+
120213
// EscapeArgs return the query and args with the argument placeholder escaped.
121214
func EscapeArgs(query string, args []interface{}) (string, []interface{}, error) {
122215
// TODO: make this a bit less ugly

db/connection/connection_test.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package connection
2+
3+
import (
4+
"context"
5+
"testing"
6+
)
7+
8+
type fakeConn struct {
9+
DB
10+
begin int
11+
commit int
12+
rollback int
13+
isTx bool
14+
}
15+
16+
func (f *fakeConn) BeginTransaction(ctx context.Context) (DB, error) {
17+
f.begin++
18+
f.isTx = true
19+
return f, nil
20+
}
21+
22+
func (f *fakeConn) CommitTransaction(ctx context.Context) error {
23+
f.commit++
24+
return nil
25+
}
26+
27+
func (f *fakeConn) RollbackTransaction(ctx context.Context) error {
28+
f.rollback++
29+
return nil
30+
}
31+
32+
func (f *fakeConn) IsTransaction() bool {
33+
return f.isTx
34+
}
35+
36+
var _ DB = (*fakeConn)(nil)
37+
38+
func TestFlexibleTransactionSucceeds(t *testing.T) {
39+
// Multiple TX begins, multiple commits
40+
fc := &fakeConn{}
41+
ctx := context.Background()
42+
tx, cleanup, err := BeginTransaction(ctx, fc)
43+
if err != nil {
44+
t.Fatal(err)
45+
}
46+
47+
for i := 0; i < 10; i++ {
48+
if err := tx.CommitTransaction(ctx); err != nil {
49+
t.Logf("Repetitive commit N %d", i+1)
50+
t.Fatal(err)
51+
}
52+
}
53+
committed, rolledBack, err := cleanup(ctx)
54+
if err != nil {
55+
t.Fatal(err)
56+
}
57+
if !committed {
58+
t.Log("tx was not committed but we expected it to")
59+
t.FailNow()
60+
}
61+
if rolledBack {
62+
t.Log("tx was rolled back and we did not expect it to")
63+
t.FailNow()
64+
}
65+
66+
if fc.begin != 1 {
67+
t.Logf("begin was called %d times in the underlying conn but we expected 1", fc.begin)
68+
t.FailNow()
69+
}
70+
71+
if fc.commit != 1 {
72+
t.Logf("commit was called %d times in the underlying conn but we expected 1", fc.commit)
73+
t.FailNow()
74+
}
75+
76+
if fc.rollback != 0 {
77+
t.Logf("rollback was called %d times in the underlying conn but we expected 0", fc.rollback)
78+
t.FailNow()
79+
}
80+
}
81+
82+
func TestFlexibleTransactionRollbackTransaction(t *testing.T) {
83+
// Multiple TX begins, multiple commits
84+
fc := &fakeConn{}
85+
ctx := context.Background()
86+
tx, cleanup, err := BeginTransaction(ctx, fc)
87+
if err != nil {
88+
t.Fatal(err)
89+
}
90+
91+
for i := 0; i < 10; i++ {
92+
if err := tx.CommitTransaction(ctx); err != nil {
93+
t.Logf("Repetitive commit N %d", i+1)
94+
t.Fatal(err)
95+
}
96+
}
97+
if err := tx.RollbackTransaction(ctx); err != nil {
98+
t.Fatal(err)
99+
}
100+
committed, rolledBack, err := cleanup(ctx)
101+
if err != nil {
102+
t.Fatal(err)
103+
}
104+
if committed {
105+
t.Log("tx was committed but we expected it not to")
106+
t.FailNow()
107+
}
108+
if !rolledBack {
109+
t.Log("tx was not rolled back and we expected it to")
110+
t.FailNow()
111+
}
112+
113+
if fc.begin != 1 {
114+
t.Logf("begin was called %d times in the underlying conn but we expected 1", fc.begin)
115+
t.FailNow()
116+
}
117+
118+
if fc.commit != 0 {
119+
t.Logf("commit was called %d times in the underlying conn but we expected 0", fc.commit)
120+
t.FailNow()
121+
}
122+
123+
if fc.rollback != 1 {
124+
t.Logf("rollback was called %d times in the underlying conn but we expected 1", fc.rollback)
125+
t.FailNow()
126+
}
127+
}
128+
129+
func TestFlexibleTransactionRecursive(t *testing.T) {
130+
// Multiple TX begins, multiple commits
131+
fc := &fakeConn{}
132+
ctx := context.Background()
133+
tx, cleanup, err := BeginTransaction(ctx, fc)
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
138+
tx, innerCleanup, err := BeginTransaction(ctx, fc)
139+
if err != nil {
140+
t.Fatal(err)
141+
}
142+
143+
// we call it early to see if it really is a noop
144+
innerCommit, innerRollback, err := innerCleanup(ctx)
145+
if err != nil {
146+
t.Fatal(err)
147+
}
148+
if innerCommit {
149+
t.Log("commit should not have happened on inner cleanup function")
150+
t.FailNow()
151+
}
152+
if innerRollback {
153+
t.Log("rollback should not have happened on inner cleanup function")
154+
t.FailNow()
155+
}
156+
// notice that we are using the inner tx
157+
for i := 0; i < 10; i++ {
158+
if err := tx.CommitTransaction(ctx); err != nil {
159+
t.Logf("Repetitive commit N %d", i+1)
160+
t.Fatal(err)
161+
}
162+
}
163+
committed, rolledBack, err := cleanup(ctx)
164+
if err != nil {
165+
t.Fatal(err)
166+
}
167+
if !committed {
168+
t.Log("tx was not committed but we expected it to")
169+
t.FailNow()
170+
}
171+
if rolledBack {
172+
t.Log("tx was rolled back and we did not expect it to")
173+
t.FailNow()
174+
}
175+
176+
if fc.begin != 1 {
177+
t.Logf("begin was called %d times in the underlying conn but we expected 1", fc.begin)
178+
t.FailNow()
179+
}
180+
181+
if fc.commit != 1 {
182+
t.Logf("commit was called %d times in the underlying conn but we expected 1", fc.commit)
183+
t.FailNow()
184+
}
185+
186+
if fc.rollback != 0 {
187+
t.Logf("rollback was called %d times in the underlying conn but we expected 0", fc.rollback)
188+
t.FailNow()
189+
}
190+
}

0 commit comments

Comments
 (0)