diff --git a/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryComponent.java b/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryComponent.java new file mode 100644 index 00000000000..977bd439665 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryComponent.java @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.solr.client.solrj.SolrServerException; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrDocumentList; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.common.params.CursorMarkParams; +import org.apache.solr.common.params.ShardParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.common.util.SimpleOrderedMap; +import org.apache.solr.common.util.StrUtils; +import org.apache.solr.core.SolrCore; +import org.apache.solr.response.BasicResultContext; +import org.apache.solr.response.ResultContext; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; +import org.apache.solr.search.DocListAndSet; +import org.apache.solr.search.QueryResult; +import org.apache.solr.search.SolrReturnFields; +import org.apache.solr.search.SortSpec; +import org.apache.solr.search.combine.QueryAndResponseCombiner; +import org.apache.solr.search.combine.ReciprocalRankFusion; +import org.apache.solr.util.SolrResponseUtil; +import org.apache.solr.util.plugin.SolrCoreAware; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The CombinedQueryComponent class extends QueryComponent and provides support for executing + * multiple queries and combining their results. + */ +public class CombinedQueryComponent extends QueryComponent implements SolrCoreAware { + + public static final String COMPONENT_NAME = "combined_query"; + protected NamedList initParams; + private Map combiners = new ConcurrentHashMap<>(); + private int maxCombinerQueries; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Override + public void init(NamedList args) { + super.init(args); + this.initParams = args; + this.maxCombinerQueries = CombinerParams.DEFAULT_MAX_COMBINER_QUERIES; + } + + @Override + public void inform(SolrCore core) { + if (initParams != null && initParams.size() > 0) { + log.info("Initializing CombinedQueryComponent"); + NamedList all = (NamedList) initParams.get("combiners"); + for (int i = 0; i < all.size(); i++) { + String name = all.getName(i); + NamedList combinerConfig = (NamedList) all.getVal(i); + String className = (String) combinerConfig.get("class"); + QueryAndResponseCombiner combiner = + core.getResourceLoader().newInstance(className, QueryAndResponseCombiner.class); + combiner.init(combinerConfig); + combiners.computeIfAbsent(name, combinerName -> combiner); + } + Object maxQueries = initParams.get("maxCombinerQueries"); + if (maxQueries != null) { + this.maxCombinerQueries = Integer.parseInt(maxQueries.toString()); + } + } + combiners.computeIfAbsent( + CombinerParams.RECIPROCAL_RANK_FUSION, + key -> { + ReciprocalRankFusion reciprocalRankFusion = new ReciprocalRankFusion(); + reciprocalRankFusion.init(initParams); + return reciprocalRankFusion; + }); + } + + /** + * Overrides the prepare method to handle combined queries. + * + * @param rb the ResponseBuilder to prepare + * @throws IOException if an I/O error occurs during preparation + */ + @Override + public void prepare(ResponseBuilder rb) throws IOException { + if (rb instanceof CombinedQueryResponseBuilder crb) { + SolrParams params = crb.req.getParams(); + String[] queriesToCombineKeys = params.getParams(CombinerParams.COMBINER_QUERY); + if (queriesToCombineKeys.length > maxCombinerQueries) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "Too many queries to combine: limit is " + maxCombinerQueries); + } + for (String queryKey : queriesToCombineKeys) { + final var unparsedQuery = params.get(queryKey); + ResponseBuilder rbNew = new ResponseBuilder(rb.req, new SolrQueryResponse(), rb.components); + rbNew.setQueryString(unparsedQuery); + super.prepare(rbNew); + crb.responseBuilders.add(rbNew); + } + } + super.prepare(rb); + } + + /** + * Overrides the process method to handle CombinedQueryResponseBuilder instances. This method + * processes the responses from multiple shards, combines them using the specified + * QueryAndResponseCombiner strategy, and sets the appropriate results and metadata in the + * CombinedQueryResponseBuilder. + * + * @param rb the ResponseBuilder object to process + * @throws IOException if an I/O error occurs during processing + */ + @Override + public void process(ResponseBuilder rb) throws IOException { + if (rb instanceof CombinedQueryResponseBuilder crb) { + boolean partialResults = false; + boolean segmentTerminatedEarly = false; + List queryResults = new ArrayList<>(); + for (ResponseBuilder rbNow : crb.responseBuilders) { + super.process(rbNow); + DocListAndSet docListAndSet = rbNow.getResults(); + QueryResult queryResult = new QueryResult(); + queryResult.setDocListAndSet(docListAndSet); + queryResults.add(queryResult); + partialResults |= SolrQueryResponse.isPartialResults(rbNow.rsp.getResponseHeader()); + rbNow.setCursorMark(rbNow.getCursorMark()); + if (rbNow.rsp.getResponseHeader() != null) { + segmentTerminatedEarly |= + (boolean) + rbNow + .rsp + .getResponseHeader() + .getOrDefault( + SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY, false); + } + } + QueryAndResponseCombiner combinerStrategy = + QueryAndResponseCombiner.getImplementation(rb.req.getParams(), combiners); + QueryResult combinedQueryResult = combinerStrategy.combine(queryResults, rb.req.getParams()); + combinedQueryResult.setPartialResults(partialResults); + combinedQueryResult.setSegmentTerminatedEarly(segmentTerminatedEarly); + crb.setResult(combinedQueryResult); + if (rb.isDebug()) { + String[] queryKeys = rb.req.getParams().getParams(CombinerParams.COMBINER_QUERY); + List queries = crb.responseBuilders.stream().map(ResponseBuilder::getQuery).toList(); + NamedList explanations = + combinerStrategy.getExplanations( + queryKeys, + queries, + queryResults, + rb.req.getSearcher(), + rb.req.getSchema(), + rb.req.getParams()); + rb.addDebugInfo("combinerExplanations", explanations); + } + ResultContext ctx = new BasicResultContext(crb); + crb.rsp.addResponse(ctx); + crb.rsp + .getToLog() + .add( + "hits", + crb.getResults() == null || crb.getResults().docList == null + ? 0 + : crb.getResults().docList.matches()); + if (!crb.req.getParams().getBool(ShardParams.IS_SHARD, false)) { + if (null != crb.getNextCursorMark()) { + crb.rsp.add( + CursorMarkParams.CURSOR_MARK_NEXT, crb.getNextCursorMark().getSerializedTotem()); + } + } + + if (crb.mergeFieldHandler != null) { + crb.mergeFieldHandler.handleMergeFields(crb, crb.req.getSearcher()); + } else { + doFieldSortValues(rb, crb.req.getSearcher()); + } + doPrefetch(crb); + } else { + super.process(rb); + } + } + + @Override + protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) { + List mergeStrategies = rb.getMergeStrategies(); + if (mergeStrategies != null) { + mergeStrategies.sort(MergeStrategy.MERGE_COMP); + boolean idsMerged = false; + for (MergeStrategy mergeStrategy : mergeStrategies) { + mergeStrategy.merge(rb, sreq); + if (mergeStrategy.mergesIds()) { + idsMerged = true; + } + } + + if (idsMerged) { + return; // ids were merged above so return. + } + } + + SortSpec ss = rb.getSortSpec(); + + // If the shard request was also used to get fields (along with the scores), there is no reason + // to copy over the score dependent fields, since those will already exist in the document with + // the return fields + Set scoreDependentFields; + if ((sreq.purpose & ShardRequest.PURPOSE_GET_FIELDS) == 0) { + scoreDependentFields = + rb.rsp.getReturnFields().getScoreDependentReturnFields().keySet().stream() + .filter(field -> !field.equals(SolrReturnFields.SCORE)) + .collect(Collectors.toSet()); + } else { + scoreDependentFields = Collections.emptySet(); + } + + IndexSchema schema = rb.req.getSchema(); + SchemaField uniqueKeyField = schema.getUniqueKeyField(); + + // id to shard mapping, to eliminate any accidental dups + HashMap uniqueDoc = new HashMap<>(); + + NamedList shardInfo = null; + if (rb.req.getParams().getBool(ShardParams.SHARDS_INFO, false)) { + shardInfo = new SimpleOrderedMap<>(); + rb.rsp.getValues().add(ShardParams.SHARDS_INFO, shardInfo); + } + + long numFound = 0; + boolean hitCountIsExact = true; + Float maxScore = null; + boolean thereArePartialResults = false; + Boolean segmentTerminatedEarly = null; + boolean maxHitsTerminatedEarly = false; + long approximateTotalHits = 0; + int failedShardCount = 0; + Map> shardDocMap = new HashMap<>(); + for (ShardResponse srsp : sreq.responses) { + SolrDocumentList docs = null; + NamedList responseHeader = null; + + if (shardInfo != null) { + SimpleOrderedMap nl = new SimpleOrderedMap<>(); + + if (srsp.getException() != null) { + Throwable t = srsp.getException(); + if (t instanceof SolrServerException && t.getCause() != null) { + t = t.getCause(); + } + nl.add("error", t.toString()); + if (!rb.req.getCore().getCoreContainer().hideStackTrace()) { + StringWriter trace = new StringWriter(); + t.printStackTrace(new PrintWriter(trace)); + nl.add("trace", trace.toString()); + } + if (!StrUtils.isNullOrEmpty(srsp.getShardAddress())) { + nl.add("shardAddress", srsp.getShardAddress()); + } + } else { + responseHeader = + (NamedList) + SolrResponseUtil.getSubsectionFromShardResponse( + rb, srsp, "responseHeader", false); + if (responseHeader == null) { + continue; + } + final Object rhste = + responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY); + if (rhste != null) { + nl.add(SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY, rhste); + } + final Object rhmhte = + responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_MAX_HITS_TERMINATED_EARLY_KEY); + if (rhmhte != null) { + nl.add(SolrQueryResponse.RESPONSE_HEADER_MAX_HITS_TERMINATED_EARLY_KEY, rhmhte); + } + final Object rhath = + responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY); + if (rhath != null) { + nl.add(SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY, rhath); + } + docs = + (SolrDocumentList) + SolrResponseUtil.getSubsectionFromShardResponse(rb, srsp, "response", false); + if (docs == null) { + continue; + } + nl.add("numFound", docs.getNumFound()); + nl.add("numFoundExact", docs.getNumFoundExact()); + nl.add("maxScore", docs.getMaxScore()); + nl.add("shardAddress", srsp.getShardAddress()); + } + if (srsp.getSolrResponse() != null) { + nl.add("time", srsp.getSolrResponse().getElapsedTime()); + } + // This ought to be better, but at least this ensures no duplicate keys in JSON result + String shard = srsp.getShard(); + if (StrUtils.isNullOrEmpty(shard)) { + failedShardCount += 1; + shard = "unknown_shard_" + failedShardCount; + } + shardInfo.add(shard, nl); + } + // now that we've added the shard info, let's only proceed if we have no error. + if (srsp.getException() != null) { + thereArePartialResults = true; + continue; + } + + if (docs == null) { // could have been initialized in the shards info block above + docs = + Objects.requireNonNull( + (SolrDocumentList) + SolrResponseUtil.getSubsectionFromShardResponse(rb, srsp, "response", false)); + } + + if (responseHeader == null) { // could have been initialized in the shards info block above + responseHeader = + Objects.requireNonNull( + (NamedList) + SolrResponseUtil.getSubsectionFromShardResponse( + rb, srsp, "responseHeader", false)); + } + + final boolean thisResponseIsPartial; + thisResponseIsPartial = + Boolean.TRUE.equals( + responseHeader.getBooleanArg(SolrQueryResponse.RESPONSE_HEADER_PARTIAL_RESULTS_KEY)); + thereArePartialResults |= thisResponseIsPartial; + + if (!Boolean.TRUE.equals(segmentTerminatedEarly)) { + final Object ste = + responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY); + if (Boolean.TRUE.equals(ste)) { + segmentTerminatedEarly = Boolean.TRUE; + } else if (Boolean.FALSE.equals(ste)) { + segmentTerminatedEarly = Boolean.FALSE; + } + } + + if (!maxHitsTerminatedEarly) { + if (Boolean.TRUE.equals( + responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_MAX_HITS_TERMINATED_EARLY_KEY))) { + maxHitsTerminatedEarly = true; + } + } + Object ath = responseHeader.get(SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY); + if (ath == null) { + approximateTotalHits += numFound; + } else { + approximateTotalHits += ((Number) ath).longValue(); + } + + // calculate global maxScore and numDocsFound + if (docs.getMaxScore() != null) { + maxScore = maxScore == null ? docs.getMaxScore() : Math.max(maxScore, docs.getMaxScore()); + } + numFound += docs.getNumFound(); + + if (hitCountIsExact && Boolean.FALSE.equals(docs.getNumFoundExact())) { + hitCountIsExact = false; + } + + @SuppressWarnings("unchecked") + NamedList> sortFieldValues = + (NamedList>) + SolrResponseUtil.getSubsectionFromShardResponse(rb, srsp, "sort_values", true); + if (null == sortFieldValues) { + sortFieldValues = new NamedList<>(); + } + + // if the SortSpec contains a field besides score or the Lucene docid, then the values will + // need to be unmarshalled from sortFieldValues. + boolean needsUnmarshalling = ss.includesNonScoreOrDocField(); + + // if we need to unmarshal the sortFieldValues for sorting but we have none, which can happen + // if partial results are being returned from the shard, then skip merging the results for the + // shard. This avoids an exception below. if the shard returned partial results but we don't + // need to unmarshal (a normal scoring query), then merge what we got. + if (thisResponseIsPartial && sortFieldValues.size() == 0 && needsUnmarshalling) { + continue; + } + + // Checking needsUnmarshalling saves on iterating the SortFields in the SortSpec again. + NamedList> unmarshalledSortFieldValues = + needsUnmarshalling ? unmarshalSortValues(ss, sortFieldValues, schema) : new NamedList<>(); + + // go through every doc in this response, construct a ShardDoc, and + // put it in the priority queue so it can be ordered. + for (int i = 0; i < docs.size(); i++) { + SolrDocument doc = docs.get(i); + Object id = doc.getFieldValue(uniqueKeyField.getName()); + ShardDoc shardDoc = new ShardDoc(); + shardDoc.id = id; + shardDoc.shard = srsp.getShard(); + shardDoc.orderInShard = i; + Object scoreObj = doc.getFieldValue(SolrReturnFields.SCORE); + if (scoreObj != null) { + if (scoreObj instanceof String) { + shardDoc.score = Float.parseFloat((String) scoreObj); + } else { + shardDoc.score = ((Number) scoreObj).floatValue(); + } + } + if (!scoreDependentFields.isEmpty()) { + shardDoc.scoreDependentFields = doc.getSubsetOfFields(scoreDependentFields); + } + + shardDoc.sortFieldValues = unmarshalledSortFieldValues; + shardDocMap.computeIfAbsent(srsp.getShard(), list -> new ArrayList<>()).add(shardDoc); + String prevShard = uniqueDoc.put(id, srsp.getShard()); + if (prevShard != null) { + // duplicate detected + numFound--; + } + } // end for-each-doc-in-response + } // end for-each-response + + SolrDocumentList responseDocs = new SolrDocumentList(); + if (maxScore != null) responseDocs.setMaxScore(maxScore); + rb.rsp.addToLog("hits", numFound); + + responseDocs.setNumFound(numFound); + responseDocs.setNumFoundExact(hitCountIsExact); + responseDocs.setStart(ss.getOffset()); + + // save these results in a private area so we can access them + // again when retrieving stored fields. + // TODO: use ResponseBuilder (w/ comments) or the request context? + rb.resultIds = createShardResult(rb, shardDocMap, responseDocs); + rb.setResponseDocs(responseDocs); + + populateNextCursorMarkFromMergedShards(rb); + + if (thereArePartialResults) { + rb.rsp + .getResponseHeader() + .asShallowMap() + .put(SolrQueryResponse.RESPONSE_HEADER_PARTIAL_RESULTS_KEY, Boolean.TRUE); + } + if (segmentTerminatedEarly != null) { + final Object existingSegmentTerminatedEarly = + rb.rsp + .getResponseHeader() + .get(SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY); + if (existingSegmentTerminatedEarly == null) { + rb.rsp + .getResponseHeader() + .add( + SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY, + segmentTerminatedEarly); + } else if (!Boolean.TRUE.equals(existingSegmentTerminatedEarly) + && Boolean.TRUE.equals(segmentTerminatedEarly)) { + rb.rsp + .getResponseHeader() + .remove(SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY); + rb.rsp + .getResponseHeader() + .add( + SolrQueryResponse.RESPONSE_HEADER_SEGMENT_TERMINATED_EARLY_KEY, + segmentTerminatedEarly); + } + } + if (maxHitsTerminatedEarly) { + rb.rsp + .getResponseHeader() + .add(SolrQueryResponse.RESPONSE_HEADER_MAX_HITS_TERMINATED_EARLY_KEY, Boolean.TRUE); + if (approximateTotalHits > 0) { + rb.rsp + .getResponseHeader() + .add( + SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY, approximateTotalHits); + } + } + } + + /** + * Combines and sorts documents from multiple shards to create the final result set. This method + * uses a combiner strategy to merge shard responses, then sorts the resulting documents using a + * priority queue based on the request's sort specification. It handles pagination (offset and + * count) and calculates the maximum score for the response. + * + * @param rb The ResponseBuilder containing the request and context, such as sort specifications. + * @param shardDocMap A map from shard addresses to the list of documents returned by each shard. + * @param responseDocs The final response document list, which will be populated with null + * placeholders and have its max score set. + * @return A map from document IDs to the corresponding ShardDoc objects for the documents in the + * final sorted page of results. + */ + protected Map createShardResult( + ResponseBuilder rb, Map> shardDocMap, SolrDocumentList responseDocs) { + QueryAndResponseCombiner combinerStrategy = + QueryAndResponseCombiner.getImplementation(rb.req.getParams(), combiners); + List combinedShardDocs = combinerStrategy.combine(shardDocMap, rb.req.getParams()); + Map shardDocIdMap = new HashMap<>(); + shardDocMap.forEach( + (shardKey, shardDocs) -> + shardDocs.forEach(shardDoc -> shardDocIdMap.put(shardDoc.id.toString(), shardDoc))); + Map resultIds = new HashMap<>(); + float maxScore = 0.0f; + Sort sort = rb.getSortSpec().getSort(); + SortField[] sortFields; + if (sort != null) { + sortFields = sort.getSort(); + } else { + sortFields = new SortField[] {SortField.FIELD_SCORE}; + } + final ShardFieldSortedHitQueue queue = + new ShardFieldSortedHitQueue( + sortFields, + rb.getSortSpec().getOffset() + rb.getSortSpec().getCount(), + rb.req.getSearcher()); + combinedShardDocs.forEach(queue::insertWithOverflow); + int resultSize = queue.size() - rb.getSortSpec().getOffset(); + resultSize = Math.max(0, resultSize); + for (int i = resultSize - 1; i >= 0; i--) { + ShardDoc shardDoc = queue.pop(); + shardDoc.positionInResponse = i; + maxScore = Math.max(maxScore, shardDoc.score); + if (Float.isNaN(shardDocIdMap.get(shardDoc.id.toString()).score)) { + shardDoc.score = Float.NaN; + } + resultIds.put(shardDoc.id.toString(), shardDoc); + } + responseDocs.setMaxScore(maxScore); + for (int i = 0; i < resultSize; i++) responseDocs.add(null); + return resultIds; + } + + @Override + public String getDescription() { + return "Combined Query Component to support multiple query execution"; + } +} diff --git a/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryResponseBuilder.java b/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryResponseBuilder.java new file mode 100644 index 00000000000..c5a93c86e47 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/handler/component/CombinedQueryResponseBuilder.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.SolrQueryResponse; + +/** + * The CombinedQueryResponseBuilder class extends the ResponseBuilder class and is responsible for + * building a combined response for multiple SearchComponent objects. It orchestrates the process of + * constructing the SolrQueryResponse by aggregating results from various components. + */ +public class CombinedQueryResponseBuilder extends ResponseBuilder { + + public final List responseBuilders = new ArrayList<>(); + + /** + * Constructs a CombinedQueryResponseBuilder instance. + * + * @param req the SolrQueryRequest object containing the query parameters and context. + * @param rsp the SolrQueryResponse object to which the combined results will be added. + * @param components a list of SearchComponent objects that will be used to build the response. + */ + public CombinedQueryResponseBuilder( + SolrQueryRequest req, SolrQueryResponse rsp, List components) { + super(req, rsp, components); + } +} diff --git a/solr/core/src/java/org/apache/solr/handler/component/CombinedQuerySearchHandler.java b/solr/core/src/java/org/apache/solr/handler/component/CombinedQuerySearchHandler.java new file mode 100644 index 00000000000..2605e4c9aec --- /dev/null +++ b/solr/core/src/java/org/apache/solr/handler/component/CombinedQuerySearchHandler.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.facet.FacetModule; + +/** + * The CombinedQuerySearchHandler class extends the SearchHandler and provides custom behavior for + * handling combined queries. It overrides methods to create a response builder based on the {@link + * CombinerParams#COMBINER} parameter and to define the default components included in the search + * configuration. + */ +public class CombinedQuerySearchHandler extends SearchHandler { + + /** + * Overrides the default response builder creation method. This method checks if the {@link + * CombinerParams#COMBINER} parameter is set to true in the request. If it is, it returns an + * instance of {@link CombinedQueryResponseBuilder}, otherwise, it returns an instance of {@link + * ResponseBuilder}. + * + * @param req the SolrQueryRequest object + * @param rsp the SolrQueryResponse object + * @param components the list of SearchComponent objects + * @return the appropriate ResponseBuilder instance based on the CombinerParams.COMBINER parameter + */ + @Override + protected ResponseBuilder newResponseBuilder( + SolrQueryRequest req, SolrQueryResponse rsp, List components) { + if (req.getParams().getBool(CombinerParams.COMBINER, false)) { + return new CombinedQueryResponseBuilder(req, rsp, components); + } + return new ResponseBuilder(req, rsp, components); + } + + /** + * Overrides the default components and returns a list of component names that are included in the + * default configuration. + * + * @return a list of component names + */ + @Override + @SuppressWarnings("unchecked") + protected List getDefaultComponents() { + List names = new ArrayList<>(9); + names.add(CombinedQueryComponent.COMPONENT_NAME); + names.add(FacetComponent.COMPONENT_NAME); + names.add(FacetModule.COMPONENT_NAME); + names.add(MoreLikeThisComponent.COMPONENT_NAME); + names.add(HighlightComponent.COMPONENT_NAME); + names.add(StatsComponent.COMPONENT_NAME); + names.add(DebugComponent.COMPONENT_NAME); + names.add(ExpandComponent.COMPONENT_NAME); + names.add(TermsComponent.COMPONENT_NAME); + return names; + } +} diff --git a/solr/core/src/java/org/apache/solr/search/combine/QueryAndResponseCombiner.java b/solr/core/src/java/org/apache/solr/search/combine/QueryAndResponseCombiner.java new file mode 100644 index 00000000000..06e1b8bc85d --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/combine/QueryAndResponseCombiner.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.combine; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.handler.component.ShardDoc; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.search.QueryResult; +import org.apache.solr.search.SolrIndexSearcher; + +/** + * The QueryAndResponseCombiner class is an abstract base class for combining query results and + * shard documents. It provides a framework for different algorithms to be implemented for merging + * ranked lists and shard documents. + */ +public abstract class QueryAndResponseCombiner { + + public abstract void init(NamedList args); + + /** + * Combines multiple ranked lists into a single QueryResult. + * + * @param rankedLists a list of ranked lists to be combined + * @param solrParams params to be used when provided at query time + * @return a new QueryResult containing the combined results + * @throws IllegalArgumentException if the input list is empty + */ + public abstract QueryResult combine(List rankedLists, SolrParams solrParams); + + /** + * Combines shard documents based on the provided map. + * + * @param shardDocMap a map where keys represent shard IDs and values are lists of ShardDocs for + * each shard + * @param solrParams params to be used when provided at query time + * @return a combined list of ShardDocs from all shards + */ + public abstract List combine( + Map> shardDocMap, SolrParams solrParams); + + /** + * Retrieves a list of explanations for the given queries and results. + * + * @param queryKeys the keys associated with the queries + * @param queries the list of queries for which explanations are requested + * @param queryResult the list of QueryResult corresponding to each query + * @param searcher the SolrIndexSearcher used to perform the search + * @param schema the IndexSchema used to interpret the search results + * @param solrParams params to be used when provided at query time + * @return a list of explanations for the given queries and results + * @throws IOException if an I/O error occurs during the explanation retrieval process + */ + public abstract NamedList getExplanations( + String[] queryKeys, + List queries, + List queryResult, + SolrIndexSearcher searcher, + IndexSchema schema, + SolrParams solrParams) + throws IOException; + + /** + * Retrieves an implementation of the QueryAndResponseCombiner based on the specified algorithm. + * + * @param requestParams the SolrParams containing the request parameters, including the combiner + * algorithm. + * @param combiners The already initialised map of QueryAndResponseCombiner + * @return an instance of QueryAndResponseCombiner corresponding to the specified algorithm. + * @throws SolrException if an unknown combiner algorithm is specified. + */ + public static QueryAndResponseCombiner getImplementation( + SolrParams requestParams, Map combiners) { + String algorithm = + requestParams.get(CombinerParams.COMBINER_ALGORITHM, CombinerParams.DEFAULT_COMBINER); + if (combiners.containsKey(algorithm)) { + return combiners.get(algorithm); + } + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, "Unknown Combining algorithm: " + algorithm); + } +} diff --git a/solr/core/src/java/org/apache/solr/search/combine/ReciprocalRankFusion.java b/solr/core/src/java/org/apache/solr/search/combine/ReciprocalRankFusion.java new file mode 100644 index 00000000000..81028c8a4f0 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/combine/ReciprocalRankFusion.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.combine; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import org.apache.lucene.document.Document; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TotalHits; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.common.util.SimpleOrderedMap; +import org.apache.solr.handler.component.ShardDoc; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.search.DocIterator; +import org.apache.solr.search.DocList; +import org.apache.solr.search.DocSlice; +import org.apache.solr.search.QueryResult; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.search.SortedIntDocSet; + +/** + * The ReciprocalRankFusion class implements a query and response combiner that uses the Reciprocal + * Rank Fusion (RRF) algorithm to combine multiple ranked lists into a single ranked list. + */ +public class ReciprocalRankFusion extends QueryAndResponseCombiner { + + private int k; + + public int getK() { + return k; + } + + public ReciprocalRankFusion() { + this.k = CombinerParams.COMBINER_RRF_K_DEFAULT; + } + + @Override + public void init(NamedList args) { + Object kParam = args.get("k"); + if (kParam != null) { + this.k = Integer.parseInt(kParam.toString()); + } + } + + @Override + public QueryResult combine(List rankedLists, SolrParams solrParams) { + int kVal = solrParams.getInt(CombinerParams.COMBINER_RRF_K, this.k); + List docLists = getDocListsFromQueryResults(rankedLists); + QueryResult combinedResult = new QueryResult(); + combineResults(combinedResult, docLists, false, kVal); + return combinedResult; + } + + private static List getDocListsFromQueryResults(List rankedLists) { + List docLists = new ArrayList<>(rankedLists.size()); + for (QueryResult rankedList : rankedLists) { + docLists.add(rankedList.getDocList()); + } + return docLists; + } + + @Override + public List combine(Map> shardDocMap, SolrParams solrParams) { + int kVal = solrParams.getInt(CombinerParams.COMBINER_RRF_K, this.k); + HashMap docIdToScore = new HashMap<>(); + Map docIdToShardDoc = new HashMap<>(); + List finalShardDocList = new ArrayList<>(); + for (Map.Entry> shardDocEntry : shardDocMap.entrySet()) { + List shardDocList = shardDocEntry.getValue(); + int ranking = 1; + while (ranking <= shardDocList.size()) { + String docId = shardDocList.get(ranking - 1).id.toString(); + docIdToShardDoc.put(docId, shardDocList.get(ranking - 1)); + float rrfScore = 1f / (kVal + ranking); + docIdToScore.compute(docId, (id, score) -> (score == null) ? rrfScore : score + rrfScore); + ranking++; + } + } + List> sortedByScoreDescending = + docIdToScore.entrySet().stream() + .sorted(Collections.reverseOrder(Map.Entry.comparingByValue())) + .toList(); + for (Map.Entry scoredDoc : sortedByScoreDescending) { + String docId = scoredDoc.getKey(); + Float score = scoredDoc.getValue(); + ShardDoc shardDoc = docIdToShardDoc.get(docId); + shardDoc.score = score; + finalShardDocList.add(shardDoc); + } + return finalShardDocList; + } + + private Map combineResults( + QueryResult combinedRankedList, + List rankedLists, + boolean saveRankPositionsForExplain, + int kVal) { + Map docIdToRanks = null; + HashMap docIdToScore = new HashMap<>(); + long totalMatches = 0; + for (DocList rankedList : rankedLists) { + DocIterator docs = rankedList.iterator(); + totalMatches = Math.max(totalMatches, rankedList.matches()); + int ranking = 1; + while (docs.hasNext()) { + int docId = docs.nextDoc(); + float rrfScore = 1f / (kVal + ranking); + docIdToScore.compute(docId, (id, score) -> (score == null) ? rrfScore : score + rrfScore); + ranking++; + } + } + List> sortedByScoreDescending = + docIdToScore.entrySet().stream() + .sorted(Collections.reverseOrder(Map.Entry.comparingByValue())) + .toList(); + + int combinedResultsLength = docIdToScore.size(); + int[] combinedResultsDocIds = new int[combinedResultsLength]; + float[] combinedResultScores = new float[combinedResultsLength]; + + int i = 0; + for (Map.Entry scoredDoc : sortedByScoreDescending) { + combinedResultsDocIds[i] = scoredDoc.getKey(); + combinedResultScores[i] = scoredDoc.getValue(); + i++; + } + + if (saveRankPositionsForExplain) { + docIdToRanks = getRanks(rankedLists, combinedResultsDocIds); + } + + DocSlice combinedResultSlice = + new DocSlice( + 0, + combinedResultsLength, + combinedResultsDocIds, + combinedResultScores, + Math.max(combinedResultsLength, totalMatches), + combinedResultScores.length > 0 ? combinedResultScores[0] : 0, + TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + combinedRankedList.setDocList(combinedResultSlice); + SortedIntDocSet docSet = new SortedIntDocSet(combinedResultsDocIds, combinedResultsLength); + combinedRankedList.setDocSet(docSet); + return docIdToRanks; + } + + private Map getRanks(List docLists, int[] combinedResultsDocIDs) { + Map docIdToRanks; + docIdToRanks = new HashMap<>(); + for (int docID : combinedResultsDocIDs) { + docIdToRanks.put(docID, new Integer[docLists.size()]); + } + for (int j = 0; j < docLists.size(); j++) { + DocIterator iterator = docLists.get(j).iterator(); + int rank = 1; + while (iterator.hasNext()) { + int docId = iterator.nextDoc(); + docIdToRanks.get(docId)[j] = rank; + rank++; + } + } + return docIdToRanks; + } + + @Override + public NamedList getExplanations( + String[] queryKeys, + List queries, + List queryResult, + SolrIndexSearcher searcher, + IndexSchema schema, + SolrParams solrParams) + throws IOException { + int kVal = solrParams.getInt(CombinerParams.COMBINER_RRF_K, this.k); + NamedList docIdsExplanations = new SimpleOrderedMap<>(); + QueryResult combinedRankedList = new QueryResult(); + Map docIdToRanks = + combineResults(combinedRankedList, getDocListsFromQueryResults(queryResult), true, kVal); + DocList combinedDocList = combinedRankedList.getDocList(); + + DocIterator iterator = combinedDocList.iterator(); + for (int i = 0; i < combinedDocList.size(); i++) { + int docId = iterator.nextDoc(); + Integer[] rankPerQuery = docIdToRanks.get(docId); + Document doc = searcher.doc(docId); + String docUniqueKey = schema.printableUniqueKey(doc); + List originalExplanations = new ArrayList<>(queryKeys.length); + for (int queryIndex = 0; queryIndex < queryKeys.length; queryIndex++) { + Explanation originalQueryExplain = searcher.explain(queries.get(queryIndex), docId); + Explanation originalQueryExplainWithKey = + Explanation.match( + originalQueryExplain.getValue(), queryKeys[queryIndex], originalQueryExplain); + originalExplanations.add(originalQueryExplainWithKey); + } + Explanation fullDocIdExplanation = + Explanation.match( + iterator.score(), + getReciprocalRankFusionExplain(queryKeys, rankPerQuery, kVal), + originalExplanations); + docIdsExplanations.add(docUniqueKey, fullDocIdExplanation); + } + return docIdsExplanations; + } + + private String getReciprocalRankFusionExplain( + String[] queryKeys, Integer[] rankPerQuery, int kVal) { + StringBuilder reciprocalRankFusionExplain = new StringBuilder(); + StringJoiner scoreComponents = new StringJoiner(" + "); + for (Integer rank : rankPerQuery) { + if (rank != null) { + scoreComponents.add("1/(" + kVal + "+" + rank + ")"); + } + } + reciprocalRankFusionExplain.append(scoreComponents); + reciprocalRankFusionExplain.append(" because its ranks were: "); + StringJoiner rankComponents = new StringJoiner(", "); + for (int i = 0; i < queryKeys.length; i++) { + Integer rank = rankPerQuery[i]; + if (rank == null) { + rankComponents.add("not in the results for query(" + queryKeys[i] + ")"); + } else { + rankComponents.add(rank + " for query(" + queryKeys[i] + ")"); + } + } + reciprocalRankFusionExplain.append(rankComponents); + return reciprocalRankFusionExplain.toString(); + } +} diff --git a/solr/core/src/java/org/apache/solr/search/combine/package-info.java b/solr/core/src/java/org/apache/solr/search/combine/package-info.java new file mode 100644 index 00000000000..4c1225e6e21 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/combine/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This contains the classes to combine the scores from search index results as well as from across + * shards. Multiple implementation of algorithms can be added to support them. + */ +package org.apache.solr.search.combine; diff --git a/solr/core/src/test-files/solr/collection1/conf/solrconfig-combined-query.xml b/solr/core/src/test-files/solr/collection1/conf/solrconfig-combined-query.xml new file mode 100644 index 00000000000..7c0ef9020fb --- /dev/null +++ b/solr/core/src/test-files/solr/collection1/conf/solrconfig-combined-query.xml @@ -0,0 +1,555 @@ + + + + + + + + + + ${solr.data.dir:} + + + + 1000000 + 2000000 + 3000000 + 4000000 + + + + + ${tests.luceneMatchVersion:LATEST} + + + + + + + + + ${solr.autoCommit.maxTime:-1} + + + + + + ${solr.ulog.dir:} + + + + ${solr.commitwithin.softcommit:true} + + + + + + + ${solr.max.booleanClauses:1024} + + + ${solr.query.minPrefixLength:-1} + + + + + + + + + + + + true + + + + + + 10 + + + + + + + + + + + + 2000 + + + + + + + + true + + + + true + + + + 2 + + + org.apache.solr.search.combine.TestCombiner + 60 + test + + + + + + + dismax + *:* + 0.01 + + text^0.5 features_t^1.0 subject^1.4 title_stemmed^2.0 + + + text^0.2 features_t^1.1 subject^1.4 title_stemmed^2.0 title^1.5 + + + weight^0.5 recip(rord(id),1,1000,1000)^0.3 + + + 3<-1 5<-2 6<90% + + 100 + + + + + + + + + 4 + true + text,name,subject,title,whitetok + + + + + + + 4 + true + text,name,subject,title,whitetok + + + + + + + + lowerpunctfilt + + + default + lowerfilt + spellchecker1 + false + + + direct + DirectSolrSpellChecker + lowerfilt + 3 + + + directMQF2 + DirectSolrSpellChecker + lowerfilt + 3 + 2 + + + wordbreak + solr.WordBreakSolrSpellChecker + lowerfilt + true + true + 10 + + + multipleFields + lowerfilt1and2 + spellcheckerMultipleFields + false + + + + jarowinkler + lowerfilt + + org.apache.lucene.search.spell.JaroWinklerDistance + spellchecker2 + + + + solr.FileBasedSpellChecker + external + spellings.txt + UTF-8 + spellchecker3 + + + + freq + lowerfilt + spellcheckerFreq + + freq + false + + + fqcn + lowerfilt + spellcheckerFQCN + org.apache.solr.spelling.SampleComparator + false + + + perDict + org.apache.solr.handler.component.DummyCustomParamSpellChecker + lowerfilt + + + + + + + + + + + + + false + + false + + 1 + + + spellcheck + + + + + direct + false + false + 1 + + + spellcheck + + + + + default + wordbreak + 20 + + + spellcheck + + + + + direct + wordbreak + 20 + + + spellcheck + + + + + dismax + lowerfilt1^1 + + + spellcheck + + + + + + + + + + + + + + tvComponent + + + + + + + + + + + + 100 + + + + + + 70 + + + + + + + ]]> + ]]> + + + + + + + + + + + + + 10 + .,!? + + + + + + WORD + en + US + + + + + + + + + max-age=30, public + + + + + + foo_s + + + foo_s:bar + + + + + foo_s + foo_s:bar + + + + + prefix-${solr.test.sys.prop2}-suffix + + + + + + + uniq + uniq2 + uniq3 + + + + + + + + + regex_dup_A_s + x + x_x + + + + regex_dup_B_s + x + x_x + + + + + + + + regex_dup_A_s + x + x_x + + + regex_dup_B_s + x + x_x + + + + + + + org.apache.solr.rest.ManagedResourceStorage$InMemoryStorageIO + + + + + + text + + + + + + text + + + nl + + + diff --git a/solr/core/src/test/org/apache/solr/handler/component/CombinedQueryComponentTest.java b/solr/core/src/test/org/apache/solr/handler/component/CombinedQueryComponentTest.java new file mode 100644 index 00000000000..d8750573b1d --- /dev/null +++ b/solr/core/src/test/org/apache/solr/handler/component/CombinedQueryComponentTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.params.CommonParams; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * The CombinedQueryComponentTest class is a unit test suite for the CombinedQueryComponent in Solr. + * It verifies the functionality of the component by performing various queries and validating the + * responses. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class CombinedQueryComponentTest extends SolrTestCaseJ4 { + + private static final int NUM_DOCS = 10; + private static final String vectorField = "vector"; + + /** + * Sets up the test class by initializing the core and adding test documents to the index. This + * method prepares the Solr index with a set of documents for subsequent test cases. + * + * @throws Exception if any error occurs during setup, such as initialization failures or indexing + * issues. + */ + @BeforeClass + public static void setUpClass() throws Exception { + initCore("solrconfig-combined-query.xml", "schema-vector-catchall.xml"); + List docs = new ArrayList<>(); + for (int i = 1; i <= NUM_DOCS; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField("id", Integer.toString(i)); + doc.addField("text", "test text for doc " + i); + doc.addField("title", "title test for doc " + i); + docs.add(doc); + } + // cosine distance vector1= 1.0 + docs.get(0).addField(vectorField, Arrays.asList(1f, 2f, 3f, 4f)); + // cosine distance vector1= 0.998 + docs.get(1).addField(vectorField, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); + // cosine distance vector1= 0.992 + docs.get(2).addField(vectorField, Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); + // cosine distance vector1= 0.999 + docs.get(3).addField(vectorField, Arrays.asList(1.4f, 2.4f, 3.4f, 4.4f)); + // cosine distance vector1= 0.862 + docs.get(4).addField(vectorField, Arrays.asList(30f, 22f, 35f, 20f)); + // cosine distance vector1= 0.756 + docs.get(5).addField(vectorField, Arrays.asList(40f, 1f, 1f, 200f)); + // cosine distance vector1= 0.970 + docs.get(6).addField(vectorField, Arrays.asList(5f, 10f, 20f, 40f)); + // cosine distance vector1= 0.515 + docs.get(7).addField(vectorField, Arrays.asList(120f, 60f, 30f, 15f)); + // cosine distance vector1= 0.554 + docs.get(8).addField(vectorField, Arrays.asList(200f, 50f, 100f, 25f)); + // cosine distance vector1= 0.997 + docs.get(9).addField(vectorField, Arrays.asList(1.8f, 2.5f, 3.7f, 4.9f)); + for (SolrInputDocument doc : docs) { + assertU(adoc(doc)); + } + assertU(commit()); + } + + /** Performs a single lexical query using the provided JSON request and verifies the response. */ + public void testSingleLexicalQuery() { + assertQ( + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:title test for doc 5\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\"]}}", + CommonParams.QT, + "/search"), + "//result[@numFound='10']", + "//result/doc[1]/str[@name='id'][.='5']"); + } + + /** Performs multiple lexical queries and verifies the results. */ + public void testMultipleLexicalQueryWithDebug() { + assertQ( + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:title test for doc 1\"}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:test text for doc 2\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"debugQuery\":true,\"combiner.query\":[\"lexical1\",\"lexical2\"]}}", + CommonParams.QT, + "/search"), + "//result[@numFound='10']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='2']", + "//lst[@name='debug']/lst[@name='combinerExplanations'][node()]"); + } + + /** Tests the functionality of a hybrid query that combines lexical and vector search. */ + public void testHybridQuery() { + // lexical => 2,3 + // vector => 1,4,2,10,3 + assertQ( + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical\":{\"lucene\":{\"query\":\"id:(2^=2 OR 3^=1)\"}}," + + "\"vector\":{\"knn\":{ \"f\": \"vector\", \"topK\": 5, \"query\": \"[1.0, 2.0, 3.0, 4.0]\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical\",\"vector\"]}}", + CommonParams.QT, + "/search"), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='2']", + "//result/doc[2]/str[@name='id'][.='3']", + "//result/doc[3]/str[@name='id'][.='1']"); + } + + /** Test no results in combined queries. */ + @Test + public void testNoResults() { + assertQ( + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:Solr is the blazing-fast, open source search platform\"}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:Solr powers the search\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\",\"lexical2\"]}}", + CommonParams.QT, + "/search"), + "//result[@numFound='0']"); + } + + /** Test max combiner queries limit set from solrconfig to 2. */ + @Test + public void testMaxQueriesLimit() { + assertQEx( + "Too many queries to combine: limit is 2", + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"id:(2^=2 OR 3^=1)\"}}," + + "\"vector\":{\"knn\":{ \"f\": \"vector\", \"topK\": 5, \"query\": \"[1.0, 2.0, 3.0, 4.0]\"}}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:test text for doc 2\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\",\"vector\", \"lexical2\"]}}", + CommonParams.QT, + "/search"), + SolrException.ErrorCode.BAD_REQUEST); + } + + /** + * Test to ensure the TestCombiner Algorithm is injected through solrconfigs and is being executed + * when sent the command through SolrParams + */ + @Test + public void testCombinerPlugin() { + assertQ( + req( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:title test for doc 1\"}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:test text for doc 2\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.algorithm\":test,\"combiner.query\":[\"lexical1\",\"lexical2\"]}}", + CommonParams.QT, + "/search"), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='2']", + "//result/doc[2]/str[@name='id'][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/handler/component/CombinedQuerySearchHandlerTest.java b/solr/core/src/test/org/apache/solr/handler/component/CombinedQuerySearchHandlerTest.java new file mode 100644 index 00000000000..387bb80c0ea --- /dev/null +++ b/solr/core/src/test/org/apache/solr/handler/component/CombinedQuerySearchHandlerTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import java.io.IOException; +import java.util.ArrayList; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.SolrQueryResponse; +import org.junit.BeforeClass; +import org.junit.Test; + +/** The type Combined query search handler test. */ +public class CombinedQuerySearchHandlerTest extends SolrTestCaseJ4 { + + /** + * Before tests. + * + * @throws Exception the exception + */ + @BeforeClass + public static void beforeTests() throws Exception { + initCore("solrconfig.xml", "schema.xml"); + } + + /** Test combined component init in search components list. */ + @Test + public void testCombinedComponentInit() { + SolrCore core = h.getCore(); + + try (CombinedQuerySearchHandler handler = new CombinedQuerySearchHandler()) { + handler.init(new NamedList<>()); + handler.inform(core); + assertEquals(9, handler.getComponents().size()); + assertEquals( + core.getSearchComponent(CombinedQueryComponent.COMPONENT_NAME), + handler.getComponents().getFirst()); + } catch (IOException e) { + fail("Exception when closing CombinedQuerySearchHandler"); + } + } + + /** Test combined response buildr type create dynamically. */ + @Test + public void testCombinedResponseBuilder() { + SolrQueryRequest request = req("q", "testQuery"); + try (CombinedQuerySearchHandler handler = new CombinedQuerySearchHandler()) { + assertFalse( + handler.newResponseBuilder(request, new SolrQueryResponse(), new ArrayList<>()) + instanceof CombinedQueryResponseBuilder); + request = req("q", "testQuery", CombinerParams.COMBINER, "true"); + assertTrue( + handler.newResponseBuilder(request, new SolrQueryResponse(), new ArrayList<>()) + instanceof CombinedQueryResponseBuilder); + } catch (IOException e) { + fail("Exception when closing CombinedQuerySearchHandler"); + } + } +} diff --git a/solr/core/src/test/org/apache/solr/handler/component/DistributedCombinedQueryComponentTest.java b/solr/core/src/test/org/apache/solr/handler/component/DistributedCombinedQueryComponentTest.java new file mode 100644 index 00000000000..e5369d3c2f9 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/handler/component/DistributedCombinedQueryComponentTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler.component; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.solr.BaseDistributedSearchTestCase; +import org.apache.solr.client.solrj.response.QueryResponse; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.params.CommonParams; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * The DistributedCombinedQueryComponentTest class is a JUnit test suite that evaluates the + * functionality of the CombinedQueryComponent in a Solr distributed search environment. It focuses + * on testing the integration of lexical and vector queries using the combiner component. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class DistributedCombinedQueryComponentTest extends BaseDistributedSearchTestCase { + + private static final int NUM_DOCS = 10; + private static final String vectorField = "vector"; + + /** + * Sets up the test class by initializing the core and setting system properties. This method is + * executed before all test methods in the class. + * + * @throws Exception if any exception occurs during initialization + */ + @BeforeClass + public static void setUpClass() throws Exception { + initCore("solrconfig-combined-query.xml", "schema-vector-catchall.xml"); + System.setProperty("validateAfterInactivity", "200"); + System.setProperty("solr.httpclient.retries", "0"); + System.setProperty("distribUpdateSoTimeout", "5000"); + } + + /** + * Prepares Solr input documents for indexing, including adding sample data and vector fields. + * This method populates the Solr index with test data, including text, title, and vector fields. + * The vector fields are used to calculate cosine distance for testing purposes. + * + * @throws Exception if any error occurs during the indexing process. + */ + private synchronized void prepareIndexDocs() throws Exception { + List docs = new ArrayList<>(); + for (int i = 1; i <= NUM_DOCS; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField("id", Integer.toString(i)); + doc.addField("text", "test text for doc " + i); + doc.addField("title", "title test for doc " + i); + doc.addField("nullfirst", String.valueOf(i % 3)); + docs.add(doc); + } + // cosine distance vector1= 1.0 + docs.get(0).addField(vectorField, Arrays.asList(1f, 2f, 3f, 4f)); + // cosine distance vector1= 0.998 + docs.get(1).addField(vectorField, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); + // cosine distance vector1= 0.992 + docs.get(2).addField(vectorField, Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); + // cosine distance vector1= 0.999 + docs.get(3).addField(vectorField, Arrays.asList(1.4f, 2.4f, 3.4f, 4.4f)); + // cosine distance vector1= 0.862 + docs.get(4).addField(vectorField, Arrays.asList(30f, 22f, 35f, 20f)); + // cosine distance vector1= 0.756 + docs.get(5).addField(vectorField, Arrays.asList(40f, 1f, 1f, 200f)); + // cosine distance vector1= 0.970 + docs.get(6).addField(vectorField, Arrays.asList(5f, 10f, 20f, 40f)); + // cosine distance vector1= 0.515 + docs.get(7).addField(vectorField, Arrays.asList(120f, 60f, 30f, 15f)); + // cosine distance vector1= 0.554 + docs.get(8).addField(vectorField, Arrays.asList(200f, 50f, 100f, 25f)); + // cosine distance vector1= 0.997 + docs.get(9).addField(vectorField, Arrays.asList(1.8f, 2.5f, 3.7f, 4.9f)); + del("*:*"); + for (SolrInputDocument doc : docs) { + indexDoc(doc); + } + commit(); + } + + /** + * Tests a single lexical query against the Solr server. + * + * @throws Exception if any exception occurs during the test execution + */ + public void testSingleLexicalQuery() throws Exception { + prepareIndexDocs(); + QueryResponse rsp; + rsp = + queryServer( + createParams( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"id:2^=10\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\"]}}", + "shards", + getShardsString(), + CommonParams.QT, + "/search")); + assertEquals(1, rsp.getResults().size()); + assertFieldValues(rsp.getResults(), id, "2"); + } + + /** + * Tests multiple lexical queries using the Solr server. + * + * @throws Exception if any error occurs during the test execution + */ + public void testMultipleLexicalQuery() throws Exception { + prepareIndexDocs(); + QueryResponse rsp; + rsp = + queryServer( + createParams( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:title test for doc 1\"}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:test text for doc 2\"}}}," + + "\"limit\":5," + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\",\"lexical2\"]}}", + "shards", + getShardsString(), + CommonParams.QT, + "/search")); + assertEquals(5, rsp.getResults().size()); + assertFieldValues(rsp.getResults(), id, "2", "1", "4", "5", "8"); + } + + /** + * Test multiple query execution with sort. + * + * @throws Exception the exception + */ + @Test + public void testMultipleQueryWithSort() throws Exception { + prepareIndexDocs(); + QueryResponse rsp; + rsp = + queryServer( + createParams( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical1\":{\"lucene\":{\"query\":\"title:title test for doc 1\"}}," + + "\"lexical2\":{\"lucene\":{\"query\":\"text:test text for doc 2\"}}}," + + "\"limit\":5,\"sort\":\"nullfirst desc\"" + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical1\",\"lexical2\"]}}", + "shards", + getShardsString(), + CommonParams.QT, + "/search")); + assertEquals(5, rsp.getResults().size()); + assertFieldValues(rsp.getResults(), id, "2", "8", "5", "4", "1"); + } + + /** + * Tests the hybrid query functionality of the system. + * + * @throws Exception if any unexpected error occurs during the test execution. + */ + public void testHybridQueryWithPagination() throws Exception { + prepareIndexDocs(); + // lexical => 2,3 + // vector => 1,4,2,10,3,6 + QueryResponse rsp; + rsp = + queryServer( + createParams( + CommonParams.JSON, + "{\"queries\":" + + "{\"lexical\":{\"lucene\":{\"query\":\"id:(2^=2 OR 3^=1)\"}}," + + "\"vector\":{\"knn\":{ \"f\": \"vector\", \"topK\": 5, \"query\": \"[1.0, 2.0, 3.0, 4.0]\"}}}," + + "\"limit\":4,\"offset\":1" + + "\"fields\":[\"id\",\"score\",\"title\"]," + + "\"params\":{\"combiner\":true,\"combiner.query\":[\"lexical\",\"vector\"]}}", + "shards", + getShardsString(), + CommonParams.QT, + "/search")); + assertEquals(4, rsp.getResults().size()); + assertFieldValues(rsp.getResults(), id, "3", "4", "1", "6"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/combine/ReciprocalRankFusionTest.java b/solr/core/src/test/org/apache/solr/search/combine/ReciprocalRankFusionTest.java new file mode 100644 index 00000000000..05679a84f46 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/combine/ReciprocalRankFusionTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.combine; + +import static org.apache.solr.common.params.CombinerParams.COMBINER_RRF_K; +import static org.apache.solr.common.params.CombinerParams.RECIPROCAL_RANK_FUSION; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.TotalHits; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.CombinerParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.handler.component.ShardDoc; +import org.apache.solr.search.DocSlice; +import org.apache.solr.search.QueryResult; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * The ReciprocalRankFusionTest class is a unit test suite for the {@link ReciprocalRankFusion} + * class. It verifies the correctness of the fusion algorithm and its supporting methods. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class ReciprocalRankFusionTest extends SolrTestCaseJ4 { + + public static ReciprocalRankFusion reciprocalRankFusion; + + /** + * Initializes the test environment by setting up the {@link ReciprocalRankFusion} instance with + * specific parameters. + */ + @BeforeClass + public static void beforeClass() { + NamedList args = new NamedList<>(Map.of("k", "20")); + reciprocalRankFusion = new ReciprocalRankFusion(); + reciprocalRankFusion.init(args); + } + + /** Tests the functionality of combining using RRF across local search indices. */ + @Test + public void testSearcherCombine() { + List rankedList = getQueryResults(); + SolrParams solrParams = params(COMBINER_RRF_K, "10"); + QueryResult result = reciprocalRankFusion.combine(rankedList, solrParams); + assertEquals(20, reciprocalRankFusion.getK()); + assertEquals(3, result.getDocList().size()); + } + + private static List getQueryResults() { + QueryResult r1 = new QueryResult(); + r1.setDocList( + new DocSlice( + 0, + 2, + new int[] {1, 2}, + new float[] {0.67f, 0, 62f}, + 3, + 0.67f, + TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + QueryResult r2 = new QueryResult(); + r2.setDocList( + new DocSlice( + 0, + 1, + new int[] {0}, + new float[] {0.87f}, + 2, + 0.87f, + TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + return List.of(r1, r2); + } + + /** Test shard combine for RRF. */ + @Test + public void testShardCombine() { + Map> shardDocMap = new HashMap<>(); + ShardDoc shardDoc = new ShardDoc(); + shardDoc.id = "id1"; + shardDoc.shard = "shard1"; + shardDoc.orderInShard = 1; + List shardDocList = new ArrayList<>(); + shardDocList.add(shardDoc); + shardDoc = new ShardDoc(); + shardDoc.id = "id2"; + shardDoc.shard = "shard1"; + shardDoc.orderInShard = 2; + shardDocList.add(shardDoc); + shardDocMap.put(shardDoc.shard, shardDocList); + + shardDoc = new ShardDoc(); + shardDoc.id = "id2"; + shardDoc.shard = "shard2"; + shardDoc.orderInShard = 1; + shardDocMap.put(shardDoc.shard, List.of(shardDoc)); + SolrParams solrParams = params(); + List shardDocs = reciprocalRankFusion.combine(shardDocMap, solrParams); + assertEquals(2, shardDocs.size()); + assertEquals("id2", shardDocs.getFirst().id); + } + + @Test + public void testImplementationFactory() { + Map combinerMap = new HashMap<>(1); + SolrParams emptySolrParms = params(); + assertThrows( + SolrException.class, + () -> QueryAndResponseCombiner.getImplementation(emptySolrParms, combinerMap)); + SolrParams solrParams = params(CombinerParams.COMBINER_ALGORITHM, RECIPROCAL_RANK_FUSION); + combinerMap.put(RECIPROCAL_RANK_FUSION, new ReciprocalRankFusion()); + assertTrue( + QueryAndResponseCombiner.getImplementation(solrParams, combinerMap) + instanceof ReciprocalRankFusion); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/combine/TestCombiner.java b/solr/core/src/test/org/apache/solr/search/combine/TestCombiner.java new file mode 100644 index 00000000000..677d0b45689 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/combine/TestCombiner.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.combine; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TotalHits; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.handler.component.ShardDoc; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.search.DocIterator; +import org.apache.solr.search.DocSlice; +import org.apache.solr.search.QueryResult; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.search.SortedIntDocSet; + +/** + * The TestCombiner class is an extension of QueryAndResponseCombiner that implements custom logic + * for combining ranked lists using linear sorting of score from all rank lists. This is just for + * testing purpose which has been used in test suite CombinedQueryComponentTest for e2e testing of + * Plugin based Combiner approach. + */ +public class TestCombiner extends QueryAndResponseCombiner { + + private int testInt; + + public int getTestInt() { + return testInt; + } + + @Override + public void init(NamedList args) { + Object kParam = args.get("var1"); + if (kParam != null) { + this.testInt = Integer.parseInt(kParam.toString()); + } + } + + @Override + public QueryResult combine(List rankedLists, SolrParams solrParams) { + HashMap docIdToScore = new HashMap<>(); + QueryResult queryResult = new QueryResult(); + for (QueryResult rankedList : rankedLists) { + DocIterator docs = rankedList.getDocList().iterator(); + while (docs.hasNext()) { + int docId = docs.nextDoc(); + docIdToScore.put(docId, docs.score()); + } + } + List> sortedByScoreDescending = + docIdToScore.entrySet().stream() + .sorted(Collections.reverseOrder(Map.Entry.comparingByValue())) + .toList(); + int combinedResultsLength = docIdToScore.size(); + int[] combinedResultsDocIds = new int[combinedResultsLength]; + float[] combinedResultScores = new float[combinedResultsLength]; + + int i = 0; + for (Map.Entry scoredDoc : sortedByScoreDescending) { + combinedResultsDocIds[i] = scoredDoc.getKey(); + combinedResultScores[i] = scoredDoc.getValue(); + i++; + } + DocSlice combinedResultSlice = + new DocSlice( + 0, + combinedResultsLength, + combinedResultsDocIds, + combinedResultScores, + combinedResultsLength, + combinedResultScores.length > 0 ? combinedResultScores[0] : 0, + TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + queryResult.setDocList(combinedResultSlice); + SortedIntDocSet docSet = new SortedIntDocSet(combinedResultsDocIds, combinedResultsLength); + queryResult.setDocSet(docSet); + return queryResult; + } + + @Override + public List combine(Map> shardDocMap, SolrParams solrParams) { + return List.of(); + } + + @Override + public NamedList getExplanations( + String[] queryKeys, + List queries, + List queryResults, + SolrIndexSearcher searcher, + IndexSchema schema, + SolrParams solrParams) + throws IOException { + return null; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/common/params/CombinerParams.java b/solr/solrj/src/java/org/apache/solr/common/params/CombinerParams.java new file mode 100644 index 00000000000..40c911a6507 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/common/params/CombinerParams.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.common.params; + +/** + * The CombinerParams class provides constants for configuration parameters related to the combiner. + * It defines keys for various properties used in the combiner configuration. + */ +public class CombinerParams { + + private CombinerParams() {} + + public static final String COMBINER = "combiner"; + public static final String COMBINER_ALGORITHM = COMBINER + ".algorithm"; + public static final String COMBINER_QUERY = COMBINER + ".query"; + public static final String RECIPROCAL_RANK_FUSION = "rrf"; + public static final String COMBINER_RRF_K = COMBINER + "." + RECIPROCAL_RANK_FUSION + ".k"; + public static final String DEFAULT_COMBINER = RECIPROCAL_RANK_FUSION; + public static final int COMBINER_RRF_K_DEFAULT = 60; + public static final int DEFAULT_MAX_COMBINER_QUERIES = 1; +}