diff --git a/Makefile b/Makefile index 21aa08c..b8c60da 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: test test: vet - go test -v -count 1 ./... + go tool gotest -v -count 1 ./... .PHONY: vet vet: diff --git a/db.go b/db.go index d84d3af..9371c92 100644 --- a/db.go +++ b/db.go @@ -60,6 +60,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" @@ -106,12 +107,14 @@ func register0(name, drv, dsn string, readOnly bool) { // when the Close is called, transaction is rolled back type conn struct { sync.Mutex - tx *sql.Tx - dsn string - opened int - drv *txDriver - savepoints []int - log io.Writer + tx *sql.Tx + dsn string + opened int + drv *txDriver + savepoints []int + log io.Writer + readOnly bool + readOnlyApplied atomic.Bool } type txDriver struct { @@ -166,12 +169,6 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) { return nil, err } opts := []stdlib.OptionOpenDB{} - if d.readOnly { - opts = append(opts, stdlib.OptionAfterConnect(func(ctx context.Context, c *pgx.Conn) error { - _, err := c.Exec(ctx, "SET default_transaction_read_only = true") - return err - })) - } connector := stdlib.GetConnector(*cc, opts...) d.db = sql.OpenDB(connector) } else { @@ -185,7 +182,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) { } c, ok := d.conns[dsn] if !ok { - c = &conn{dsn: dsn, drv: d, savepoints: []int{0}, log: d.log} + c = &conn{dsn: dsn, drv: d, savepoints: []int{0}, log: d.log, readOnly: d.readOnly} fmt.Fprintf(c.log, "%s: open\n", c.dsn) c.tx, err = d.db.Begin() d.conns[dsn] = c @@ -194,6 +191,17 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) { return c, err } +func (c *conn) setReadOnlyIfNeed() error { + if !c.readOnly || c.readOnlyApplied.Load() { + return nil + } + if _, err := c.tx.Exec("SET transaction_read_only = true"); err != nil { + return fmt.Errorf("failed to set transaction_read_only: %w", err) + } + c.readOnlyApplied.Store(true) + return nil +} + func (c *conn) Close() (err error) { c.drv.Lock() defer c.drv.Unlock() @@ -255,6 +263,9 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { defer c.Unlock() fmt.Fprintf(c.log, "%s: exec %s\n", c.dsn, query) + if err := c.setReadOnlyIfNeed(); err != nil { + return nil, err + } return c.tx.Exec(query, mapArgs(args)...) } @@ -270,6 +281,10 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { c.Lock() defer c.Unlock() + if err := c.setReadOnlyIfNeed(); err != nil { + return nil, err + } + // query rows rs, err := c.tx.Query(query, mapArgs(args)...) if err != nil { @@ -412,6 +427,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam c.Lock() defer c.Unlock() + if err := c.setReadOnlyIfNeed(); err != nil { + return nil, err + } + rs, err := c.tx.QueryContext(ctx, query, mapNamedArgs(args)...) if err != nil { return nil, err @@ -427,6 +446,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name defer c.Unlock() fmt.Fprintf(c.log, "%s: exec %s\n", c.dsn, query) + if err := c.setReadOnlyIfNeed(); err != nil { + return nil, err + } return c.tx.ExecContext(ctx, query, mapNamedArgs(args)...) } diff --git a/go.mod b/go.mod index 5826cac..cd26362 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,16 @@ require ( ) require ( + github.com/fatih/color v1.18.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rakyll/gotest v0.0.7 // indirect golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.31.0 // indirect ) + +tool github.com/rakyll/gotest diff --git a/go.sum b/go.sum index ee8ea8a..1c03c41 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -9,10 +11,16 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rakyll/gotest v0.0.7 h1:CL4D+fVEL0cUS5ys1cjrd+pN7sb8s1uf2LUy3UqyhAo= +github.com/rakyll/gotest v0.0.7/go.mod h1:F/7ufCiqpm6I79Epl+SQ7tc03zSdgcf7yZsGyBH60+Q= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -20,6 +28,9 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=