diff --git a/tracer.go b/tracer.go index e3ab16f..9b2e9bd 100644 --- a/tracer.go +++ b/tracer.go @@ -145,6 +145,7 @@ func NewTracer(opts ...Option) *Tracer { spanNameCtxFunc: cfg.spanNameCtxFunc, prefixQuerySpanName: cfg.prefixQuerySpanName, logSQLStatement: cfg.logSQLStatement, + logConnectionDetails: cfg.logConnectionDetails, includeParams: cfg.includeParams, disableAcquireTracer: cfg.disableAcquireTracer, } diff --git a/tracer_test.go b/tracer_test.go index 96eb9d8..89f97c9 100644 --- a/tracer_test.go +++ b/tracer_test.go @@ -2,10 +2,18 @@ package otelpgx import ( "context" + "fmt" + "net" "strings" "testing" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" ) func TestTracer_sqlOperationName(t *testing.T) { @@ -173,3 +181,210 @@ func TestTracer_sqlOperationNameFromCtx(t *testing.T) { }) } } + +// newMockConn creates a *pgx.Conn with the given connection details, backed +// by a fake PostgreSQL server over net.Pipe — no real database needed. +func newMockConn(t *testing.T, host string, port uint16, user, database string) *pgx.Conn { + t.Helper() + + client, server := net.Pipe() + errCh := make(chan error, 1) + + go func() { + var gErr error + defer func() { + if err := server.Close(); err != nil && gErr == nil { + gErr = fmt.Errorf("server close: %w", err) + } + errCh <- gErr + }() + + b := pgproto3.NewBackend(server, server) + if _, err := b.ReceiveStartupMessage(); err != nil { + gErr = fmt.Errorf("receive startup: %w", err) + return + } + + for _, msg := range []pgproto3.BackendMessage{ + &pgproto3.AuthenticationOk{}, + &pgproto3.BackendKeyData{SecretKey: make([]byte, 4)}, + &pgproto3.ReadyForQuery{TxStatus: 'I'}, + } { + b.Send(msg) + } + if err := b.Flush(); err != nil { + gErr = fmt.Errorf("flush: %w", err) + return + } + + // Drain until the client disconnects or sends Terminate. + for { + msg, err := b.Receive() + if err != nil { + break + } + if _, ok := msg.(*pgproto3.Terminate); ok { + break + } + } + }() + + dsn := fmt.Sprintf("postgres://%s@%s:%d/%s?sslmode=disable", user, host, int(port), database) + config, err := pgx.ParseConfig(dsn) + require.NoError(t, err) + + config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + return []string{host}, nil + } + config.DialFunc = func(ctx context.Context, _, _ string) (net.Conn, error) { + return client, nil + } + + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + t.Cleanup(func() { + if err := conn.Close(context.Background()); err != nil { + t.Errorf("close conn: %v", err) + } + if serverErr := <-errCh; serverErr != nil { + t.Errorf("mock server: %v", serverErr) + } + }) + + return conn +} + +// findAttr returns the value for the given key in the attribute slice, and +// whether it was found. +func findAttr(attrs []attribute.KeyValue, key string) (attribute.Value, bool) { + for _, a := range attrs { + if string(a.Key) == key { + return a.Value, true + } + } + return attribute.Value{}, false +} + +func TestTracer_spanAttributes(t *testing.T) { + conn := newMockConn(t, "fakehost", 5432, "fakeuser", "fakedb") + + tests := []struct { + name string + opts []Option + drive func(ctx context.Context, tracer *Tracer, conn *pgx.Conn) + wantStrAttrs map[string]string + wantIntAttrs map[string]int64 + absentAttrs []string + }{ + { + name: "query default", + drive: func(ctx context.Context, tracer *Tracer, conn *pgx.Conn) { + ctx = tracer.TraceQueryStart(ctx, conn, pgx.TraceQueryStartData{SQL: "SELECT * FROM users"}) + tracer.TraceQueryEnd(ctx, conn, pgx.TraceQueryEndData{}) + }, + wantStrAttrs: map[string]string{ + "db.system.name": "postgresql", + "server.address": "fakehost", + "user.name": "fakeuser", + "db.namespace": "fakedb", + "db.query.text": "SELECT * FROM users", + "db.operation.name": "SELECT", + }, + wantIntAttrs: map[string]int64{ + "server.port": 5432, + }, + }, + { + name: "query without connection details", + opts: []Option{WithDisableConnectionDetailsInAttributes()}, + drive: func(ctx context.Context, tracer *Tracer, conn *pgx.Conn) { + ctx = tracer.TraceQueryStart(ctx, conn, pgx.TraceQueryStartData{SQL: "SELECT * FROM users"}) + tracer.TraceQueryEnd(ctx, conn, pgx.TraceQueryEndData{}) + }, + wantStrAttrs: map[string]string{ + "db.system.name": "postgresql", + "db.query.text": "SELECT * FROM users", + "db.operation.name": "SELECT", + }, + absentAttrs: []string{"server.address", "server.port", "user.name", "db.namespace"}, + }, + { + name: "query without SQL statement", + opts: []Option{WithDisableSQLStatementInAttributes()}, + drive: func(ctx context.Context, tracer *Tracer, conn *pgx.Conn) { + ctx = tracer.TraceQueryStart(ctx, conn, pgx.TraceQueryStartData{SQL: "SELECT * FROM users"}) + tracer.TraceQueryEnd(ctx, conn, pgx.TraceQueryEndData{}) + }, + wantStrAttrs: map[string]string{ + "db.system.name": "postgresql", + "server.address": "fakehost", + "user.name": "fakeuser", + "db.namespace": "fakedb", + }, + wantIntAttrs: map[string]int64{ + "server.port": 5432, + }, + absentAttrs: []string{"db.query.text", "db.operation.name"}, + }, + { + name: "query with parameters included", + opts: []Option{WithIncludeQueryParameters()}, + drive: func(ctx context.Context, tracer *Tracer, conn *pgx.Conn) { + ctx = tracer.TraceQueryStart(ctx, conn, pgx.TraceQueryStartData{ + SQL: "SELECT * FROM users WHERE id = $1", + Args: []any{42}, + }) + tracer.TraceQueryEnd(ctx, conn, pgx.TraceQueryEndData{}) + }, + wantStrAttrs: map[string]string{ + "db.system.name": "postgresql", + "server.address": "fakehost", + "user.name": "fakeuser", + "db.namespace": "fakedb", + "db.query.text": "SELECT * FROM users WHERE id = $1", + "db.operation.name": "SELECT", + }, + wantIntAttrs: map[string]int64{ + "server.port": 5432, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + t.Cleanup(func() { require.NoError(t, tp.Shutdown(context.Background())) }) + + opts := append([]Option{WithTracerProvider(tp)}, tt.opts...) + tracer := NewTracer(opts...) + + ctx, parentSpan := tp.Tracer("test").Start(context.Background(), "parent") + tt.drive(ctx, tracer, conn) + parentSpan.End() + + spans := exporter.GetSpans() + require.Greater(t, len(spans), 0, "no spans recorded") + + span := spans[0] + + for key, want := range tt.wantStrAttrs { + v, ok := findAttr(span.Attributes, key) + require.Truef(t, ok, "missing attribute %q", key) + require.Equalf(t, want, v.AsString(), "attr %q = %q, want %q", key, v.AsString(), want) + } + + for key, want := range tt.wantIntAttrs { + v, ok := findAttr(span.Attributes, key) + require.Truef(t, ok, "missing attribute %q", key) + require.Equalf(t, want, v.AsInt64(), "attr %q = %q, want %d", key, v.AsInt64(), want) + } + + for _, key := range tt.absentAttrs { + _, ok := findAttr(span.Attributes, key) + require.Falsef(t, ok, "unexpected attribute %q present") + } + }) + } +}