diff --git a/builder.go b/builder.go index 203da7c..060df0a 100644 --- a/builder.go +++ b/builder.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" + "github.com/tinh-tinh/tinhtinh/v2/dto/validator" "gorm.io/gorm" ) @@ -24,18 +25,27 @@ type QueryBuilder struct { } func (q *QueryBuilder) Equal(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " = ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) Not(column string, args ...interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " = ?" q.qb = q.qb.Not(query, args...) return q } func (q *QueryBuilder) Or(column string, args ...interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " = ?" q.qb = q.qb.Or(query, args...) return q @@ -46,54 +56,81 @@ func (q *QueryBuilder) In(column string, values ...interface{}) *QueryBuilder { for i := range values { placeholders[i] = "?" } + if !isValidColumn(column) { + return q + } query := column + " IN (" + strings.Join(placeholders, ", ") + ")" q.qb = q.qb.Where(query, values...) return q } func (q *QueryBuilder) MoreThan(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " > ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) MoreThanOrEqual(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " >= ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) LessThan(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " < ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) LessThanOrEqual(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " <= ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) Like(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " LIKE ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) ILike(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " ILIKE ?" q.qb = q.qb.Where(query, value) return q } func (q *QueryBuilder) Between(column string, start interface{}, end interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " BETWEEN ? AND ?" q.qb = q.qb.Where(query, start, end) return q } func (q *QueryBuilder) NotEqual(column string, value interface{}) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " <> ?" q.qb = q.qb.Where(query, value) return q @@ -104,12 +141,18 @@ func (q *QueryBuilder) NotIn(column string, values ...interface{}) *QueryBuilder for i := range values { placeholders[i] = "?" } + if !isValidColumn(column) { + return q + } query := column + " NOT IN (" + strings.Join(placeholders, ", ") + ")" q.qb = q.qb.Where(query, values...) return q } func (q *QueryBuilder) IsNull(column string) *QueryBuilder { + if !isValidColumn(column) { + return q + } query := column + " IS NULL" q.qb = q.qb.Where(query) return q @@ -119,3 +162,7 @@ func (q *QueryBuilder) Raw(sql string, values ...interface{}) *QueryBuilder { q.qb = q.qb.Raw(sql, values...) return q } + +func isValidColumn(column string) bool { + return validator.IsAlphanumeric(column) +} diff --git a/builder_test.go b/builder_test.go index 2a3c901..c4a5638 100644 --- a/builder_test.go +++ b/builder_test.go @@ -171,3 +171,239 @@ func Test_QueryBuilder(t *testing.T) { require.Equal(t, 1, len(docs)) require.Equal(t, "test", docs[0].Name) } + +func Test_IsValidColumn(t *testing.T) { + require.NotPanics(t, func() { + createDatabaseForTest("test_valid_column") + }) + dsn := "host=localhost user=postgres password=postgres dbname=test_valid_column port=5432 sslmode=disable TimeZone=Asia/Shanghai" + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + require.Nil(t, err) + db.Exec("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";") + + type TestEntity struct { + gorm.Model + Name string `gorm:"type:varchar(255)"` + Value int `gorm:"type:int"` + } + err = db.AutoMigrate(&TestEntity{}) + require.Nil(t, err) + + repo := sqlorm.Repository[TestEntity]{DB: db} + + // Create test data + count, err := repo.Count(nil) + require.Nil(t, err) + + if count == 0 { + _, err = repo.Create(&TestEntity{Name: "valid", Value: 1}) + require.Nil(t, err) + } + + // Define invalid column names to test + invalidColumns := []string{ + "name; DROP TABLE users;--", + "column' OR '1'='1", + "col=1", + "column!", + "col@name", + "col#name", + "col$name", + "col%name", + "col^name", + "col&name", + "col*name", + "column()", + "[column]", + "", + "col/name", + "col\\name", + "col|name", + "col`name", + "col~name", + "col+name", + "col name", + "col-name", + "col.name", + "col:name", + "col;name", + "col'name", + "col\"name", + "col{name}", + } + + // Test Equal with invalid columns + t.Run("Equal_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Equal(invalidCol, "test") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test Not with invalid columns + t.Run("Not_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Not(invalidCol, "test") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test Or with invalid columns + t.Run("Or_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Or(invalidCol, "test") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test In with invalid columns + t.Run("In_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.In(invalidCol, "test", "test2") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test MoreThan with invalid columns + t.Run("MoreThan_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.MoreThan(invalidCol, 0) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test MoreThanOrEqual with invalid columns + t.Run("MoreThanOrEqual_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.MoreThanOrEqual(invalidCol, 0) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test LessThan with invalid columns + t.Run("LessThan_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.LessThan(invalidCol, 100) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test LessThanOrEqual with invalid columns + t.Run("LessThanOrEqual_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.LessThanOrEqual(invalidCol, 100) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test Like with invalid columns + t.Run("Like_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Like(invalidCol, "%test%") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test ILike with invalid columns + t.Run("ILike_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.ILike(invalidCol, "%TEST%") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test Between with invalid columns + t.Run("Between_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Between(invalidCol, 0, 100) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test NotEqual with invalid columns + t.Run("NotEqual_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.NotEqual(invalidCol, "nonexistent") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test NotIn with invalid columns + t.Run("NotIn_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.NotIn(invalidCol, "nonexistent1", "nonexistent2") + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test IsNull with invalid columns + t.Run("IsNull_InvalidColumn", func(t *testing.T) { + for _, invalidCol := range invalidColumns { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.IsNull(invalidCol) + }) + require.Nil(t, err) + require.GreaterOrEqual(t, len(docs), 1, "Invalid column %q should not filter results", invalidCol) + } + }) + + // Test valid column names still work + t.Run("ValidColumn_Equal", func(t *testing.T) { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.Equal("Name", "valid") + }) + require.Nil(t, err) + require.Equal(t, 1, len(docs)) + require.Equal(t, "valid", docs[0].Name) + }) + + t.Run("ValidColumn_MoreThan", func(t *testing.T) { + docs, err := repo.FindAll(func(qb *sqlorm.QueryBuilder) { + qb.MoreThan("Value", 0) + }) + require.Nil(t, err) + require.Equal(t, 1, len(docs)) + }) +}