Skip to content

fix(conn): ensure ReadTimeout applies to each conn.Read #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
4 changes: 3 additions & 1 deletion clickhouse_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ type Options struct {
// Use this instead of Auth.Username and Auth.Password if you're using JWT auth.
GetJWT GetJWTFunc

scheme string
scheme string
// ReadTimeout is the maximum duration the client will wait for ClickHouse
// to respond to a single Rady call for bytes over the connection.
ReadTimeout time.Duration
}

Expand Down
19 changes: 19 additions & 0 deletions conn_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"io"
"time"

"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
)
Expand Down Expand Up @@ -74,6 +75,15 @@ func (c *connect) firstBlockImpl(ctx context.Context, on *onProcess) (*proto.Blo
c.readerMutex.Lock()
defer c.readerMutex.Unlock()

// set a read deadline - alternative to context.Read operation will fail if no data is received after deadline.
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
defer c.conn.SetReadDeadline(time.Time{})
// context level deadlines override any read deadline
if deadline, ok := ctx.Deadline(); ok {
c.conn.SetDeadline(deadline)
defer c.conn.SetDeadline(time.Time{})
}

for {
if c.reader == nil {
return nil, errors.New("unexpected state: c.reader is nil")
Expand Down Expand Up @@ -144,6 +154,15 @@ func (c *connect) processImpl(ctx context.Context, on *onProcess) error {
c.readerMutex.Lock()
defer c.readerMutex.Unlock()

// set a read deadline - alternative to context.Read operation will fail if no data is received after deadline.
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
defer c.conn.SetReadDeadline(time.Time{})
// context level deadlines override any read deadline
if deadline, ok := ctx.Deadline(); ok {
c.conn.SetDeadline(deadline)
defer c.conn.SetDeadline(time.Time{})
}

for {
if c.reader == nil {
return errors.New("unexpected state: c.reader is nil")
Expand Down
10 changes: 0 additions & 10 deletions conn_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package clickhouse

import (
"context"
"time"

"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
)
Expand All @@ -38,15 +37,6 @@ func (c *connect) query(ctx context.Context, release nativeTransportRelease, que
return nil, err
}

// set a read deadline - alternative to context.Read operation will fail if no data is received after deadline.
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
defer c.conn.SetReadDeadline(time.Time{})
// context level deadlines override any read deadline
if deadline, ok := ctx.Deadline(); ok {
c.conn.SetDeadline(deadline)
defer c.conn.SetDeadline(time.Time{})
}

if err = c.sendQuery(body, &options); err != nil {
release(c, err)
return nil, err
Expand Down
122 changes: 122 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package clickhouse

import (
"context"
"log/slog"
"os"
"testing"
"time"

"github.com/ClickHouse/clickhouse-go/v2/lib/column"
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
chtesting "github.com/ClickHouse/clickhouse-go/v2/lib/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestConn_Query(t *testing.T) {
slog.SetDefault(slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{
Level: slog.LevelDebug,
},
)))

handlers := chtesting.DefaultHandlers()
handlers.OnQuery = func(q *proto.Query, blocks []*proto.Block, c chan<- *proto.Block) error {
col := &column.UInt8{}
require.NoError(t, col.AppendRow(uint8(1)))

c <- &proto.Block{
Columns: []column.Interface{
col,
},
}

return nil
}

server, err := chtesting.NewTestServer(":0", handlers)
require.NoError(t, err)

server.Start()
t.Cleanup(func() { server.Stop() })

conn, err := Open(&Options{
Addr: []string{server.Address()},
MaxOpenConns: 2,
})
require.NoError(t, err)

ctx := context.TODO()
require.NoError(t, conn.Ping(ctx))

rows, err := conn.Query(ctx, "SELECT 1")
require.NoError(t, err)

var num uint8
for rows.Next() {
err := rows.Scan(&num)
require.NoError(t, err)
}

assert.Equal(t, uint8(1), num)
require.NoError(t, rows.Err())
}

func TestConn_Query_ReadTimeout(t *testing.T) {
slog.SetDefault(slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{
Level: slog.LevelDebug,
},
)))

handlers := chtesting.DefaultHandlers()
handlers.OnQuery = func(q *proto.Query, blocks []*proto.Block, c chan<- *proto.Block) error {
col := &column.UInt8{}
require.NoError(t, col.AppendRow(uint8(1)))

// sends first block
c <- &proto.Block{
Columns: []column.Interface{
col,
},
}

// then blocks indefinitely
select {}
}

server, err := chtesting.NewTestServer(":0", handlers)
require.NoError(t, err)

server.Start()
t.Cleanup(func() { server.Stop() })

conn, err := Open(&Options{
Addr: []string{server.Address()},
MaxOpenConns: 2,
ReadTimeout: 2 * time.Second,
})
require.NoError(t, err)

ctx := context.TODO()
require.NoError(t, conn.Ping(ctx))

rows, err := conn.Query(ctx, "SELECT 1")
require.NoError(t, err)

t.Run("first row is returned", func(t *testing.T) {
assert.True(t, rows.Next())
var num uint8
err = rows.Scan(&num)
require.NoError(t, err)
assert.Equal(t, uint8(1), num)
})

t.Run("second row timeout", func(t *testing.T) {
assert.False(t, rows.Next())
chtesting.AssertIsTimeoutError(t, rows.Err())
})
}
37 changes: 35 additions & 2 deletions lib/proto/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package proto

import (
"fmt"
chproto "github.com/ClickHouse/ch-go/proto"
"gopkg.in/yaml.v3"
"strconv"
"strings"
"time"

chproto "github.com/ClickHouse/ch-go/proto"
"gopkg.in/yaml.v3"

"github.com/ClickHouse/clickhouse-go/v2/lib/timezone"
)

Expand All @@ -42,6 +43,22 @@ func (h ClientHandshake) Encode(buffer *chproto.Buffer) {
buffer.PutUVarInt(h.ProtocolVersion)
}

func (h *ClientHandshake) Decode(reader *chproto.Reader) (err error) {
if h.ClientName, err = reader.Str(); err != nil {
return fmt.Errorf("could not read client name: %v", err)
}
if h.ClientVersion.Major, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read client major version: %v", err)
}
if h.ClientVersion.Minor, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read client minor version: %v", err)
}
if h.ProtocolVersion, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read protocol version: %v", err)
}
return nil
}

func (h ClientHandshake) String() string {
return fmt.Sprintf("%s %d.%d.%d", h.ClientName, h.ClientVersion.Major, h.ClientVersion.Minor, h.ClientVersion.Patch)
}
Expand Down Expand Up @@ -85,6 +102,22 @@ func CheckMinVersion(constraint Version, version Version) bool {
return true
}

func (srv *ServerHandshake) Encode(buffer *chproto.Buffer) {
buffer.PutString(srv.Name)
buffer.PutUVarInt(srv.Version.Major)
buffer.PutUVarInt(srv.Version.Minor)
buffer.PutUVarInt(srv.Revision)
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
buffer.PutString(srv.Timezone.String())
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
buffer.PutString(srv.DisplayName)
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
buffer.PutUVarInt(srv.Version.Patch)
}
}

func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
if srv.Name, err = reader.Str(); err != nil {
return fmt.Errorf("could not read server name: %v", err)
Expand Down
Loading
Loading