diff --git a/src/Config/DatabasePrimitives/DatabaseObject.cs b/src/Config/DatabasePrimitives/DatabaseObject.cs index 6e0db9d8b8..e7ba603b78 100644 --- a/src/Config/DatabasePrimitives/DatabaseObject.cs +++ b/src/Config/DatabasePrimitives/DatabaseObject.cs @@ -229,6 +229,16 @@ public bool IsAnyColumnNullable(List columnsToCheck) return null; } + + public virtual int? GetLengthForParam(string paramName) + { + if (Columns.TryGetValue(paramName, out ColumnDefinition? columnDefinition)) + { + return columnDefinition.Length; + } + + return null; + } } /// @@ -264,6 +274,7 @@ public class ColumnDefinition public bool IsNullable { get; set; } public bool IsReadOnly { get; set; } public object? DefaultValue { get; set; } + public int? Length { get; set; } public ColumnDefinition() { } diff --git a/src/Core/Models/DbConnectionParam.cs b/src/Core/Models/DbConnectionParam.cs index 9426f8fd49..0c2c54a5e0 100644 --- a/src/Core/Models/DbConnectionParam.cs +++ b/src/Core/Models/DbConnectionParam.cs @@ -10,11 +10,12 @@ namespace Azure.DataApiBuilder.Core.Models; /// public class DbConnectionParam { - public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null) + public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null, int? length = null) { Value = value; DbType = dbType; SqlDbType = sqlDbType; + Length = length; } /// @@ -31,4 +32,7 @@ public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbT // This is being made nullable // because it's not populated for DB's other than MSSQL. public SqlDbType? SqlDbType { get; set; } + + // Nullable integer parameter representing length. nullable for back compatibility and for where its not needed + public int? Length { get; set; } } diff --git a/src/Core/Models/GraphQLFilterParsers.cs b/src/Core/Models/GraphQLFilterParsers.cs index 0d367fcd68..11ca98b1eb 100644 --- a/src/Core/Models/GraphQLFilterParsers.cs +++ b/src/Core/Models/GraphQLFilterParsers.cs @@ -486,7 +486,7 @@ private static Predicate ParseScalarType( string schemaName, string tableName, string tableAlias, - Func processLiterals, + Func processLiterals, bool isListType = false) { Column column = new(schemaName, tableName, columnName: fieldName, tableAlias); @@ -614,7 +614,7 @@ public static Predicate Parse( IInputField argumentSchema, Column column, List fields, - Func processLiterals, + Func processLiterals, bool isListType = false) { List predicates = new(); @@ -635,6 +635,8 @@ public static Predicate Parse( continue; } + bool lengthOverride = false; + PredicateOperation op; switch (name) { @@ -665,6 +667,7 @@ public static Predicate Parse( { op = PredicateOperation.LIKE; value = $"%{EscapeLikeString((string)value)}%"; + lengthOverride = true; } break; @@ -677,16 +680,19 @@ public static Predicate Parse( { op = PredicateOperation.NOT_LIKE; value = $"%{EscapeLikeString((string)value)}%"; + lengthOverride = true; } break; case "startsWith": op = PredicateOperation.LIKE; value = $"{EscapeLikeString((string)value)}%"; + lengthOverride = true; break; case "endsWith": op = PredicateOperation.LIKE; value = $"%{EscapeLikeString((string)value)}"; + lengthOverride = true; break; case "isNull": processLiteral = false; @@ -699,10 +705,10 @@ public static Predicate Parse( } predicates.Push(new PredicateOperand(new Predicate( - new PredicateOperand(column), + new(column), op, - new PredicateOperand(processLiteral ? $"{processLiterals(value, column.ColumnName)}" : value.ToString())) - )); + new(processLiteral ? $"{processLiterals(value, column.ColumnName, lengthOverride)}" : value.ToString()) + ))); } return GQLFilterParser.MakeChainPredicate(predicates, PredicateOperation.AND); diff --git a/src/Core/Resolvers/BaseQueryStructure.cs b/src/Core/Resolvers/BaseQueryStructure.cs index 48940a17d5..a5066e968e 100644 --- a/src/Core/Resolvers/BaseQueryStructure.cs +++ b/src/Core/Resolvers/BaseQueryStructure.cs @@ -117,7 +117,7 @@ public BaseQueryStructure( /// /// Value to be assigned to parameter, which can be null for nullable columns. /// The name of the parameter - backing column name for table/views or parameter name for stored procedures. - public virtual string MakeDbConnectionParam(object? value, string? paramName = null) + public virtual string MakeDbConnectionParam(object? value, string? paramName = null, bool lengthOverride = false) { string encodedParamName = GetEncodedParamName(Counter.Next()); if (!string.IsNullOrEmpty(paramName)) @@ -125,7 +125,8 @@ public virtual string MakeDbConnectionParam(object? value, string? paramName = n Parameters.Add(encodedParamName, new(value, dbType: GetUnderlyingSourceDefinition().GetDbTypeForParam(paramName), - sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName))); + sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName), + length: lengthOverride ? -1 : GetUnderlyingSourceDefinition().GetLengthForParam(paramName))); } else { diff --git a/src/Core/Resolvers/CosmosQueryStructure.cs b/src/Core/Resolvers/CosmosQueryStructure.cs index f919426f97..bbb0d7ea6c 100644 --- a/src/Core/Resolvers/CosmosQueryStructure.cs +++ b/src/Core/Resolvers/CosmosQueryStructure.cs @@ -67,7 +67,7 @@ public CosmosQueryStructure( } /// - public override string MakeDbConnectionParam(object? value, string? columnName = null) + public override string MakeDbConnectionParam(object? value, string? columnName = null, bool lengthOverride = false) { string encodedParamName = $"{PARAM_NAME_PREFIX}param{Counter.Next()}"; Parameters.Add(encodedParamName, new(value)); diff --git a/src/Core/Resolvers/MsSqlQueryExecutor.cs b/src/Core/Resolvers/MsSqlQueryExecutor.cs index 45d641bb32..c75498a917 100644 --- a/src/Core/Resolvers/MsSqlQueryExecutor.cs +++ b/src/Core/Resolvers/MsSqlQueryExecutor.cs @@ -392,8 +392,16 @@ public override SqlCommand PrepareDbCommand( { SqlParameter parameter = cmd.CreateParameter(); parameter.ParameterName = parameterEntry.Key; - parameter.Value = parameterEntry.Value.Value ?? DBNull.Value; + parameter.Value = parameterEntry.Value?.Value ?? DBNull.Value; + PopulateDbTypeForParameter(parameterEntry, parameter); + + //if sqldbtype is varchar, nvarchar then set the length + if (parameter.SqlDbType is SqlDbType.VarChar or SqlDbType.NVarChar or SqlDbType.Char or SqlDbType.NChar) + { + parameter.Size = parameterEntry.Value?.Length ?? -1; + } + cmd.Parameters.Add(parameter); } } diff --git a/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs index a6704bf1e5..f43d67315e 100644 --- a/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/SqlMetadataProvider.cs @@ -1321,7 +1321,8 @@ private async Task PopulateSourceDefinitionAsync( SystemType = (Type)columnInfoFromAdapter["DataType"], // An auto-increment column is also considered as a read-only column. For other types of read-only columns, // the flag is populated later via PopulateColumnDefinitionsWithReadOnlyFlag() method. - IsReadOnly = (bool)columnInfoFromAdapter["IsAutoIncrement"] + IsReadOnly = (bool)columnInfoFromAdapter["IsAutoIncrement"], + Length = GetDatabaseType() is DatabaseType.MSSQL ? (int)columnInfoFromAdapter["ColumnSize"] : null }; // Tests may try to add the same column simultaneously diff --git a/src/Service.Tests/DatabaseSchema-MsSql.sql b/src/Service.Tests/DatabaseSchema-MsSql.sql index 3605b2628a..3bf61cff9f 100644 --- a/src/Service.Tests/DatabaseSchema-MsSql.sql +++ b/src/Service.Tests/DatabaseSchema-MsSql.sql @@ -82,7 +82,7 @@ CREATE TABLE publishers_mm( CREATE TABLE books( id int IDENTITY(5001, 1) PRIMARY KEY, - title varchar(max) NOT NULL, + title varchar(30) NOT NULL, publisher_id int NOT NULL ); @@ -514,7 +514,7 @@ SET IDENTITY_INSERT books ON INSERT INTO books(id, title, publisher_id) VALUES (1, 'Awesome book', 1234), (2, 'Also Awesome book', 1234), -(3, 'Great wall of china explained', 2345), +(3, 'Great wall of china explained]', 2345), (4, 'US history in a nutshell', 2345), (5, 'Chernobyl Diaries', 2323), (6, 'The Palace Door', 2324), diff --git a/src/Service.Tests/SqlTests/GraphQLQueryTests/MsSqlGraphQLQueryTests.cs b/src/Service.Tests/SqlTests/GraphQLQueryTests/MsSqlGraphQLQueryTests.cs index 856a8d7d0d..6ea6bf7aa1 100644 --- a/src/Service.Tests/SqlTests/GraphQLQueryTests/MsSqlGraphQLQueryTests.cs +++ b/src/Service.Tests/SqlTests/GraphQLQueryTests/MsSqlGraphQLQueryTests.cs @@ -165,6 +165,109 @@ SELECT TOP 1 content FROM reviews await QueryWithMultipleColumnPrimaryKey(msSqlQuery); } + /// + /// Test if filter param successfully filters when string filter + /// + [TestMethod] + public virtual async Task TestFilterParamForStringFilter() + { + string graphQLQueryName = "books"; + string graphQLQuery = @"{ + books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: {eq:""Awesome book""}}) { + items { + id + title + } + } + }"; + + string expected = @" +[ + { + ""id"": 1, + ""title"": ""Awesome book"" + } +]"; + + JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false); + + SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString()); + } + + /// + /// Test if filter param successfully filters when string filter results in a value longer than the column + /// + /// + /// When using complex operators i.e. NotContains due to wildcards being added or special characters being escaped + /// the string being passed as a parameter maybe longer than the length of the column. The parameter data type + /// can't be fixed to the length of the underlying column, otherwise the parameter value would be truncated and + /// we'd get incorrect results + /// Thus checking the parameter length is overridden to cater for the extra length i.e. lengthOverride = true codepath. + /// + [DataTestMethod] + [DataRow("contains")] + [DataRow("startsWith")] + [DataRow("endsWith")] + public virtual async Task TestFilterParamForStringFilterWorkWithComplexOp(string op) + { + string graphQLQueryName = "books"; + + //using a lookup value that is the length of the title column AND includes special characters + string graphQLQuery = @"{ + books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: {" + op + @":""Great wall of china explained]""}}) { + items { + id + title + } + } + }"; + + string expected = @" +[ + { + ""id"": 3, + ""title"": ""Great wall of china explained]"" + } +]"; + + JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false); + + SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString()); + } + + /// + /// Test if filter param successfully filters when string filter results in a value longer than the column + /// + /// + /// When using complex operators i.e. NotContains due to wildcards being added or special characters being escaped + /// the string being passed as a parameter maybe longer than the length of the column. The parameter data type + /// can't be fixed to the length of the underlying column, otherwise the parameter value would be truncated and + /// we'd get incorrect results. + /// Thus checking the parameter length is overridden to cater for the extra length i.e. lengthOverride = true codepath. + /// + [TestMethod] + public virtual async Task TestFilterParamForStringFilterWorkWithNotContains(string op) + { + string graphQLQueryName = "books"; + //using a lookup value that is the length of the title column AND includes special characters + string graphQLQuery = @"{ + books( " + Service.GraphQLBuilder.Queries.QueryBuilder.FILTER_FIELD_NAME + @":{ title: { notContains:""Great wall of china explained]""},id:{eq:3} }) { + items { + id + title + } + } + }"; + + string expected = @" +[ +]"; + + JsonElement actual = await ExecuteGraphQLRequestAsync(graphQLQuery, graphQLQueryName, isAuthenticated: false); + + SqlTestHelper.PerformTestEqualJsonStrings(expected, actual.GetProperty("items").ToString()); + } + [TestMethod] public async Task QueryWithNullableForeignKey() { @@ -421,8 +524,8 @@ public async Task TestStoredProcedureQueryWithNoDefaultInConfig() public async Task TestSupportForAggregationsWithAliases() { string msSqlQuery = @" - SELECT - MAX(categoryid) AS max, + SELECT + MAX(categoryid) AS max, MAX(price) AS max_price, MIN(price) AS min_price, AVG(price) AS avg_price,