@@ -16,9 +16,11 @@ package connection
1616
1717import (
1818 "context"
19+ "fmt"
1920 "net"
2021 "strconv"
2122 "strings"
23+ "sync"
2224 "time"
2325
2426 "github.com/ShiftLeftSecurity/gaum/v2/db/logging"
@@ -117,6 +119,97 @@ type DB interface {
117119 BulkInsert (ctx context.Context , tableName string , columns []string , values [][]interface {}) (execError error )
118120}
119121
122+ var _ DB = (* FlexibleTransaction )(nil )
123+
124+ // FlexibleTransaction allows for a DB transaction to be passed through functions and avoid multiple commit/rollbacks
125+ // it also takes care of some of the most repeated checks at the time of commit/rollback and tx checking.
126+ type FlexibleTransaction struct {
127+ DB
128+ rolled bool
129+ concurrencySafeguard sync.Mutex
130+ }
131+
132+ // Cleanup is an implementation of TXFinishFunc for FlexibleTransaction, it handles an attempt to either Commit
133+ // or rollback a transaction depending on the perceived outcome: If someone invoked rollback on the FlexibleTransaction
134+ // we assume the process went wrong and will rollback all. This is intended as a way to mitigate the lack of different
135+ // abstractions for Transaction and Connection in the current version of gaum, retaining the ability to finalize the
136+ // transaction at the initiator level.
137+ // This does however allow some bad habits such as functions acting different depending on if they think they receive
138+ // a transaction or a connection instead of having two functions that force the former or later as arguments.
139+ func (f * FlexibleTransaction ) Cleanup (ctx context.Context ) (bool , bool , error ) {
140+ f .concurrencySafeguard .Lock ()
141+ defer f .concurrencySafeguard .Unlock ()
142+ if f .DB == nil {
143+ return false , false , nil
144+ }
145+ if f .rolled {
146+ if err := f .DB .RollbackTransaction (ctx ); err != nil {
147+ return false , false , fmt .Errorf ("rolling back transaction: %w" , err )
148+ }
149+ return false , true , nil
150+ }
151+
152+ if err := f .DB .CommitTransaction (ctx ); err != nil {
153+ return false , false , fmt .Errorf ("committing transaction: %w" , err )
154+ }
155+ return true , false , nil
156+ }
157+
158+ // TXFinishFunc represents an all-encompassing function that either rolls or commits a tx based on the outcome.
159+ type TXFinishFunc func (ctx context.Context ) (committed , rolled bool , err error )
160+
161+ // BeginTransaction will wrap the passed DB into a transaction handler that supports it being used with less care
162+ // and prevents having to check if we are already in a tx and failures due to eager committers.
163+ func BeginTransaction (ctx context.Context , conn DB ) (DB , TXFinishFunc , error ) {
164+ // this can happen so let's work around it
165+ ft , isFT := conn .(* FlexibleTransaction )
166+ if isFT {
167+ return ft , func (ctx2 context.Context ) (bool , bool , error ) {
168+ return false , false , nil
169+ }, nil
170+ }
171+
172+ // the underlying conn is a tx, let's be careful not to commit/rollback it
173+ if conn .IsTransaction () {
174+ return & FlexibleTransaction {
175+ DB : conn ,
176+ },
177+ func (ctx2 context.Context ) (bool , bool , error ) {
178+ return false , false , nil
179+ },
180+ nil
181+
182+ }
183+
184+ tx , err := conn .BeginTransaction (ctx )
185+ if err != nil {
186+ return nil , nil , fmt .Errorf ("beginning transaction: %w" , err )
187+ }
188+
189+ f := & FlexibleTransaction {
190+ DB : tx ,
191+ }
192+ return f , f .Cleanup , nil
193+ }
194+
195+ // BeginTransaction implements DB for FlexibleTransaction
196+ func (f * FlexibleTransaction ) BeginTransaction (ctx context.Context ) (DB , error ) {
197+ return f , nil
198+ }
199+
200+ // CommitTransaction implements DB for FlexibleTransaction
201+ func (f * FlexibleTransaction ) CommitTransaction (ctx context.Context ) error {
202+ return nil
203+ }
204+
205+ // RollbackTransaction implements DB for FlexibleTransaction
206+ func (f * FlexibleTransaction ) RollbackTransaction (ctx context.Context ) error {
207+ f .concurrencySafeguard .Lock ()
208+ defer f .concurrencySafeguard .Unlock ()
209+ f .rolled = true
210+ return nil
211+ }
212+
120213// EscapeArgs return the query and args with the argument placeholder escaped.
121214func EscapeArgs (query string , args []interface {}) (string , []interface {}, error ) {
122215 // TODO: make this a bit less ugly
0 commit comments