diff --git a/cmd/git-schemalex/main.go b/cmd/git-schemalex/main.go index 49ba239..3d6bba1 100644 --- a/cmd/git-schemalex/main.go +++ b/cmd/git-schemalex/main.go @@ -1,17 +1,22 @@ package main import ( + "context" "flag" "fmt" "log" + "os" + "os/signal" + "syscall" "github.com/schemalex/git-schemalex" ) var ( workspace = flag.String("workspace", "", "workspace of git") + commit = flag.String("commit", "HEAD", "target git commit hash") deploy = flag.Bool("deploy", false, "deploy") - dsn = flag.String("dsn", "", "") + dsn = flag.String("dsn", "root:@tcp(127.0.0.1:3306)/test", "DSN of the target mysql instance") table = flag.String("table", "git_schemalex_version", "table of git revision") schema = flag.String("schema", "", "path to schema file") ) @@ -24,14 +29,31 @@ func main() { } func _main() error { - r := &gitschemalex.Runner{ - Workspace: *workspace, - Deploy: *deploy, - DSN: *dsn, - Table: *table, - Schema: *schema, - } - err := r.Run() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) + + go func() { + select { + case <-ctx.Done(): + return + case <-sigCh: + cancel() + return + } + }() + + r := gitschemalex.New() + r.Workspace = *workspace + r.Commit = *commit + r.Deploy = *deploy + r.DSN = *dsn + r.Table = *table + r.Schema = *schema + + err := r.Run(ctx) if err == gitschemalex.ErrEqualVersion { fmt.Println(err.Error()) return nil diff --git a/gitschemalex.go b/gitschemalex.go index a0b4c3b..2e8db9e 100644 --- a/gitschemalex.go +++ b/gitschemalex.go @@ -2,13 +2,12 @@ package gitschemalex import ( "bytes" + "context" "database/sql" "errors" "fmt" - "io/ioutil" "os" "os/exec" - "path/filepath" "strings" _ "github.com/go-sql-driver/mysql" @@ -22,40 +21,45 @@ var ( type Runner struct { Workspace string + Commit string Deploy bool DSN string Table string Schema string } -func (r *Runner) Run() error { - db, err := r.DB() +func New() *Runner { + return &Runner{ + Commit: "HEAD", + Deploy: false, + } +} +func (r *Runner) Run(ctx context.Context) error { + db, err := r.DB() if err != nil { return err } - defer db.Close() - schemaVersion, err := r.SchemaVersion() - if err != nil { + var schemaVersion string + if err := r.SchemaVersion(ctx, &schemaVersion); err != nil { return err } - dbVersion, err := r.DatabaseVersion(db) - - if err != nil { + var dbVersion string + if err := r.DatabaseVersion(ctx, db, &dbVersion); err != nil { if !strings.Contains(err.Error(), "doesn't exist") { return err } - return r.DeploySchema(db, schemaVersion) + return r.DeploySchema(ctx, db, schemaVersion) } if dbVersion == schemaVersion { return ErrEqualVersion } - if err := r.UpgradeSchema(db, schemaVersion, dbVersion); err != nil { + if err := r.UpgradeSchema(ctx, db, schemaVersion, dbVersion); err != nil { return err } @@ -66,92 +70,83 @@ func (r *Runner) DB() (*sql.DB, error) { return sql.Open("mysql", r.DSN) } -func (r *Runner) DatabaseVersion(db *sql.DB) (version string, err error) { - err = db.QueryRow(fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(&version) - return +func (r *Runner) DatabaseVersion(ctx context.Context, db *sql.DB, version *string) error { + return db.QueryRowContext(ctx, fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(version) } -func (r *Runner) SchemaVersion() (string, error) { - - byt, err := r.execGitCmd("log", "-n", "1", "--pretty=format:%H", "--", r.Schema) +func (r *Runner) SchemaVersion(ctx context.Context, version *string) error { + // git rev-parse takes things like "HEAD" or commit hash, and gives + // us the corresponding commit hash + v, err := r.execGitCmd(ctx, "rev-parse", r.Commit) if err != nil { - return "", err + return err } - return string(byt), nil + *version = string(bytes.TrimSpace(v)) + return nil } -func (r *Runner) DeploySchema(db *sql.DB, version string) error { - content, err := r.schemaContent() - if err != nil { +func (r *Runner) DeploySchema(ctx context.Context, db *sql.DB, version string) error { + var content string + if err := r.schemaSpecificCommit(ctx, version, &content); err != nil { return err } + queries := queryListFromString(content) queries.AppendStmt(fmt.Sprintf("CREATE TABLE `%s` ( version VARCHAR(40) NOT NULL )", r.Table)) queries.AppendStmt(fmt.Sprintf("INSERT INTO `%s` (version) VALUES (?)", r.Table), version) - return r.execSql(db, queries) + return r.execSql(ctx, db, queries) } -func (r *Runner) UpgradeSchema(db *sql.DB, schemaVersion string, dbVersion string) error { - - lastSchema, err := r.schemaSpecificCommit(dbVersion) - if err != nil { +func (r *Runner) UpgradeSchema(ctx context.Context, db *sql.DB, schemaVersion string, dbVersion string) error { + var lastSchema string + if err := r.schemaSpecificCommit(ctx, dbVersion, &lastSchema); err != nil { return err } - currentSchema, err := r.schemaContent() - if err != nil { + var currentSchema string + if err := r.schemaSpecificCommit(ctx, schemaVersion, ¤tSchema); err != nil { return err } stmts := &bytes.Buffer{} p := schemalex.New() - err = diff.Strings(stmts, lastSchema, currentSchema, diff.WithTransaction(true), diff.WithParser(p)) - if err != nil { + if err := diff.Strings(stmts, lastSchema, currentSchema, diff.WithTransaction(true), diff.WithParser(p)); err != nil { return err } queries := queryListFromString(stmts.String()) queries.AppendStmt(fmt.Sprintf("UPDATE %s SET version = ?", r.Table), schemaVersion) - return r.execSql(db, queries) + return r.execSql(ctx, db, queries) } // private -func (r *Runner) schemaSpecificCommit(commit string) (string, error) { - byt, err := r.execGitCmd("ls-tree", commit, "--", r.Schema) - - if err != nil { - return "", err - } - - fields := strings.Fields(string(byt)) - - byt, err = r.execGitCmd("cat-file", "blob", fields[2]) +func (r *Runner) schemaSpecificCommit(ctx context.Context, commit string, dst *string) error { + // Old code used to do ls-tree and then cat-file, but I don't see why + // you need to do this. + // Doing + // > fields := git ls-tree $commit -- $schema_file + // And then taking fields[2] just gives us back $commit. + // showing the contents at the point of commit using "git show" is much simpler + v, err := r.execGitCmd(ctx, "show", fmt.Sprintf("%s:%s", commit, r.Schema)) if err != nil { - return "", err + return err } - return string(byt), nil + *dst = string(v) + return nil } -func (r *Runner) execSql(db *sql.DB, queries queryList) error { +func (r *Runner) execSql(ctx context.Context, db *sql.DB, queries queryList) error { if !r.Deploy { return queries.dump(os.Stdout) } - return queries.execute(db) -} - -func (r *Runner) schemaContent() (string, error) { - byt, err := ioutil.ReadFile(filepath.Join(r.Workspace, r.Schema)) - if err != nil { - return "", err - } - return string(byt), nil + return queries.execute(ctx, db) } -func (r *Runner) execGitCmd(args ...string) ([]byte, error) { - cmd := exec.Command("git", args...) +func (r *Runner) execGitCmd(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "git", args...) if r.Workspace != "" { cmd.Dir = r.Workspace } diff --git a/gitschemalex_test.go b/gitschemalex_test.go index 8059b93..2f862d0 100644 --- a/gitschemalex_test.go +++ b/gitschemalex_test.go @@ -1,6 +1,7 @@ package gitschemalex import ( + "context" "database/sql" "io/ioutil" "os" @@ -84,14 +85,13 @@ func TestRunner(t *testing.T) { // whatever to "test" re := regexp.MustCompile(`/[^/]+$`) dsn = re.ReplaceAllString(dsn, `/test`) - r := &Runner{ - Workspace: dir, - Deploy: true, - DSN: dsn, - Table: "git_schemalex_version", - Schema: "schema.sql", - } - if err := r.Run(); err != nil { + r := New() + r.Workspace = dir + r.Deploy = true + r.DSN = dsn + r.Table = "git_schemalex_version" + r.Schema = "schema.sql" + if err := r.Run(context.TODO()); err != nil { t.Fatal(err) } @@ -113,7 +113,7 @@ func TestRunner(t *testing.T) { t.Fatal(err) } - if err := r.Run(); err != nil { + if err := r.Run(context.TODO()); err != nil { t.Fatal(err) } @@ -122,8 +122,7 @@ func TestRunner(t *testing.T) { } // equal version - - if e, g := ErrEqualVersion, r.Run(); e != g { - t.Fatal("should %v got %v", e, g) + if err := r.Run(context.TODO()); err != ErrEqualVersion { + t.Fatal("should %v got %v", err, ErrEqualVersion) } } diff --git a/query.go b/query.go index 0c5c26a..b3a74db 100644 --- a/query.go +++ b/query.go @@ -1,6 +1,7 @@ package gitschemalex import ( + "context" "database/sql" "fmt" "io" @@ -14,8 +15,8 @@ type query struct { args []interface{} } -func (q *query) execute(db *sql.DB) error { - _, err := db.Exec(q.stmt, q.args...) +func (q *query) execute(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, q.stmt, q.args...) return errors.Wrap(err, `failed to execute query`) } @@ -58,9 +59,9 @@ func (l *queryList) dump(dst io.Writer) error { return nil } -func (l *queryList) execute(db *sql.DB) error { +func (l *queryList) execute(ctx context.Context, db *sql.DB) error { for i, q := range *l { - if err := q.execute(db); err != nil { + if err := q.execute(ctx, db); err != nil { return errors.Wrapf(err, `failed to execute query %d`, i+1) } }