(search-query) Tidy up QueryGRPCService and IndexClient

This commit is contained in:
Viktor Lofgren 2024-06-26 14:03:30 +02:00
parent 6973712480
commit 3faa5bf521
5 changed files with 59 additions and 50 deletions

View File

@ -64,6 +64,11 @@ public class GrpcMultiNodeChannelPool<STUB> {
return nodeConfigurationWatcher.getQueryNodes();
}
/** Return the number of nodes that are eligible for broadcast-style requests */
public int getNumNodes() {
return nodeConfigurationWatcher.getQueryNodes().size();
}
/** Create a new call builder for the given method. This is a fluent-style
* method, where you can chain calls to specify how to run the method.
* <p></p>

View File

@ -1,19 +1,17 @@
package nu.marginalia.functions.searchquery;
import com.google.common.collect.Lists;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import io.grpc.stub.StreamObserver;
import io.prometheus.client.Histogram;
import lombok.SneakyThrows;
import nu.marginalia.api.searchquery.*;
import nu.marginalia.api.searchquery.model.query.ProcessedQuery;
import nu.marginalia.api.searchquery.model.query.QueryParams;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.db.DomainBlacklist;
import nu.marginalia.index.api.IndexClient;
import nu.marginalia.functions.searchquery.svc.QueryFactory;
import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem;
import nu.marginalia.model.id.UrlIdCodec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -33,18 +31,18 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase {
private final QueryFactory queryFactory;
private final DomainBlacklist blacklist;
private final IndexClient indexClient;
@Inject
public QueryGRPCService(QueryFactory queryFactory,
DomainBlacklist blacklist,
IndexClient indexClient)
{
this.queryFactory = queryFactory;
this.blacklist = blacklist;
this.indexClient = indexClient;
}
/** GRPC endpoint that parses a query, delegates it to the index partitions, and then collects the results.
*/
public void query(RpcQsQuery request, StreamObserver<RpcQsResponse> responseObserver)
{
try {
@ -55,16 +53,20 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase {
var params = QueryProtobufCodec.convertRequest(request);
var query = queryFactory.createQuery(params, ResultRankingParameters.sensibleDefaults());
RpcIndexQuery indexRequest = QueryProtobufCodec.convertQuery(request, query);
List<RpcDecoratedResultItem> bestItems = executeQueries(indexRequest, request.getQueryLimits().getResultsTotal());
var indexRequest = QueryProtobufCodec.convertQuery(request, query);
// Execute the query on the index partitions
List<RpcDecoratedResultItem> bestItems = indexClient.executeQueries(indexRequest);
// Convert results to response and send it back
var responseBuilder = RpcQsResponse.newBuilder()
.addAllResults(bestItems)
.setSpecs(indexRequest)
.addAllSearchTermsHuman(query.searchTermsHuman);
if (query.domain != null)
if (query.domain != null) {
responseBuilder.setDomain(query.domain);
}
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();
@ -75,44 +77,19 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase {
}
}
private static final Comparator<RpcDecoratedResultItem> comparator =
Comparator.comparing(RpcDecoratedResultItem::getRankingScore);
private boolean isBlacklisted(RpcDecoratedResultItem item) {
return blacklist.isBlacklisted(UrlIdCodec.getDomainId(item.getRawItem().getCombinedId()));
}
public record DetailedDirectResult(ProcessedQuery processedQuery,
List<DecoratedSearchResultItem> result) {}
/** Local query execution, without GRPC. */
public DetailedDirectResult executeDirect(
String originalQuery,
QueryParams params,
ResultRankingParameters rankingParameters,
int count) {
ResultRankingParameters rankingParameters) {
var query = queryFactory.createQuery(params, rankingParameters);
var items = indexClient.executeQueries(QueryProtobufCodec.convertQuery(originalQuery, query));
var items = executeQueries(
QueryProtobufCodec.convertQuery(originalQuery, query),
count)
.stream().map(QueryProtobufCodec::convertQueryResult)
.toList();
return new DetailedDirectResult(query, items);
}
public record DetailedDirectResult(ProcessedQuery processedQuery,
List<DecoratedSearchResultItem> result) {}
@SneakyThrows
List<RpcDecoratedResultItem> executeQueries(RpcIndexQuery indexRequest, int totalSize) {
var results = indexClient.executeQueries(indexRequest);
results.sort(comparator);
results.removeIf(this::isBlacklisted);
if (results.size() > totalSize) {
results = results.subList(0, totalSize);
}
return results;
return new DetailedDirectResult(query, Lists.transform(items, QueryProtobufCodec::convertQueryResult));
}
}

View File

@ -15,6 +15,7 @@ dependencies {
implementation project(':code:common:model')
implementation project(':code:common:config')
implementation project(':code:common:service')
implementation project(':code:common:db')
implementation project(':code:libraries:message-queue')
implementation project(':code:functions:search-query:api')

View File

@ -6,6 +6,8 @@ import lombok.SneakyThrows;
import nu.marginalia.api.searchquery.IndexApiGrpc;
import nu.marginalia.api.searchquery.RpcDecoratedResultItem;
import nu.marginalia.api.searchquery.RpcIndexQuery;
import nu.marginalia.db.DomainBlacklistImpl;
import nu.marginalia.model.id.UrlIdCodec;
import nu.marginalia.service.client.GrpcChannelPoolFactory;
import nu.marginalia.service.client.GrpcMultiNodeChannelPool;
import nu.marginalia.service.discovery.property.ServiceKey;
@ -14,6 +16,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@ -22,22 +25,34 @@ import java.util.concurrent.Executors;
public class IndexClient {
private static final Logger logger = LoggerFactory.getLogger(IndexClient.class);
private final GrpcMultiNodeChannelPool<IndexApiGrpc.IndexApiBlockingStub> channelPool;
private final DomainBlacklistImpl blacklist;
private static final ExecutorService executor = Executors.newFixedThreadPool(32);
@Inject
public IndexClient(GrpcChannelPoolFactory channelPoolFactory) {
public IndexClient(GrpcChannelPoolFactory channelPoolFactory, DomainBlacklistImpl blacklist) {
this.channelPool = channelPoolFactory.createMulti(
ServiceKey.forGrpcApi(IndexApiGrpc.class, ServicePartition.multi()),
IndexApiGrpc::newBlockingStub);
this.blacklist = blacklist;
}
private static final Comparator<RpcDecoratedResultItem> comparator =
Comparator.comparing(RpcDecoratedResultItem::getRankingScore);
/** Execute a query on the index partitions and return the combined results. */
@SneakyThrows
public List<RpcDecoratedResultItem> executeQueries(RpcIndexQuery indexRequest) {
var futures =
channelPool.call(IndexApiGrpc.IndexApiBlockingStub::query)
.async(executor)
.runEach(indexRequest);
List<RpcDecoratedResultItem> results = new ArrayList<>();
final int resultsTotal = indexRequest.getQueryLimits().getResultsTotal();
final int resultsUpperBound = resultsTotal * channelPool.getNumNodes();
List<RpcDecoratedResultItem> results = new ArrayList<>(resultsUpperBound);
for (var future : futures) {
try {
future.get().forEachRemaining(results::add);
@ -47,7 +62,20 @@ public class IndexClient {
}
}
// Sort the results by ranking score and remove blacklisted domains
results.sort(comparator);
results.removeIf(this::isBlacklisted);
// Keep only as many results as were requested
if (results.size() > resultsTotal) {
results = results.subList(0, resultsTotal);
}
return results;
}
private boolean isBlacklisted(RpcDecoratedResultItem item) {
return blacklist.isBlacklisted(UrlIdCodec.getDomainId(item.getRawItem().getCombinedId()));
}
}

View File

@ -48,10 +48,9 @@ public class QueryBasicInterface {
domainCount, count, 250, 8192
), set);
var detailedDirectResult = queryGRPCService.executeDirect(queryParams,
params,
ResultRankingParameters.sensibleDefaults(),
count);
var detailedDirectResult = queryGRPCService.executeDirect(
queryParams, params, ResultRankingParameters.sensibleDefaults()
);
var results = detailedDirectResult.result();
@ -85,10 +84,9 @@ public class QueryBasicInterface {
var rankingParams = rankingParamsFromRequest(request);
var detailedDirectResult = queryGRPCService.executeDirect(queryString,
queryParams,
rankingParams,
count);
var detailedDirectResult = queryGRPCService.executeDirect(
queryString, queryParams, rankingParams
);
var results = detailedDirectResult.result();