diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6665865 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out + +.idea diff --git a/pglogrepl.go b/pglogrepl.go index bd199b4..5386494 100644 --- a/pglogrepl.go +++ b/pglogrepl.go @@ -13,6 +13,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "io" "strconv" "strings" "time" @@ -393,6 +394,19 @@ type BaseBackupOptions struct { // Disable checksums being verified during a base backup. // Note that NoVerifyChecksums=true is only supported since PG11 NoVerifyChecksums bool + // When this option is specified with a value of yes or force-encode, a backup manifest is created and sent along with the backup. + // The manifest is a list of every file present in the backup with the exception of any WAL files that may be included. + // It also stores the size, last modification time, and optionally a checksum for each file. + // A value of force-encode forces all filenames to be hex-encoded; otherwise, this type of encoding is performed only for files whose names are non-UTF8 octet sequences. + // force-encode is intended primarily for testing purposes, to be sure that clients which read the backup manifest can handle this case. + // For compatibility with previous releases, the default is MANIFEST 'no'. + Manifest bool + // Specifies the checksum algorithm that should be applied to each file included in the backup manifest. + // Currently, the available algorithms are NONE, CRC32C, SHA224, SHA256, SHA384, and SHA512. The default is CRC32C. + ManifestChecksums string + // Requests an incremental backup. + // The UPLOAD_MANIFEST command must be executed before running a base backup with this option. + Incremental bool } func (bbo BaseBackupOptions) sql(serverVersion int) string { @@ -433,6 +447,17 @@ func (bbo BaseBackupOptions) sql(serverVersion int) string { parts = append(parts, "NOVERIFY_CHECKSUMS") } } + if bbo.Manifest { + parts = append(parts, "MANIFEST 'yes'") + if bbo.ManifestChecksums != "" { + parts = append(parts, fmt.Sprintf("MANIFEST_CHECKSUMS '%s'", bbo.ManifestChecksums)) + } + } + if serverVersion >= 17 { + if bbo.Incremental { + parts = append(parts, "INCREMENTAL") + } + } if serverVersion >= 15 { return "BASE_BACKUP(" + strings.Join(parts, ", ") + ")" } @@ -650,6 +675,21 @@ func FinishBaseBackup(ctx context.Context, conn *pgconn.PgConn) (result BaseBack return } +func UploadManifest(ctx context.Context, conn *pgconn.PgConn, r io.Reader) error { + serverVersion, err := serverMajorVersion(conn) + if err != nil { + return err + } + if serverVersion < 17 { + return fmt.Errorf("upload_manifest required version >= 17, current version is: %d", serverVersion) + } + + if _, err := conn.CopyFrom(ctx, r, "UPLOAD_MANIFEST"); err != nil { + return fmt.Errorf("UPLOAD_MANIFEST: %w", err) + } + return nil +} + type PrimaryKeepaliveMessage struct { ServerWALEnd LSN ServerTime time.Time diff --git a/pglogrepl_test.go b/pglogrepl_test.go index ee684b2..37eebcf 100644 --- a/pglogrepl_test.go +++ b/pglogrepl_test.go @@ -1,10 +1,15 @@ package pglogrepl_test import ( + "bytes" "context" + "encoding/json" "fmt" + "io" "os" + "path/filepath" "strconv" + "strings" "testing" "time" @@ -390,6 +395,80 @@ func TestBaseBackup(t *testing.T) { require.NoError(t, err) } +func TestBaseBackupManifest(t *testing.T) { + // base backup test could take a long time. Therefore it can be disabled. + envSkipTest := os.Getenv("PGLOGREPL_SKIP_BASE_BACKUP") + if envSkipTest != "" { + skipTest, err := strconv.ParseBool(envSkipTest) + require.NoError(t, err) + if skipTest { + t.Skip("PGLOGREPL_SKIP_BASE_BACKUP=true, skipping base backup test") + } + } + + // Use timeout so the test cannot hang forever. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + manifestData, filesWritten := streamBB(ctx, t, conn, false) + require.Greater(t, len(manifestData), 1) + require.GreaterOrEqual(t, len(filesWritten), 2) // manifest + base +} + +func TestBaseBackupIncremental(t *testing.T) { + // base backup test could take a long time. Therefore it can be disabled. + envSkipTest := os.Getenv("PGLOGREPL_SKIP_BASE_BACKUP") + if envSkipTest != "" { + skipTest, err := strconv.ParseBool(envSkipTest) + require.NoError(t, err) + if skipTest { + t.Skip("PGLOGREPL_SKIP_BASE_BACKUP=true, skipping base backup test") + } + } + + // Use timeout so the test cannot hang forever. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + // skip when major version < 17 + serverVersion, err := serverMajorVersion(conn) + require.NoError(t, err) + if serverVersion < 17 { + t.Skip() + } + + // create basebackup + manifestData, filesWritten := streamBB(ctx, t, conn, false) + require.Greater(t, len(manifestData), 1) + require.GreaterOrEqual(t, len(filesWritten), 2) // manifest + base + manifestRdr := io.NopCloser(bytes.NewReader([]byte(manifestData))) + + for _, f := range filesWritten { + t.Logf("base. written file: %s\n", f) + } + + // create incremental backup + // 1) upload manifest + err = pglogrepl.UploadManifest(ctx, conn, manifestRdr) + require.NoError(t, err) + // 2) streaming incremental backup + manifestDataIncremental, filesWritten := streamBB(ctx, t, conn, true) + require.Greater(t, len(manifestDataIncremental), 1) + require.GreaterOrEqual(t, len(filesWritten), 2) // manifest + base + + for _, f := range filesWritten { + t.Logf("incremental. written file: %s\n", f) + } +} + func TestSendStandbyStatusUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -404,3 +483,171 @@ func TestSendStandbyStatusUpdate(t *testing.T) { err = pglogrepl.SendStandbyStatusUpdate(ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) require.NoError(t, err) } + +// Helpers + +//nolint:gocritic +func readCString(buf []byte) (string, []byte, error) { + idx := bytes.IndexByte(buf, 0) + if idx < 0 { + return "", nil, fmt.Errorf("invalid CString: %q", string(buf)) + } + return string(buf[:idx]), buf[idx+1:], nil +} + +// Wrap any io.Writer to give it a no-op Close() so it satisfies writeCloser. +type nopCloser struct { + io.Writer +} + +// Generalized writer interface for "current stream target". +type writeCloser interface { + io.Writer + io.Closer +} + +func (n nopCloser) Close() error { + return nil +} + +func streamBB(ctx context.Context, t *testing.T, conn *pgconn.PgConn, incremental bool) (string, []string) { + t.Helper() + + var ( + curTarget writeCloser + curTargetName string + manifestBuf bytes.Buffer + filesWritten []string + ) + + _, err := pglogrepl.StartBaseBackup(ctx, conn, pglogrepl.BaseBackupOptions{ + Label: "pglogrepltest", + Progress: false, + Fast: true, + WAL: false, + NoWait: true, + MaxRate: 0, + TablespaceMap: true, + Manifest: true, + Incremental: incremental, + }) + require.NoError(t, err) + + closeCurrent := func() { + if curTarget == nil { + return + } + if len(curTargetName) > 0 { + filesWritten = append(filesWritten, curTargetName) + } + require.NoError(t, curTarget.Close()) + curTarget = nil + curTargetName = "" + } + + for { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + + switch m := msg.(type) { + case *pgproto3.CopyOutResponse: + // nothing interesting here + continue + + case *pgproto3.CopyData: + switch m.Data[0] { + case 'n': + // New file header (tar member) + closeCurrent() + + filename, rest, err := readCString(m.Data[1:]) + require.NoError(t, err) + + tsPath, _, err := readCString(rest) + require.NoError(t, err) + + if !strings.Contains(filename, "base") { + assert.Greater(t, len(tsPath), 1) + } + + bbTar := strings.TrimPrefix(filename, "./") + + // Still write backup files to temp, but we don't care about contents in this test. + f, err := os.CreateTemp("", "*-"+bbTar) + require.NoError(t, err) + curTarget = f + curTargetName = filepath.ToSlash(f.Name()) + + case 'd': + // File or manifest data + require.NotNil(t, curTarget, "received data but no active writer") + + _, err := curTarget.Write(m.Data[1:]) + require.NoError(t, err) + + case 'm': + // Switch to manifest stream -> write into buffer instead of a file. + closeCurrent() + manifestBuf.Reset() + curTarget = nopCloser{Writer: &manifestBuf} + + case 'p': + // only if Progress: true (we disabled Progress above) + + default: + // unexpected data type – fail fast so we don't spin forever + t.Fatalf("unexpected CopyData message type: %q", m.Data[0]) + } + + case *pgproto3.CopyDone: + // backup stream complete + closeCurrent() + + _, err := pglogrepl.FinishBaseBackup(ctx, conn) + require.NoError(t, err) + + // assert manifest is meaningful + + // 1) non-empty + manStr := strings.TrimSpace(manifestBuf.String()) + require.NotEmpty(t, manStr, "manifest should not be empty") + + // save for inspecting + f, err := os.CreateTemp(os.TempDir(), "manifest") + require.NoError(t, err) + _, err = f.Write(manifestBuf.Bytes()) + require.NoError(t, err) + filesWritten = append(filesWritten, filepath.ToSlash(f.Name())) + + // 2) valid json with some keys + var manifestJSON map[string]any + err = json.Unmarshal(manifestBuf.Bytes(), &manifestJSON) + require.NoError(t, err, "manifest must be valid JSON") + require.NotEmpty(t, manifestJSON, "manifest JSON must have at least one key") + + // 3) expect keys + _, hasVersion := manifestJSON["PostgreSQL-Backup-Manifest-Version"] + assert.True(t, hasVersion, "manifest should contain 'PostgreSQL-Backup-Manifest-Version' field") + + // 4) check incremental + if incremental { + assert.True(t, strings.Contains(manStr, "INCREMENTAL")) + } + + return manifestBuf.String(), filesWritten + + default: + // For this test, any other message is unexpected; better to fail than hang. + t.Fatalf("unexpected message type: %T", msg) + } + } +} + +func serverMajorVersion(conn *pgconn.PgConn) (int, error) { + verString := conn.ParameterStatus("server_version") + dot := strings.IndexByte(verString, '.') + if dot == -1 { + return 0, fmt.Errorf("bad server version string: '%s'", verString) + } + return strconv.Atoi(verString[:dot]) +}