diff --git a/engine.go b/engine.go index f09c167c5f..9cdd194ee2 100644 --- a/engine.go +++ b/engine.go @@ -120,7 +120,7 @@ func (e *Engine) Query( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables, *plan.CreateView, *plan.DropView: perm = auth.ReadPerm | auth.WritePerm } diff --git a/engine_test.go b/engine_test.go index e88cc2404a..a5151d17f8 100644 --- a/engine_test.go +++ b/engine_test.go @@ -3661,17 +3661,19 @@ func TestReadOnly(t *testing.T) { _, _, err := e.Query(newCtx(), `SELECT i FROM mytable`) require.NoError(err) - _, _, err = e.Query(newCtx(), `CREATE INDEX foo ON mytable USING pilosa (i, s)`) - require.Error(err) - require.True(auth.ErrNotAuthorized.Is(err)) - - _, _, err = e.Query(newCtx(), `DROP INDEX foo ON mytable`) - require.Error(err) - require.True(auth.ErrNotAuthorized.Is(err)) + writingQueries := []string{ + `CREATE INDEX foo ON mytable USING pilosa (i, s)`, + `DROP INDEX foo ON mytable`, + `INSERT INTO mytable (i, s) VALUES(42, 'yolo')`, + `CREATE VIEW myview AS SELECT i FROM mytable`, + `DROP VIEW myview`, + } - _, _, err = e.Query(newCtx(), `INSERT INTO mytable (i, s) VALUES(42, 'yolo')`) - require.Error(err) - require.True(auth.ErrNotAuthorized.Is(err)) + for _, query := range writingQueries { + _, _, err = e.Query(newCtx(), query) + require.Error(err) + require.True(auth.ErrNotAuthorized.Is(err)) + } } func TestSessionVariables(t *testing.T) { diff --git a/sql/analyzer/assign_catalog.go b/sql/analyzer/assign_catalog.go index 0a8f76ee81..1accc5ed86 100644 --- a/sql/analyzer/assign_catalog.go +++ b/sql/analyzer/assign_catalog.go @@ -60,6 +60,14 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) nc := *node nc.Catalog = a.Catalog return &nc, nil + case *plan.CreateView: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil + case *plan.DropView: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil default: return n, nil } diff --git a/sql/analyzer/assign_catalog_test.go b/sql/analyzer/assign_catalog_test.go index 39aed1a4f5..22c66059e9 100644 --- a/sql/analyzer/assign_catalog_test.go +++ b/sql/analyzer/assign_catalog_test.go @@ -73,4 +73,18 @@ func TestAssignCatalog(t *testing.T) { ut, ok := node.(*plan.UnlockTables) require.True(ok) require.Equal(c, ut.Catalog) + + mockSubquery := plan.NewSubqueryAlias("mock", plan.NewResolvedTable(tbl)) + mockView := plan.NewCreateView(db, "", nil, mockSubquery, false) + node, err = f.Apply(sql.NewEmptyContext(), a, mockView) + require.NoError(err) + cv, ok := node.(*plan.CreateView) + require.True(ok) + require.Equal(c, cv.Catalog) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewDropView(nil, false)) + require.NoError(err) + dv, ok := node.(*plan.DropView) + require.True(ok) + require.Equal(c, dv.Catalog) } diff --git a/sql/analyzer/resolve_tables.go b/sql/analyzer/resolve_tables.go index ec495cd362..0ab516a9c8 100644 --- a/sql/analyzer/resolve_tables.go +++ b/sql/analyzer/resolve_tables.go @@ -39,17 +39,26 @@ func resolveTables(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) } rt, err := a.Catalog.Table(db, name) - if err != nil { - if sql.ErrTableNotFound.Is(err) && name == dualTableName { + if err == nil { + a.Log("table resolved: %q", t.Name()) + return plan.NewResolvedTable(rt), nil + } + + if sql.ErrTableNotFound.Is(err) { + if name == dualTableName { rt = dualTable name = dualTableName - } else { - return nil, err + + a.Log("table resolved: %q", t.Name()) + return plan.NewResolvedTable(rt), nil } - } - a.Log("table resolved: %q", t.Name()) + if view, err := a.Catalog.ViewRegistry.View(db, name); err == nil { + a.Log("table %q is a view: replacing plans", t.Name()) + return view.Definition(), nil + } + } - return plan.NewResolvedTable(rt), nil + return nil, err }) } diff --git a/sql/analyzer/resolve_tables_test.go b/sql/analyzer/resolve_tables_test.go index fbc4fcca0e..d2fc34fedb 100644 --- a/sql/analyzer/resolve_tables_test.go +++ b/sql/analyzer/resolve_tables_test.go @@ -91,3 +91,45 @@ func TestResolveTablesNested(t *testing.T) { ) require.Equal(expected, analyzed) } + +func TestResolveViews(t *testing.T) { + require := require.New(t) + + f := getRule("resolve_tables") + + table := memory.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) + db := memory.NewDatabase("mydb") + db.AddTable("mytable", table) + + // Resolved plan that corresponds to query "SELECT i FROM mytable" + subquery := plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable( + 1, sql.Int32, table.Name(), "i", true), + }, + plan.NewResolvedTable(table), + ) + subqueryAlias := plan.NewSubqueryAlias("myview", subquery) + view := sql.NewView("myview", subqueryAlias) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + err := catalog.ViewRegistry.Register(db.Name(), view) + require.NoError(err) + + a := NewBuilder(catalog).AddPostAnalyzeRule(f.Name, f.Apply).Build() + + var notAnalyzed sql.Node = plan.NewUnresolvedTable("myview", "") + analyzed, err := f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + require.Equal(subqueryAlias, analyzed) + + notAnalyzed = plan.NewUnresolvedTable("MyVieW", "") + analyzed, err = f.Apply(sql.NewEmptyContext(), a, notAnalyzed) + require.NoError(err) + require.Equal(subqueryAlias, analyzed) + + analyzed, err = f.Apply(sql.NewEmptyContext(), a, subqueryAlias) + require.NoError(err) + require.Equal(subqueryAlias, analyzed) +} diff --git a/sql/catalog.go b/sql/catalog.go index 4be0bad629..abfaf6fa6b 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -18,6 +18,7 @@ var ErrDatabaseNotFound = errors.NewKind("database not found: %s") type Catalog struct { FunctionRegistry *IndexRegistry + *ViewRegistry *ProcessList *MemoryManager @@ -38,6 +39,7 @@ func NewCatalog() *Catalog { return &Catalog{ FunctionRegistry: NewFunctionRegistry(), IndexRegistry: NewIndexRegistry(), + ViewRegistry: NewViewRegistry(), MemoryManager: NewMemoryManager(ProcessMemory), ProcessList: NewProcessList(), locks: make(sessionLocks), diff --git a/sql/index.go b/sql/index.go index cd5512f507..97ee716ae4 100644 --- a/sql/index.go +++ b/sql/index.go @@ -589,7 +589,7 @@ func exprListsEqual(a, b []string) bool { // marked as creating, so nobody can't register two indexes with the same // expression or id while the other is still being created. // When something is sent through the returned channel, it means the index has -// finished it's creation and will be marked as ready. +// finished its creation and will be marked as ready. // Another channel is returned to notify the user when the index is ready. func (r *IndexRegistry) AddIndex( idx Index, diff --git a/sql/parse/parse.go b/sql/parse/parse.go index dd8532e5a3..3a262ff804 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -47,7 +47,6 @@ var ( unlockTablesRegex = regexp.MustCompile(`^unlock\s+tables$`) lockTablesRegex = regexp.MustCompile(`^lock\s+tables\s`) setRegex = regexp.MustCompile(`^set\s+`) - createViewRegex = regexp.MustCompile(`^create\s+view\s+`) ) // These constants aren't exported from vitess for some reason. This could be removed if we changed this. @@ -104,9 +103,6 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseLockTables(ctx, s) case setRegex.MatchString(lowerQuery): s = fixSetQuery(s) - case createViewRegex.MatchString(lowerQuery): - // CREATE VIEW parses as a CREATE DDL statement with an empty table spec - return nil, ErrUnsupportedFeature.New("CREATE VIEW") } stmt, err := sqlparser.Parse(s) @@ -163,7 +159,7 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node if err != nil { return nil, err } - return convertDDL(ddl.(*sqlparser.DDL)) + return convertDDL(ctx, ddl.(*sqlparser.DDL)) case *sqlparser.Set: return convertSet(ctx, n) case *sqlparser.Use: @@ -369,11 +365,17 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { return node, nil } -func convertDDL(c *sqlparser.DDL) (sql.Node, error) { +func convertDDL(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { switch c.Action { case sqlparser.CreateStr: + if !c.View.IsEmpty() { + return convertCreateView(ctx, c) + } return convertCreateTable(c) case sqlparser.DropStr: + if len(c.FromViews) != 0 { + return convertDropView(ctx, c) + } return convertDropTable(c) default: return nil, ErrUnsupportedSyntax.New(c) @@ -398,6 +400,31 @@ func convertCreateTable(c *sqlparser.DDL) (sql.Node, error) { sql.UnresolvedDatabase(""), c.Table.Name.String(), schema), nil } +func convertCreateView(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { + selectStatement, ok := c.ViewExpr.(*sqlparser.Select) + if !ok { + return nil, ErrUnsupportedSyntax.New(c.ViewExpr) + } + + queryNode, err := convertSelect(ctx, selectStatement) + if err != nil { + return nil, err + } + + queryAlias := plan.NewSubqueryAlias(c.View.Name.String(), queryNode) + + return plan.NewCreateView( + sql.UnresolvedDatabase(""), c.View.Name.String(), []string{}, queryAlias, c.OrReplace), nil +} + +func convertDropView(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { + plans := make([]sql.Node, len(c.FromViews)) + for i, v := range c.FromViews { + plans[i] = plan.NewSingleDropView(sql.UnresolvedDatabase(""), v.Name.String()) + } + return plan.NewDropView(plans, c.IfExists), nil +} + func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { if len(i.OnDup) > 0 { return nil, ErrUnsupportedFeature.New("ON DUPLICATE KEY") diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 929092c60c..7fe9453b22 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -55,14 +55,14 @@ var fixtures = map[string]sql.Node{ sql.UnresolvedDatabase(""), "t1", sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, + Name: "a", + Type: sql.Int32, + Nullable: false, PrimaryKey: true, }, { - Name: "b", - Type: sql.Text, - Nullable: true, + Name: "b", + Type: sql.Text, + Nullable: true, PrimaryKey: false, }}, ), @@ -70,14 +70,14 @@ var fixtures = map[string]sql.Node{ sql.UnresolvedDatabase(""), "t1", sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, + Name: "a", + Type: sql.Int32, + Nullable: false, PrimaryKey: true, }, { - Name: "b", - Type: sql.Text, - Nullable: true, + Name: "b", + Type: sql.Text, + Nullable: true, PrimaryKey: false, }}, ), @@ -85,14 +85,14 @@ var fixtures = map[string]sql.Node{ sql.UnresolvedDatabase(""), "t1", sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, + Name: "a", + Type: sql.Int32, + Nullable: false, PrimaryKey: true, }, { - Name: "b", - Type: sql.Text, - Nullable: false, + Name: "b", + Type: sql.Text, + Nullable: false, PrimaryKey: true, }}, ), @@ -1321,7 +1321,7 @@ var fixturesErrors = map[string]*errors.Kind{ `SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax, `SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax, `SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax, - `CREATE VIEW view1 AS SELECT x FROM t1 WHERE x>0`: ErrUnsupportedFeature, + `CREATE VIEW myview AS SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax, } func TestParseErrors(t *testing.T) { diff --git a/sql/parse/util.go b/sql/parse/util.go index bfb358b1d5..b14375a2af 100644 --- a/sql/parse/util.go +++ b/sql/parse/util.go @@ -64,6 +64,14 @@ func expect(expected string) parseFunc { } func skipSpaces(r *bufio.Reader) error { + var unusedCount int + return readSpaces(r, &unusedCount) +} + +// readSpaces reads every contiguous space from the reader, populating +// numSpacesRead with the number of spaces read. +func readSpaces(r *bufio.Reader, numSpacesRead *int) error { + *numSpacesRead = 0 for { ru, _, err := r.ReadRune() if err == io.EOF { @@ -77,6 +85,7 @@ func skipSpaces(r *bufio.Reader) error { if !unicode.IsSpace(ru) { return r.UnreadRune() } + *numSpacesRead++ } } @@ -127,6 +136,29 @@ func readLetter(r *bufio.Reader, buf *bytes.Buffer) error { return nil } +// readLetterOrPoint parses a single rune from the reader and consumes it, +// copying it to the buffer, if it is either a letter or a point +func readLetterOrPoint(r *bufio.Reader, buf *bytes.Buffer) error { + ru, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + return nil + } + + return err + } + + if !unicode.IsLetter(ru) && ru != '.' { + if err := r.UnreadRune(); err != nil { + return err + } + return nil + } + + buf.WriteRune(ru) + return nil +} + func readValidIdentRune(r *bufio.Reader, buf *bytes.Buffer) error { ru, _, err := r.ReadRune() if err != nil { @@ -144,6 +176,27 @@ func readValidIdentRune(r *bufio.Reader, buf *bytes.Buffer) error { return nil } +// readValidScopedIdentRune parses a single rune from the reader and consumes +// it, copying it to the buffer, if is a letter, a digit, an underscore or the +// specified separator. +// If the returned error is not nil, the returned rune equals the null +// character. +func readValidScopedIdentRune(r *bufio.Reader, separator rune) (rune, error) { + ru, _, err := r.ReadRune() + if err != nil { + return 0, err + } + + if !unicode.IsLetter(ru) && !unicode.IsDigit(ru) && ru != '_' && ru != separator { + if err := r.UnreadRune(); err != nil { + return 0, err + } + return 0, io.EOF + } + + return ru, nil +} + func readValidQuotedIdentRune(r *bufio.Reader, buf *bytes.Buffer) error { bs, err := r.Peek(2) if err != nil { @@ -199,6 +252,44 @@ func readIdent(ident *string) parseFunc { } } +// readIdentList reads a scoped identifier, populating the specified slice +// with the different parts of the identifier if it is correctly formed. +// A scoped identifier is a sequence of identifiers separated by the specified +// rune in separator. An identifier is a string of runes whose first character +// is a letter and the following ones are either letters, digits or underscores. +// An example of a correctly formed scoped identifier is "dbName.tableName", +// that would populate the slice with the values ["dbName", "tableName"] +func readIdentList(separator rune, idents *[]string) parseFunc { + return func(r *bufio.Reader) error { + var buf bytes.Buffer + if err := readLetter(r, &buf); err != nil { + return err + } + + for { + currentRune, err := readValidScopedIdentRune(r, separator) + if err != nil { + if err == io.EOF { + break + } + return err + } + + if currentRune == separator { + *idents = append(*idents, buf.String()) + buf.Reset() + } else { + buf.WriteRune(currentRune) + } + } + + if readString := buf.String(); len(readString) > 0 { + *idents = append(*idents, readString) + } + return nil + } +} + func readQuotedIdent(ident *string) parseFunc { return func(r *bufio.Reader) error { var buf bytes.Buffer @@ -308,3 +399,182 @@ func expectQuote(r *bufio.Reader) error { return nil } + +// maybe tries to read the specified string, consuming the reader if the string +// is found. The `matched` boolean is set to true if the string is found +func maybe(matched *bool, str string) parseFunc { + return func(rd *bufio.Reader) error { + *matched = false + strLength := len(str) + + data, err := rd.Peek(strLength) + if err != nil { + // If there are not enough runes, what we expected was not there, which + // is not an error per se. + if len(data) < strLength { + return nil + } + + return err + } + + if strings.ToLower(string(data)) == str { + _, err := rd.Discard(strLength) + if err != nil { + return err + } + + *matched = true + return nil + } + + return nil + } +} + +// multiMaybe tries to read the specified strings, one after the other, +// separated by an arbitrary number of spaces. It consumes the reader if and +// only if all the strings are found. +func multiMaybe(matched *bool, strings ...string) parseFunc { + return func(rd *bufio.Reader) error { + *matched = false + var read string + for _, str := range strings { + if err := maybe(matched, str)(rd); err != nil { + return err + } + + if !*matched { + unreadString(rd, read) + return nil + } + + var numSpaces int + if err := readSpaces(rd, &numSpaces); err != nil { + return err + } + + read = read + str + for i := 0; i < numSpaces; i++ { + read = read + " " + } + } + *matched = true + return nil + } +} + +// maybeList reads a list of strings separated by the specified separator, with +// a rune indicating the opening of the list and another one specifying its +// closing. +// For example, readList('(', ',', ')', list) parses "(uno, dos,tres)" and +// populates list with the array of strings ["uno", "dos", "tres"] +// If the opening is not found, this does not consumes any rune from the +// reader. If there is a parsing error after some elements were found, the list +// is partially populated with the correct fields +func maybeList(opening, separator, closing rune, list *[]string) parseFunc { + return func(rd *bufio.Reader) error { + r, _, err := rd.ReadRune() + if err != nil { + return err + } + + if r != opening { + return rd.UnreadRune() + } + + for { + var newItem string + err := parseFuncs{ + skipSpaces, + readIdent(&newItem), + skipSpaces, + }.exec(rd) + + if err != nil { + return err + } + + r, _, err := rd.ReadRune() + if err != nil { + return err + } + + switch r { + case closing: + *list = append(*list, newItem) + return nil + case separator: + *list = append(*list, newItem) + continue + default: + return errUnexpectedSyntax.New( + fmt.Sprintf("%v or %v", separator, closing), + string(r), + ) + } + } + } +} + +// A qualifiedName represents an identifier of type "db_name.table_name" +type qualifiedName struct { + qualifier string + name string +} + +// readQualifiedIdentifierList reads a comma-separated list of qualifiedNames. +// Any number of spaces between the qualified names are accepted. The qualifier +// may be empty, in which case the period is optional. +// An example of a correctly formed list is: +// "my_db.myview, db_2.mytable , aTable" +func readQualifiedIdentifierList(list *[]qualifiedName) parseFunc { + return func(rd *bufio.Reader) error { + for { + var newItem []string + err := parseFuncs{ + skipSpaces, + readIdentList('.', &newItem), + skipSpaces, + }.exec(rd) + + if err != nil { + return err + } + + if len(newItem) < 1 || len(newItem) > 2 { + return errUnexpectedSyntax.New( + "[qualifier.]name", + strings.Join(newItem, "."), + ) + } + + var qualifier, name string + + if len(newItem) == 1 { + qualifier = "" + name = newItem[0] + } else { + qualifier = newItem[0] + name = newItem[1] + } + + *list = append(*list, qualifiedName{qualifier, name}) + + r, _, err := rd.ReadRune() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + switch r { + case ',': + continue + default: + return rd.UnreadRune() + } + } + } +} diff --git a/sql/parse/util_test.go b/sql/parse/util_test.go new file mode 100644 index 0000000000..3cb5f49f8b --- /dev/null +++ b/sql/parse/util_test.go @@ -0,0 +1,524 @@ +package parse + +import ( + "bufio" + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// Tests that readLetterOrPoint reads only letters and points, not consuming +// the reader when the rune is not of those kinds +func TestReadLetterOrPoint(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + string string + expectedBuffer string + expectedRemaining string + }{ + { + "asd.ASD.ñu", + "asd.ASD.ñu", + "", + }, + { + "5anytext", + "", + "5anytext", + }, + { + "", + "", + "", + }, + { + "as df", + "as", + " df", + }, + { + "a.s df", + "a.s", + " df", + }, + { + "a.s-", + "a.s", + "-", + }, + { + "a.s_", + "a.s", + "_", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.string)) + var buffer bytes.Buffer + + for i := 0; i < len(fixture.string); i++ { + err := readLetterOrPoint(reader, &buffer) + require.NoError(err) + } + + remaining, _ := reader.ReadString('\n') + require.Equal(remaining, fixture.expectedRemaining) + + require.Equal(buffer.String(), fixture.expectedBuffer) + } +} + +// Tests that readValidScopedIdentRune reads a single rune that is either part +// of an identifier or the specified separator. It checks that the function +// does not consume the reader when it encounters any other rune +func TestReadValidScopedIdentRune(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + string string + separator rune + expectedBuffer string + expectedRemaining string + expectedError bool + }{ + { + "ident_1.ident_2", + '.', + "ident_1.ident_2", + "", + false, + }, + { + "$ident_1.ident_2", + '.', + "", + "$ident_1.ident_2", + true, + }, + { + "", + '.', + "", + "", + false, + }, + { + "ident_1 ident_2", + '.', + "ident_1", + " ident_2", + true, + }, + { + "ident_1 ident_2", + ' ', + "ident_1 ident_2", + "", + false, + }, + { + "ident_1.ident_2 ident_3", + '.', + "ident_1.ident_2", + " ident_3", + true, + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.string)) + var buffer bytes.Buffer + + var rune rune + var err error + for i := 0; i < len(fixture.string); i++ { + if rune, err = readValidScopedIdentRune(reader, fixture.separator); err != nil { + break + } + + buffer.WriteRune(rune) + } + if fixture.expectedError { + require.Error(err) + } else { + require.NoError(err) + } + + remaining, _ := reader.ReadString('\n') + require.Equal(remaining, fixture.expectedRemaining) + + require.Equal(fixture.expectedBuffer, buffer.String()) + } +} + +// Tests that readIdentList reads a list of identifiers separated by a user- +// specified rune, populating the passed slice with the identifiers found. +func TestReadIdentList(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + string string + separator rune + expectedIdents []string + expectedRemaining string + }{ + { + "ident_1.ident_2", + '.', + []string{"ident_1", "ident_2"}, + "", + }, + { + "$ident_1.ident_2", + '.', + nil, + "$ident_1.ident_2", + }, + { + "", + '.', + nil, + "", + }, + { + "ident_1 ident_2", + '.', + []string{"ident_1"}, + " ident_2", + }, + { + "ident_1 ident_2", + ' ', + []string{"ident_1", "ident_2"}, + "", + }, + { + "ident_1.ident_2 ident_3", + '.', + []string{"ident_1", "ident_2"}, + " ident_3", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.string)) + var actualIdents []string + + err := readIdentList(fixture.separator, &actualIdents)(reader) + require.NoError(err) + + remaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, remaining) + + require.Equal(fixture.expectedIdents, actualIdents) + } +} + +// Tests that maybe reads, and consumes, the specified string, if and only if +// it is there, reporting the result in the boolean passed. +func TestMaybe(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + input string + maybeString string + expectedMatched bool + expectedRemaining string + }{ + { + "ident_1.ident_2", + "ident_1", + true, + ".ident_2", + }, + { + "ident_1.ident_2", + "random", + false, + "ident_1.ident_2", + }, + { + "ident_1.ident_2", + "ident_1.ident_2", + true, + "", + }, + { + "ident_1", + "ident_1butlonger", + false, + "ident_1", + }, + { + "ident_1", + "", + true, + "ident_1", + }, + { + "", + "", + true, + "", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.input)) + var actualMatched bool + + err := maybe(&actualMatched, fixture.maybeString)(reader) + require.NoError(err) + + remaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, remaining) + + require.Equal(fixture.expectedMatched, actualMatched) + } +} + +// Tests that multiMaybe reads, and consumes, the list of strings passed if all +// of them are in the reader, reporting the result in the boolean passed. +func TestMultiMaybe(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + input string + maybeStrings []string + expectedMatched bool + expectedRemaining string + }{ + { + "unodostres", + []string{"uno", "dos", "tres"}, + true, + "", + }, + { + "uno dos tres", + []string{"uno", "dos", "tres"}, + true, + "", + }, + { + "uno dos tres", + []string{"uno", "dos", "tres"}, + true, + "", + }, + { + "uno dos tres", + []string{"random"}, + false, + "uno dos tres", + }, + { + "uno dos tres", + []string{"uno", "random"}, + false, + "uno dos tres", + }, + { + "uno dos tres", + []string{"uno", "dos", "tres", "cuatro"}, + false, + "uno dos tres", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.input)) + var actualMatched bool + + err := multiMaybe(&actualMatched, fixture.maybeStrings...)(reader) + require.NoError(err) + + remaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, remaining) + + require.Equal(fixture.expectedMatched, actualMatched) + } +} + +// Tests that maybeList reads the specified list of strings separated by the +// user-specified separator, not consuming the reader if the opening rune is +// not found. It checks that the function populates the list with the found +// strings even if there is an error in the middle of the parsing. +func TestMaybeList(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + stringWithList string + openingRune rune + separatorRune rune + closingRune rune + expectedList []string + expectedError bool + }{ + { + "(uno, dos, tres)", + '(', ',', ')', + []string{"uno", "dos", "tres"}, + false, + }, + { + "-uno&dos & tres-", + '-', '&', '-', + []string{"uno", "dos", "tres"}, + false, + }, + { + "-(uno, dos, tres)", + '(', ',', ')', + nil, + false, + }, + { + "(uno, dos,( tres)", + '(', ',', ')', + []string{"uno", "dos"}, + true, + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.stringWithList)) + var actualList []string + + err := maybeList(fixture.openingRune, fixture.separatorRune, fixture.closingRune, &actualList)(reader) + + if fixture.expectedError { + require.Error(err) + } else { + require.NoError(err) + } + + require.Equal(fixture.expectedList, actualList) + } +} + +// Tests that readSpaces consumes all the spaces it ecounters in the reader, +// reporting the number of spaces read to the user through the integer passed. +func TestReadSpaces(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + stringWithSpaces string + runesBeforeSpaces int + expectedNumSpaces int + expectedRemaining string + }{ + { + "one", + 3, 0, + "", + }, + { + "two", + 0, 0, + "two", + }, + { + " three", + 0, 3, + "three", + }, + { + "four four ", + 4, 4, + "four ", + }, + { + "five ", + 4, 5, + "", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.stringWithSpaces)) + var actualNumSpaces int + + // Check that readSpaces does not read spaces when there are none + if fixture.runesBeforeSpaces > 0 { + err := readSpaces(reader, &actualNumSpaces) + require.NoError(err) + require.Equal(0, actualNumSpaces) + } + + // Read all the runes before the spaces + for i := 0; i < fixture.runesBeforeSpaces; i++ { + _, _, err := reader.ReadRune() + require.NoError(err) + } + + // Read all the spaces + err := readSpaces(reader, &actualNumSpaces) + require.NoError(err) + require.Equal(fixture.expectedNumSpaces, actualNumSpaces) + + actualRemaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, actualRemaining) + } +} + +// Tests that readQualifiedIdentifierList correctly parses well-formed lists, +// populating the list of identifiers, and that it errors with partial lists +// and when it does not found any identifiers +func TestReadQualifiedIdentifierList(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + string string + expectedList []qualifiedName + expectedError bool + expectedRemaining string + }{ + { + "my_db.myview, db_2.mytable , aTable", + []qualifiedName{{"my_db", "myview"}, {"db_2", "mytable"}, {"", "aTable"}}, + false, + "", + }, + { + "single_identifier -remaining", + []qualifiedName{{"", "single_identifier"}}, + false, + "-remaining", + }, + { + "", + nil, + true, + "", + }, + { + "partial_list,", + []qualifiedName{{"", "partial_list"}}, + true, + "", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.string)) + var actualList []qualifiedName + + err := readQualifiedIdentifierList(&actualList)(reader) + + if fixture.expectedError { + require.Error(err) + } else { + require.NoError(err) + } + + require.Equal(fixture.expectedList, actualList) + + actualRemaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, actualRemaining) + } +} diff --git a/sql/plan/create_view.go b/sql/plan/create_view.go new file mode 100644 index 0000000000..6d7e7c8bde --- /dev/null +++ b/sql/plan/create_view.go @@ -0,0 +1,116 @@ +package plan + +import ( + "fmt" + "strings" + + "github.com/src-d/go-mysql-server/sql" +) + +// CreateView is a node representing the creation (or replacement) of a view, +// which is defined by the Child node. The Columns member represent the +// explicit columns specified by the query, if any. +type CreateView struct { + UnaryNode + database sql.Database + Name string + Columns []string + Catalog *sql.Catalog + IsReplace bool +} + +// NewCreateView creates a CreateView node with the specified parameters, +// setting its catalog to nil. +func NewCreateView( + database sql.Database, + name string, + columns []string, + definition *SubqueryAlias, + isReplace bool, +) *CreateView { + return &CreateView{ + UnaryNode{Child: definition}, + database, + name, + columns, + nil, + isReplace, + } +} + +// View returns the view that will be created by this node. +func (cv *CreateView) View() sql.View { + return sql.NewView(cv.Name, cv.Child) +} + +// Children implements the Node interface. It returns the Child of the +// CreateView node; i.e., the definition of the view that will be created. +func (cv *CreateView) Children() []sql.Node { + return []sql.Node{cv.Child} +} + +// Resolved implements the Node interface. This node is resolved if and only if +// the database and the Child are both resolved. +func (cv *CreateView) Resolved() bool { + _, ok := cv.database.(sql.UnresolvedDatabase) + return !ok && cv.Child.Resolved() +} + +// RowIter implements the Node interface. When executed, this function creates +// (or replaces) the view. It can error if the CraeteView's IsReplace member is +// set to false and the view already exists. The RowIter returned is always +// empty. +func (cv *CreateView) RowIter(ctx *sql.Context) (sql.RowIter, error) { + view := sql.NewView(cv.Name, cv.Child) + registry := cv.Catalog.ViewRegistry + + if cv.IsReplace { + err := registry.Delete(cv.database.Name(), view.Name()) + if err != nil && !sql.ErrNonExistingView.Is(err) { + return sql.RowsToRowIter(), err + } + } + + return sql.RowsToRowIter(), registry.Register(cv.database.Name(), view) +} + +// Schema implements the Node interface. It always returns nil. +func (cv *CreateView) Schema() sql.Schema { return nil } + +// String implements the fmt.Stringer interface, using sql.TreePrinter to +// generate the string. +func (cv *CreateView) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("CreateView(%s)", cv.Name) + _ = pr.WriteChildren( + fmt.Sprintf("Columns (%s)", strings.Join(cv.Columns, ", ")), + cv.Child.String(), + ) + return pr.String() +} + +// WithChildren implements the Node interface. It only succeeds if the length +// of the specified children equals 1. +func (cv *CreateView) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(cv, len(children), 1) + } + + newCreate := cv + newCreate.Child = children[0] + return newCreate, nil +} + +// Database implements the Databaser interface, and it returns the database in +// which CreateView will create the view. +func (cv *CreateView) Database() sql.Database { + return cv.database +} + +// Database implements the Databaser interface, and it returns a copy of this +// node with the specified database. +func (cv *CreateView) WithDatabase(database sql.Database) (sql.Node, error) { + newCreate := *cv + newCreate.database = database + return &newCreate, nil +} diff --git a/sql/plan/create_view_test.go b/sql/plan/create_view_test.go new file mode 100644 index 0000000000..d9b4444a0f --- /dev/null +++ b/sql/plan/create_view_test.go @@ -0,0 +1,94 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + + "github.com/stretchr/testify/require" +) + +func mockCreateView(isReplace bool) *CreateView { + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Source: "mytable", Type: sql.Int32}, + {Name: "s", Source: "mytable", Type: sql.Text}, + }) + + db := memory.NewDatabase("db") + db.AddTable("db", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + subqueryAlias := NewSubqueryAlias("myview", + NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int32, table.Name(), "i", true), + }, + NewUnresolvedTable("dual", ""), + ), + ) + + createView := NewCreateView(db, subqueryAlias.Name(), nil, subqueryAlias, isReplace) + createView.Catalog = catalog + + return createView +} + +// Tests that CreateView works as expected and that the view is registered in +// the catalog when RowIter is called +func TestCreateView(t *testing.T) { + require := require.New(t) + + createView := mockCreateView(false) + + ctx := sql.NewEmptyContext() + _, err := createView.RowIter(ctx) + require.NoError(err) + + expectedView := sql.NewView(createView.Name, createView.Child) + actualView, err := createView.Catalog.ViewRegistry.View(createView.database.Name(), createView.Name) + require.NoError(err) + require.Equal(expectedView, *actualView) +} + +// Tests that CreateView RowIter returns an error when the view exists +func TestCreateExistingView(t *testing.T) { + require := require.New(t) + + createView := mockCreateView(false) + + view := createView.View() + err := createView.Catalog.ViewRegistry.Register(createView.database.Name(), view) + require.NoError(err) + + ctx := sql.NewEmptyContext() + _, err = createView.RowIter(ctx) + require.Error(err) + require.True(sql.ErrExistingView.Is(err)) +} + +// Tests that CreateView RowIter succeeds when the view exists and the +// IsReplace flag is set to true +func TestReplaceExistingView(t *testing.T) { + require := require.New(t) + + createView := mockCreateView(true) + + view := sql.NewView(createView.Name, nil) + err := createView.Catalog.ViewRegistry.Register(createView.database.Name(), view) + require.NoError(err) + + createView.IsReplace = true + + ctx := sql.NewEmptyContext() + _, err = createView.RowIter(ctx) + require.NoError(err) + + expectedView := createView.View() + actualView, err := createView.Catalog.ViewRegistry.View(createView.database.Name(), createView.Name) + require.NoError(err) + require.Equal(expectedView, *actualView) +} diff --git a/sql/plan/drop_view.go b/sql/plan/drop_view.go new file mode 100644 index 0000000000..718d2bfcde --- /dev/null +++ b/sql/plan/drop_view.go @@ -0,0 +1,148 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errDropViewChild = errors.NewKind("any child of DropView must be of type SingleDropView") + +type SingleDropView struct { + database sql.Database + viewName string +} + +// NewSingleDropView creates a SingleDropView. +func NewSingleDropView( + database sql.Database, + viewName string, +) *SingleDropView { + return &SingleDropView{database, viewName} +} + +// Children implements the Node interface. It always returns nil. +func (dv *SingleDropView) Children() []sql.Node { + return nil +} + +// Resolved implements the Node interface. This node is resolved if and only if +// its database is resolved. +func (dv *SingleDropView) Resolved() bool { + _, ok := dv.database.(sql.UnresolvedDatabase) + return !ok +} + +// RowIter implements the Node interface. It always returns an empty iterator. +func (dv *SingleDropView) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +// Schema implements the Node interface. It always returns nil. +func (dv *SingleDropView) Schema() sql.Schema { return nil } + +// String implements the fmt.Stringer interface, using sql.TreePrinter to +// generate the string. +func (dv *SingleDropView) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("SingleDropView(%s.%s)", dv.database.Name(), dv.viewName) + + return pr.String() +} + +// WithChildren implements the Node interface. It only succeeds if the length +// of the specified children equals 0. +func (dv *SingleDropView) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(dv, len(children), 0) + } + + return dv, nil +} + +// Database implements the Databaser interfacee. It returns the node's database. +func (dv *SingleDropView) Database() sql.Database { + return dv.database +} + +// Database implements the Databaser interface, and it returns a copy of this +// node with the specified database. +func (dv *SingleDropView) WithDatabase(database sql.Database) (sql.Node, error) { + newDrop := *dv + newDrop.database = database + return &newDrop, nil +} + +// DropView is a node representing the removal of a list of views, defined by +// the children member. The flag ifExists represents whether the user wants the +// node to fail if any of the views in children does not exist. +type DropView struct { + children []sql.Node + Catalog *sql.Catalog + ifExists bool +} + +// NewDropView creates a DropView node with the specified parameters, +// setting its catalog to nil. +func NewDropView(children []sql.Node, ifExists bool) *DropView { + return &DropView{children, nil, ifExists} +} + +// Children implements the Node interface. It returns the children of the +// CreateView node; i.e., all the views that will be dropped. +func (dvs *DropView) Children() []sql.Node { + return dvs.children +} + +// Resolved implements the Node interface. This node is resolved if and only if +// all of its children are resolved. +func (dvs *DropView) Resolved() bool { + for _, child := range dvs.children { + if !child.Resolved() { + return false + } + } + return true +} + +// RowIter implements the Node interface. When executed, this function drops +// all the views defined by the node's children. It errors if the flag ifExists +// is set to false and there is some view that does not exist. +func (dvs *DropView) RowIter(ctx *sql.Context) (sql.RowIter, error) { + viewList := make([]sql.ViewKey, len(dvs.children)) + for i, child := range dvs.children { + drop, ok := child.(*SingleDropView) + if !ok { + return sql.RowsToRowIter(), errDropViewChild.New() + } + + viewList[i] = sql.NewViewKey(drop.database.Name(), drop.viewName) + } + + return sql.RowsToRowIter(), dvs.Catalog.ViewRegistry.DeleteList(viewList, !dvs.ifExists) +} + +// Schema implements the Node interface. It always returns nil. +func (dvs *DropView) Schema() sql.Schema { return nil } + +// String implements the fmt.Stringer interface, using sql.TreePrinter to +// generate the string. +func (dvs *DropView) String() string { + childrenStrings := make([]string, len(dvs.children)) + for i, child := range dvs.children { + childrenStrings[i] = child.String() + } + + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DropView") + _ = pr.WriteChildren(childrenStrings...) + + return pr.String() +} + +// WithChildren implements the Node interface. It always suceeds, returning a +// copy of this node with the new array of nodes as children. +func (dvs *DropView) WithChildren(children ...sql.Node) (sql.Node, error) { + newDrop := dvs + newDrop.children = children + return newDrop, nil +} diff --git a/sql/plan/drop_view_test.go b/sql/plan/drop_view_test.go new file mode 100644 index 0000000000..0f9287fcc7 --- /dev/null +++ b/sql/plan/drop_view_test.go @@ -0,0 +1,94 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + + "github.com/stretchr/testify/require" +) + +// Generates a database with a single table called mytable and a catalog with +// the view that is also returned. The context returned is the one used to +// create the view. +func mockData(require *require.Assertions) (sql.Database, *sql.Catalog, *sql.Context, sql.View) { + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Source: "mytable", Type: sql.Int32}, + {Name: "s", Source: "mytable", Type: sql.Text}, + }) + + db := memory.NewDatabase("db") + db.AddTable("db", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + subqueryAlias := NewSubqueryAlias("myview", + NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int32, table.Name(), "i", true), + }, + NewUnresolvedTable("dual", ""), + ), + ) + + createView := NewCreateView(db, subqueryAlias.Name(), nil, subqueryAlias, false) + createView.Catalog = catalog + + ctx := sql.NewEmptyContext() + + _, err := createView.RowIter(ctx) + require.NoError(err) + + return db, catalog, ctx, createView.View() +} + +// Tests that DropView works as expected and that the view is dropped in +// the catalog when RowIter is called, regardless of the value of ifExists +func TestDropExistingView(t *testing.T) { + require := require.New(t) + + test := func(ifExists bool) { + db, catalog, ctx, view := mockData(require) + + singleDropView := NewSingleDropView(db, view.Name()) + dropView := NewDropView([]sql.Node{singleDropView}, ifExists) + dropView.Catalog = catalog + + _, err := dropView.RowIter(ctx) + require.NoError(err) + + require.False(catalog.ViewRegistry.Exists(db.Name(), view.Name())) + } + + test(false) + test(true) +} + +// Tests that DropView errors when trying to delete a non-existing view if and +// only if the flag ifExists is set to false +func TestDropNonExistingView(t *testing.T) { + require := require.New(t) + + test := func(ifExists bool) error { + db, catalog, ctx, view := mockData(require) + + singleDropView := NewSingleDropView(db, "non-existing-view") + dropView := NewDropView([]sql.Node{singleDropView}, ifExists) + dropView.Catalog = catalog + + _, err := dropView.RowIter(ctx) + + require.True(catalog.ViewRegistry.Exists(db.Name(), view.Name())) + + return err + } + + err := test(true) + require.NoError(err) + + err = test(false) + require.Error(err) +} diff --git a/sql/viewregistry.go b/sql/viewregistry.go new file mode 100644 index 0000000000..d7ebb577cd --- /dev/null +++ b/sql/viewregistry.go @@ -0,0 +1,167 @@ +package sql + +import ( + "strings" + "sync" + + "gopkg.in/src-d/go-errors.v1" +) + +var ( + ErrExistingView = errors.NewKind("the view %s.%s already exists in the registry") + ErrNonExistingView = errors.NewKind("the view %s.%s does not exist in the registry") +) + +// View is defined by a Node and has a name. +type View struct { + name string + definition Node +} + +// NewView creates a View with the specified name and definition. +func NewView(name string, definition Node) View { + return View{name, definition} +} + +// Name returns the name of the view. +func (v *View) Name() string { + return v.name +} + +// Definition returns the definition of the view. +func (v *View) Definition() Node { + return v.definition +} + +// Views are scoped by the databases in which they were defined, so a key in +// the view registry is a pair of names: database and view. +type ViewKey struct { + dbName, viewName string +} + +// NewViewKey creates a ViewKey ensuring both names are lowercase. +func NewViewKey(databaseName, viewName string) ViewKey { + return ViewKey{strings.ToLower(databaseName), strings.ToLower(viewName)} +} + +// ViewRegistry is a map of ViewKey to View whose access is protected by a +// RWMutex. +type ViewRegistry struct { + mutex sync.RWMutex + views map[ViewKey]View +} + +// NewViewRegistry creates an empty ViewRegistry. +func NewViewRegistry() *ViewRegistry { + return &ViewRegistry{ + views: make(map[ViewKey]View), + } +} + +// Register adds the view specified by the pair {database, view.Name()}, +// returning an error if there is already an element with that key. +func (r *ViewRegistry) Register(database string, view View) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + key := NewViewKey(database, view.Name()) + + if _, ok := r.views[key]; ok { + return ErrExistingView.New(database, view.Name()) + } + + r.views[key] = view + return nil +} + +// Delete deletes the view specified by the pair {databaseName, viewName}, +// returning an error if it does not exist. +func (r *ViewRegistry) Delete(databaseName, viewName string) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + key := NewViewKey(databaseName, viewName) + + if _, ok := r.views[key]; !ok { + return ErrNonExistingView.New(databaseName, viewName) + } + + delete(r.views, key) + return nil +} + +// DeleteList tries to delete a list of view keys. +// If the list contains views that do exist and views that do not, the existing +// views are deleted if and only if the errIfNotExists flag is set to false; if +// it is set to true, no views are deleted and an error is returned. +func (r *ViewRegistry) DeleteList(keys []ViewKey, errIfNotExists bool) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if errIfNotExists { + for _, key := range keys { + if !r.exists(key.dbName, key.viewName) { + return ErrNonExistingView.New(key.dbName, key.viewName) + } + } + } + + for _, key := range keys { + delete(r.views, key) + } + + return nil +} + +// View returns a pointer to the view specified by the pair {databaseName, +// viewName}, returning an error if it does not exist. +func (r *ViewRegistry) View(databaseName, viewName string) (*View, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + key := NewViewKey(databaseName, viewName) + + if view, ok := r.views[key]; ok { + return &view, nil + } + + return nil, ErrNonExistingView.New(databaseName, viewName) +} + +// AllViews returns the map of all views in the registry. +func (r *ViewRegistry) AllViews() map[ViewKey]View { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return r.views +} + +// ViewsInDatabase returns an array of all the views registered under the +// specified database. +func (r *ViewRegistry) ViewsInDatabase(databaseName string) (views []View) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + for key, value := range r.views { + if key.dbName == databaseName { + views = append(views, value) + } + } + + return views +} + +func (r *ViewRegistry) exists(databaseName, viewName string) bool { + key := NewViewKey(databaseName, viewName) + _, ok := r.views[key] + + return ok +} + +// Exists returns whether the specified key is already registered +func (r *ViewRegistry) Exists(databaseName, viewName string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return r.exists(databaseName, viewName) +} diff --git a/sql/viewregistry_test.go b/sql/viewregistry_test.go new file mode 100644 index 0000000000..adf1ef6c07 --- /dev/null +++ b/sql/viewregistry_test.go @@ -0,0 +1,224 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +var ( + dbName = "db" + viewName = "myview" + mockView = NewView(viewName, nil) +) + +func newMockRegistry(require *require.Assertions) *ViewRegistry { + registry := NewViewRegistry() + + err := registry.Register(dbName, mockView) + require.NoError(err) + require.Equal(1, len(registry.AllViews())) + + return registry +} + +// Tests the creation of an empty ViewRegistry with no views registered. +func TestNewViewRegistry(t *testing.T) { + require := require.New(t) + + registry := NewViewRegistry() + require.Equal(0, len(registry.AllViews())) +} + +// Tests that registering a non-existing view succeeds. +func TestRegisterNonExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + actualView, err := registry.View(dbName, viewName) + require.NoError(err) + require.Equal(mockView, *actualView) +} + +// Tests that registering an existing view fails. +func TestRegisterExistingVIew(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + err := registry.Register(dbName, mockView) + require.Error(err) + require.True(ErrExistingView.Is(err)) +} + +// Tests that deleting an existing view succeeds. +func TestDeleteExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + err := registry.Delete(dbName, viewName) + require.NoError(err) + require.Equal(0, len(registry.AllViews())) +} + +// Tests that deleting a non-existing view fails. +func TestDeleteNonExistingView(t *testing.T) { + require := require.New(t) + + registry := NewViewRegistry() + + err := registry.Delete("random", "randomer") + require.Error(err) + require.True(ErrNonExistingView.Is(err)) +} + +// Tests that retrieving an existing view succeeds and that the view returned +// is the correct one. +func TestGetExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + actualView, err := registry.View(dbName, viewName) + require.NoError(err) + require.Equal(mockView, *actualView) +} + +// Tests that retrieving a non-existing view fails. +func TestGetNonExistingView(t *testing.T) { + require := require.New(t) + + registry := NewViewRegistry() + + actualView, err := registry.View(dbName, viewName) + require.Error(err) + require.Nil(actualView) + require.True(ErrNonExistingView.Is(err)) +} + +// Tests that retrieving the views registered under a database succeeds, +// returning the list of all the correct views. +func TestViewsInDatabase(t *testing.T) { + require := require.New(t) + + registry := NewViewRegistry() + + databases := []struct { + name string + numViews int + }{ + {"db0", 0}, + {"db1", 5}, + {"db2", 10}, + } + + for _, db := range databases { + for i := 0; i < db.numViews; i++ { + view := NewView(viewName+string(i), nil) + err := registry.Register(db.name, view) + require.NoError(err) + } + + views := registry.ViewsInDatabase(db.name) + require.Equal(db.numViews, len(views)) + } +} + +var viewKeys = []ViewKey{ + { + "db1", + "view1", + }, + { + "db1", + "view2", + }, + { + "db2", + "view1", + }, +} + +func registerKeys(registry *ViewRegistry, require *require.Assertions) { + for _, key := range viewKeys { + err := registry.Register(key.dbName, NewView(key.viewName, nil)) + require.NoError(err) + } + require.Equal(len(viewKeys), len(registry.AllViews())) +} + +func TestDeleteExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + err := registry.DeleteList(viewKeys, errIfNotExists) + require.NoError(err) + require.Equal(0, len(registry.AllViews())) + } + + test(true) + test(false) +} + +func TestDeleteNonExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + err := registry.DeleteList([]ViewKey{{"random", "random"}}, errIfNotExists) + if errIfNotExists { + require.Error(err) + } else { + require.NoError(err) + } + require.Equal(len(viewKeys), len(registry.AllViews())) + } + + test(false) + test(true) +} + +func TestDeletePartiallyExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + toDelete := append(viewKeys, ViewKey{"random", "random"}) + err := registry.DeleteList(toDelete, errIfNotExists) + if errIfNotExists { + require.Error(err) + require.Equal(len(viewKeys), len(registry.AllViews())) + } else { + require.NoError(err) + require.Equal(0, len(registry.AllViews())) + } + } + + test(false) + test(true) +} + +func TestExistsOnExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + require.True(registry.Exists(dbName, viewName)) +} + +func TestExistsOnNonExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + require.False(registry.Exists("non", "existing")) +}