diff --git a/code/common/service/java/nu/marginalia/service/client/GrpcMultiNodeChannelPool.java b/code/common/service/java/nu/marginalia/service/client/GrpcMultiNodeChannelPool.java index d4f75e66..de74adb4 100644 --- a/code/common/service/java/nu/marginalia/service/client/GrpcMultiNodeChannelPool.java +++ b/code/common/service/java/nu/marginalia/service/client/GrpcMultiNodeChannelPool.java @@ -64,6 +64,11 @@ public class GrpcMultiNodeChannelPool { return nodeConfigurationWatcher.getQueryNodes(); } + /** Return the number of nodes that are eligible for broadcast-style requests */ + public int getNumNodes() { + return nodeConfigurationWatcher.getQueryNodes().size(); + } + /** Create a new call builder for the given method. This is a fluent-style * method, where you can chain calls to specify how to run the method. *

diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java index 98f7fb6f..4da55bc1 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java @@ -1,19 +1,17 @@ package nu.marginalia.functions.searchquery; +import com.google.common.collect.Lists; import com.google.inject.Inject; import com.google.inject.Singleton; import io.grpc.stub.StreamObserver; import io.prometheus.client.Histogram; -import lombok.SneakyThrows; import nu.marginalia.api.searchquery.*; import nu.marginalia.api.searchquery.model.query.ProcessedQuery; import nu.marginalia.api.searchquery.model.query.QueryParams; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; -import nu.marginalia.db.DomainBlacklist; import nu.marginalia.index.api.IndexClient; import nu.marginalia.functions.searchquery.svc.QueryFactory; import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem; -import nu.marginalia.model.id.UrlIdCodec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,18 +31,18 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { private final QueryFactory queryFactory; - private final DomainBlacklist blacklist; private final IndexClient indexClient; + @Inject public QueryGRPCService(QueryFactory queryFactory, - DomainBlacklist blacklist, IndexClient indexClient) { this.queryFactory = queryFactory; - this.blacklist = blacklist; this.indexClient = indexClient; } + /** GRPC endpoint that parses a query, delegates it to the index partitions, and then collects the results. + */ public void query(RpcQsQuery request, StreamObserver responseObserver) { try { @@ -55,16 +53,20 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { var params = QueryProtobufCodec.convertRequest(request); var query = queryFactory.createQuery(params, ResultRankingParameters.sensibleDefaults()); - RpcIndexQuery indexRequest = QueryProtobufCodec.convertQuery(request, query); - List bestItems = executeQueries(indexRequest, request.getQueryLimits().getResultsTotal()); + var indexRequest = QueryProtobufCodec.convertQuery(request, query); + // Execute the query on the index partitions + List bestItems = indexClient.executeQueries(indexRequest); + + // Convert results to response and send it back var responseBuilder = RpcQsResponse.newBuilder() .addAllResults(bestItems) .setSpecs(indexRequest) .addAllSearchTermsHuman(query.searchTermsHuman); - if (query.domain != null) + if (query.domain != null) { responseBuilder.setDomain(query.domain); + } responseObserver.onNext(responseBuilder.build()); responseObserver.onCompleted(); @@ -75,44 +77,19 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { } } - private static final Comparator comparator = - Comparator.comparing(RpcDecoratedResultItem::getRankingScore); - - - private boolean isBlacklisted(RpcDecoratedResultItem item) { - return blacklist.isBlacklisted(UrlIdCodec.getDomainId(item.getRawItem().getCombinedId())); - } + public record DetailedDirectResult(ProcessedQuery processedQuery, + List result) {} + /** Local query execution, without GRPC. */ public DetailedDirectResult executeDirect( String originalQuery, QueryParams params, - ResultRankingParameters rankingParameters, - int count) { + ResultRankingParameters rankingParameters) { var query = queryFactory.createQuery(params, rankingParameters); + var items = indexClient.executeQueries(QueryProtobufCodec.convertQuery(originalQuery, query)); - var items = executeQueries( - QueryProtobufCodec.convertQuery(originalQuery, query), - count) - .stream().map(QueryProtobufCodec::convertQueryResult) - .toList(); - - return new DetailedDirectResult(query, items); - } - - public record DetailedDirectResult(ProcessedQuery processedQuery, - List result) {} - - @SneakyThrows - List executeQueries(RpcIndexQuery indexRequest, int totalSize) { - var results = indexClient.executeQueries(indexRequest); - - results.sort(comparator); - results.removeIf(this::isBlacklisted); - if (results.size() > totalSize) { - results = results.subList(0, totalSize); - } - return results; + return new DetailedDirectResult(query, Lists.transform(items, QueryProtobufCodec::convertQueryResult)); } } diff --git a/code/index/api/build.gradle b/code/index/api/build.gradle index 1c0873a8..7f958c0e 100644 --- a/code/index/api/build.gradle +++ b/code/index/api/build.gradle @@ -15,6 +15,7 @@ dependencies { implementation project(':code:common:model') implementation project(':code:common:config') implementation project(':code:common:service') + implementation project(':code:common:db') implementation project(':code:libraries:message-queue') implementation project(':code:functions:search-query:api') diff --git a/code/index/api/java/nu/marginalia/index/api/IndexClient.java b/code/index/api/java/nu/marginalia/index/api/IndexClient.java index 9dd14920..e0383a27 100644 --- a/code/index/api/java/nu/marginalia/index/api/IndexClient.java +++ b/code/index/api/java/nu/marginalia/index/api/IndexClient.java @@ -6,6 +6,8 @@ import lombok.SneakyThrows; import nu.marginalia.api.searchquery.IndexApiGrpc; import nu.marginalia.api.searchquery.RpcDecoratedResultItem; import nu.marginalia.api.searchquery.RpcIndexQuery; +import nu.marginalia.db.DomainBlacklistImpl; +import nu.marginalia.model.id.UrlIdCodec; import nu.marginalia.service.client.GrpcChannelPoolFactory; import nu.marginalia.service.client.GrpcMultiNodeChannelPool; import nu.marginalia.service.discovery.property.ServiceKey; @@ -14,6 +16,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -22,22 +25,34 @@ import java.util.concurrent.Executors; public class IndexClient { private static final Logger logger = LoggerFactory.getLogger(IndexClient.class); private final GrpcMultiNodeChannelPool channelPool; + private final DomainBlacklistImpl blacklist; private static final ExecutorService executor = Executors.newFixedThreadPool(32); @Inject - public IndexClient(GrpcChannelPoolFactory channelPoolFactory) { + public IndexClient(GrpcChannelPoolFactory channelPoolFactory, DomainBlacklistImpl blacklist) { this.channelPool = channelPoolFactory.createMulti( ServiceKey.forGrpcApi(IndexApiGrpc.class, ServicePartition.multi()), IndexApiGrpc::newBlockingStub); + this.blacklist = blacklist; } + private static final Comparator comparator = + Comparator.comparing(RpcDecoratedResultItem::getRankingScore); + + + /** Execute a query on the index partitions and return the combined results. */ @SneakyThrows public List executeQueries(RpcIndexQuery indexRequest) { var futures = channelPool.call(IndexApiGrpc.IndexApiBlockingStub::query) .async(executor) .runEach(indexRequest); - List results = new ArrayList<>(); + + final int resultsTotal = indexRequest.getQueryLimits().getResultsTotal(); + final int resultsUpperBound = resultsTotal * channelPool.getNumNodes(); + + List results = new ArrayList<>(resultsUpperBound); + for (var future : futures) { try { future.get().forEachRemaining(results::add); @@ -47,7 +62,20 @@ public class IndexClient { } } + // Sort the results by ranking score and remove blacklisted domains + results.sort(comparator); + results.removeIf(this::isBlacklisted); + + // Keep only as many results as were requested + if (results.size() > resultsTotal) { + results = results.subList(0, resultsTotal); + } + return results; } + private boolean isBlacklisted(RpcDecoratedResultItem item) { + return blacklist.isBlacklisted(UrlIdCodec.getDomainId(item.getRawItem().getCombinedId())); + } + } diff --git a/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java b/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java index 152f6a78..62af8591 100644 --- a/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java +++ b/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java @@ -48,10 +48,9 @@ public class QueryBasicInterface { domainCount, count, 250, 8192 ), set); - var detailedDirectResult = queryGRPCService.executeDirect(queryParams, - params, - ResultRankingParameters.sensibleDefaults(), - count); + var detailedDirectResult = queryGRPCService.executeDirect( + queryParams, params, ResultRankingParameters.sensibleDefaults() + ); var results = detailedDirectResult.result(); @@ -85,10 +84,9 @@ public class QueryBasicInterface { var rankingParams = rankingParamsFromRequest(request); - var detailedDirectResult = queryGRPCService.executeDirect(queryString, - queryParams, - rankingParams, - count); + var detailedDirectResult = queryGRPCService.executeDirect( + queryString, queryParams, rankingParams + ); var results = detailedDirectResult.result();