Skip to content

Commit c34860d

Browse files
committed
HHH-19735 Add vector support for SQL Server
1 parent b064009 commit c34860d

16 files changed

+616
-30
lines changed

hibernate-vector/hibernate-vector.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ description = 'Hibernate\'s extensions for vector support'
1313
dependencies {
1414
api project( ':hibernate-core' )
1515

16+
compileOnly jdbcLibs.mssql
17+
1618
testImplementation project( ':hibernate-testing' )
1719
testImplementation project( path: ':hibernate-core', configuration: 'tests' )
1820
}

hibernate-vector/src/main/java/org/hibernate/vector/internal/AbstractOracleVectorJdbcType.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
import org.hibernate.type.descriptor.ValueBinder;
1919
import org.hibernate.type.descriptor.ValueExtractor;
2020
import org.hibernate.type.descriptor.WrapperOptions;
21-
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
22-
import org.hibernate.type.descriptor.java.ByteJavaType;
2321
import org.hibernate.type.descriptor.java.JavaType;
24-
import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType;
2522
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
2623
import org.hibernate.type.descriptor.jdbc.BasicBinder;
2724
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
@@ -82,22 +79,9 @@ public boolean isWriteExpressionTyped(Dialect dialect) {
8279

8380
@Override
8481
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
85-
final JavaType<T> elementJavaType;
86-
if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) {
87-
// Special handling needed for Byte[], because that would conflict with the VARBINARY mapping
88-
//noinspection unchecked
89-
elementJavaType = (JavaType<T>) ByteJavaType.INSTANCE;
90-
}
91-
else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) {
92-
//noinspection unchecked
93-
elementJavaType = ( (BasicPluralJavaType<T>) javaTypeDescriptor ).getElementJavaType();
94-
}
95-
else {
96-
throw new IllegalArgumentException( "not a BasicPluralJavaType" );
97-
}
9882
return new OracleJdbcLiteralFormatterVector<>(
9983
javaTypeDescriptor,
100-
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ),
84+
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ),
10185
getVectorParameters().replace( ",sparse", "" )
10286
);
10387
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.sql.ast.spi.SqlAppender;
9+
import org.hibernate.type.descriptor.WrapperOptions;
10+
import org.hibernate.type.descriptor.java.JavaType;
11+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
12+
import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter;
13+
14+
public class MariaDBJdbcLiteralFormatterVector<T> extends BasicJdbcLiteralFormatter<T> {
15+
16+
private final JdbcLiteralFormatter<Object> elementFormatter;
17+
18+
public MariaDBJdbcLiteralFormatterVector(JavaType<T> javaType, JdbcLiteralFormatter<?> elementFormatter) {
19+
super( javaType );
20+
//noinspection unchecked
21+
this.elementFormatter = (JdbcLiteralFormatter<Object>) elementFormatter;
22+
}
23+
24+
@Override
25+
public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) {
26+
final Object[] objects = unwrap( value, Object[].class, wrapperOptions );
27+
appender.appendSql( "vec_fromtext('" );
28+
char separator = '[';
29+
for ( Object o : objects ) {
30+
appender.appendSql( separator );
31+
elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions );
32+
separator = ',';
33+
}
34+
appender.appendSql( "]')" );
35+
}
36+
37+
}

hibernate-vector/src/main/java/org/hibernate/vector/internal/MariaDBVectorJdbcType.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
1717
import org.hibernate.type.descriptor.jdbc.BasicBinder;
1818
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
19+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
1920
import org.hibernate.type.descriptor.jdbc.JdbcType;
2021
import org.hibernate.type.spi.TypeConfiguration;
2122

@@ -46,6 +47,14 @@ public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
4647
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
4748
}
4849

50+
@Override
51+
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
52+
return new MariaDBJdbcLiteralFormatterVector<>(
53+
javaTypeDescriptor,
54+
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) )
55+
);
56+
}
57+
4958
@Override
5059
public void appendWriteExpression(
5160
String writeExpression,
@@ -70,17 +79,17 @@ public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
7079
return new BasicExtractor<>( javaTypeDescriptor, this ) {
7180
@Override
7281
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
73-
return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options );
82+
return getJavaType().wrap( rs.getObject( paramIndex, float[].class ), options );
7483
}
7584

7685
@Override
7786
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
78-
return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options );
87+
return getJavaType().wrap( statement.getObject( index, float[].class ), options );
7988
}
8089

8190
@Override
8291
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
83-
return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options );
92+
return getJavaType().wrap( statement.getObject( name, float[].class ), options );
8493
}
8594

8695
};
@@ -92,18 +101,18 @@ public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
92101

93102
@Override
94103
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException {
95-
st.setObject( index, value );
104+
st.setObject( index, getBindValue( value, options ) );
96105
}
97106

98107
@Override
99108
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
100109
throws SQLException {
101-
st.setObject( name, value, java.sql.Types.ARRAY );
110+
st.setObject( name, getBindValue( value, options ), java.sql.Types.ARRAY );
102111
}
103112

104113
@Override
105114
public Object getBindValue(X value, WrapperOptions options) {
106-
return value;
115+
return getJavaType().unwrap( value, float[].class, options );
107116
}
108117
};
109118
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.sql.ast.spi.SqlAppender;
9+
import org.hibernate.type.descriptor.WrapperOptions;
10+
import org.hibernate.type.descriptor.java.JavaType;
11+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
12+
import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter;
13+
14+
public class MySQLJdbcLiteralFormatterVector<T> extends BasicJdbcLiteralFormatter<T> {
15+
16+
private final JdbcLiteralFormatter<Object> elementFormatter;
17+
18+
public MySQLJdbcLiteralFormatterVector(JavaType<T> javaType, JdbcLiteralFormatter<?> elementFormatter) {
19+
super( javaType );
20+
//noinspection unchecked
21+
this.elementFormatter = (JdbcLiteralFormatter<Object>) elementFormatter;
22+
}
23+
24+
@Override
25+
public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) {
26+
final Object[] objects = unwrap( value, Object[].class, wrapperOptions );
27+
appender.appendSql( "string_to_vector('" );
28+
char separator = '[';
29+
for ( Object o : objects ) {
30+
appender.appendSql( separator );
31+
elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions );
32+
separator = ',';
33+
}
34+
appender.appendSql( "]')" );
35+
}
36+
37+
}

hibernate-vector/src/main/java/org/hibernate/vector/internal/MySQLVectorJdbcType.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
1717
import org.hibernate.type.descriptor.jdbc.BasicBinder;
1818
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
19+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
1920
import org.hibernate.type.descriptor.jdbc.JdbcType;
2021
import org.hibernate.type.spi.TypeConfiguration;
2122

@@ -49,6 +50,14 @@ public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
4950
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
5051
}
5152

53+
@Override
54+
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
55+
return new MySQLJdbcLiteralFormatterVector<>(
56+
javaTypeDescriptor,
57+
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) )
58+
);
59+
}
60+
5261
@Override
5362
public void appendWriteExpression(
5463
String writeExpression,
@@ -80,17 +89,17 @@ public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
8089
return new BasicExtractor<>( javaTypeDescriptor, this ) {
8190
@Override
8291
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
83-
return javaTypeDescriptor.wrap( parseFloatVector( rs.getBytes( paramIndex ) ), options );
92+
return getJavaType().wrap( parseFloatVector( rs.getBytes( paramIndex ) ), options );
8493
}
8594

8695
@Override
8796
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
88-
return javaTypeDescriptor.wrap( parseFloatVector( statement.getBytes( index ) ), options );
97+
return getJavaType().wrap( parseFloatVector( statement.getBytes( index ) ), options );
8998
}
9099

91100
@Override
92101
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
93-
return javaTypeDescriptor.wrap( parseFloatVector( statement.getBytes( name ) ), options );
102+
return getJavaType().wrap( parseFloatVector( statement.getBytes( name ) ), options );
94103
}
95104

96105
};
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.sql.ast.spi.SqlAppender;
9+
import org.hibernate.type.descriptor.WrapperOptions;
10+
import org.hibernate.type.descriptor.java.JavaType;
11+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
12+
import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter;
13+
14+
public class PGVectorJdbcLiteralFormatterVector<T> extends BasicJdbcLiteralFormatter<T> {
15+
16+
private final JdbcLiteralFormatter<Object> elementFormatter;
17+
18+
public PGVectorJdbcLiteralFormatterVector(JavaType<T> javaType, JdbcLiteralFormatter<?> elementFormatter) {
19+
super( javaType );
20+
//noinspection unchecked
21+
this.elementFormatter = (JdbcLiteralFormatter<Object>) elementFormatter;
22+
}
23+
24+
@Override
25+
public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) {
26+
final Object[] objects = unwrap( value, Object[].class, wrapperOptions );
27+
appender.appendSql( "cast('" );
28+
char separator = '[';
29+
for ( Object o : objects ) {
30+
appender.appendSql( separator );
31+
elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions );
32+
separator = ',';
33+
}
34+
appender.appendSql( "]' as vector)" );
35+
}
36+
37+
}

hibernate-vector/src/main/java/org/hibernate/vector/internal/PGVectorJdbcType.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.hibernate.type.descriptor.java.JavaType;
1515
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
1616
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
17+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
1718
import org.hibernate.type.descriptor.jdbc.JdbcType;
1819
import org.hibernate.type.spi.TypeConfiguration;
1920

@@ -47,6 +48,14 @@ public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
4748
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
4849
}
4950

51+
@Override
52+
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
53+
return new PGVectorJdbcLiteralFormatterVector<>(
54+
javaTypeDescriptor,
55+
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) )
56+
);
57+
}
58+
5059
@Override
5160
public void appendWriteExpression(
5261
String writeExpression,

0 commit comments

Comments
 (0)