From 73f973cc065d92242171ae695116acda24657a7f Mon Sep 17 00:00:00 2001 From: Viktor Lofgren Date: Wed, 25 Sep 2024 12:56:38 +0200 Subject: [PATCH] (search-query) Add pagination to search query API and the direct query-service interface --- .../api/src/main/protobuf/query-api.proto | 15 +++++ .../searchquery/QueryGRPCService.java | 29 ++++++--- .../nu/marginalia/index/api/IndexClient.java | 33 +++++++--- .../marginalia/query/QueryBasicInterface.java | 60 ++++++++++++++----- .../resources/templates/search.hdb | 14 +++++ 5 files changed, 120 insertions(+), 31 deletions(-) diff --git a/code/functions/search-query/api/src/main/protobuf/query-api.proto b/code/functions/search-query/api/src/main/protobuf/query-api.proto index 1504d46f..b505600b 100644 --- a/code/functions/search-query/api/src/main/protobuf/query-api.proto +++ b/code/functions/search-query/api/src/main/protobuf/query-api.proto @@ -30,6 +30,8 @@ message RpcQsQuery { string searchSetIdentifier = 14; string queryStrategy = 15; // Named query configuration RpcTemporalBias temporalBias = 16; + + RpcQsQueryPagination pagination = 17; } /* Query service query response */ @@ -39,6 +41,19 @@ message RpcQsResponse { repeated string searchTermsHuman = 3; repeated string problems = 4; string domain = 5; + + RpcQsResultPagination pagination = 6; +} + +message RpcQsQueryPagination { + int32 page = 1; + int32 pageSize = 2; +} + +message RpcQsResultPagination { + int32 page = 1; + int32 pageSize = 2; + int32 totalResults = 3; } message RpcTemporalBias { diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java index e4bac6e2..41951a3c 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/QueryGRPCService.java @@ -8,13 +8,13 @@ import io.prometheus.client.Histogram; import nu.marginalia.api.searchquery.*; import nu.marginalia.api.searchquery.model.query.ProcessedQuery; import nu.marginalia.api.searchquery.model.query.QueryParams; +import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.api.IndexClient; -import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.*; +import java.util.List; @Singleton public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { @@ -54,12 +54,23 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { var indexRequest = QueryProtobufCodec.convertQuery(request, query); + var requestPagination = request.getPagination(); + + IndexClient.Pagination pagination = new IndexClient.Pagination( + requestPagination.getPage(), + requestPagination.getPageSize()); + // Execute the query on the index partitions - List bestItems = indexClient.executeQueries(indexRequest); + IndexClient.AggregateQueryResponse response = indexClient.executeQueries(indexRequest, pagination); // Convert results to response and send it back var responseBuilder = RpcQsResponse.newBuilder() - .addAllResults(bestItems) + .addAllResults(response.results()) + .setPagination( + RpcQsResultPagination.newBuilder() + .setPage(response.page()) + .setTotalResults(response.totalResults()) + ) .setSpecs(indexRequest) .addAllSearchTermsHuman(query.searchTermsHuman); @@ -77,18 +88,22 @@ public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { } public record DetailedDirectResult(ProcessedQuery processedQuery, - List result) {} + List result, + int totalResults) {} /** Local query execution, without GRPC. */ public DetailedDirectResult executeDirect( String originalQuery, QueryParams params, + IndexClient.Pagination pagination, ResultRankingParameters rankingParameters) { var query = queryFactory.createQuery(params, rankingParameters); - var items = indexClient.executeQueries(QueryProtobufCodec.convertQuery(originalQuery, query)); + IndexClient.AggregateQueryResponse response = indexClient.executeQueries(QueryProtobufCodec.convertQuery(originalQuery, query), pagination); - return new DetailedDirectResult(query, Lists.transform(items, QueryProtobufCodec::convertQueryResult)); + return new DetailedDirectResult(query, + Lists.transform(response.results(), QueryProtobufCodec::convertQueryResult), + response.totalResults()); } } diff --git a/code/index/api/java/nu/marginalia/index/api/IndexClient.java b/code/index/api/java/nu/marginalia/index/api/IndexClient.java index e0383a27..ddd16584 100644 --- a/code/index/api/java/nu/marginalia/index/api/IndexClient.java +++ b/code/index/api/java/nu/marginalia/index/api/IndexClient.java @@ -17,10 +17,14 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Comparator; +import java.util.Iterator; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import static java.lang.Math.clamp; + @Singleton public class IndexClient { private static final Logger logger = LoggerFactory.getLogger(IndexClient.class); @@ -39,17 +43,23 @@ public class IndexClient { private static final Comparator comparator = Comparator.comparing(RpcDecoratedResultItem::getRankingScore); + public record Pagination(int page, int pageSize) {} + + public record AggregateQueryResponse(List results, + int page, + int totalResults + ) {} /** Execute a query on the index partitions and return the combined results. */ @SneakyThrows - public List executeQueries(RpcIndexQuery indexRequest) { - var futures = + public AggregateQueryResponse executeQueries(RpcIndexQuery indexRequest, Pagination pagination) { + List>> futures = channelPool.call(IndexApiGrpc.IndexApiBlockingStub::query) .async(executor) .runEach(indexRequest); - final int resultsTotal = indexRequest.getQueryLimits().getResultsTotal(); - final int resultsUpperBound = resultsTotal * channelPool.getNumNodes(); + final int requestedMaxResults = indexRequest.getQueryLimits().getResultsTotal(); + final int resultsUpperBound = requestedMaxResults * channelPool.getNumNodes(); List results = new ArrayList<>(resultsUpperBound); @@ -66,12 +76,17 @@ public class IndexClient { results.sort(comparator); results.removeIf(this::isBlacklisted); - // Keep only as many results as were requested - if (results.size() > resultsTotal) { - results = results.subList(0, resultsTotal); - } + int numReceivedResults = results.size(); - return results; + // pagination is typically 1-indexed, so we need to adjust the start and end indices + int indexStart = (pagination.page - 1) * pagination.pageSize; + int indexEnd = (pagination.page) * pagination.pageSize; + + results = results.subList( + clamp(indexStart, 0, results.size() - 1), // from is inclusive, so subtract 1 from size() + clamp(indexEnd, 0, results.size())); + + return new AggregateQueryResponse(results, pagination.page(), numReceivedResults); } private boolean isBlacklisted(RpcDecoratedResultItem item) { diff --git a/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java b/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java index 937b80d7..836f4daa 100644 --- a/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java +++ b/code/services-core/query-service/java/nu/marginalia/query/QueryBasicInterface.java @@ -7,6 +7,7 @@ import nu.marginalia.api.searchquery.model.query.QueryParams; import nu.marginalia.api.searchquery.model.results.Bm25Parameters; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.functions.searchquery.QueryGRPCService; +import nu.marginalia.index.api.IndexClient; import nu.marginalia.index.query.limit.QueryLimits; import nu.marginalia.model.gson.GsonFactory; import nu.marginalia.renderer.MustacheRenderer; @@ -15,8 +16,14 @@ import spark.Request; import spark.Response; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import static java.lang.Integer.min; +import static java.lang.Integer.parseInt; +import static java.util.Objects.requireNonNullElse; + public class QueryBasicInterface { private final MustacheRenderer basicRenderer; private final MustacheRenderer qdebugRenderer; @@ -34,38 +41,53 @@ public class QueryBasicInterface { this.queryGRPCService = queryGRPCService; } + /** Handle the basic search endpoint exposed in the bare-bones search interface. */ public Object handleBasic(Request request, Response response) { - String queryParams = request.queryParams("q"); - if (queryParams == null) { + String queryString = request.queryParams("q"); + if (queryString == null) { return basicRenderer.render(new Object()); } - int count = request.queryParams("count") == null ? 10 : Integer.parseInt(request.queryParams("count")); - int domainCount = request.queryParams("domainCount") == null ? 5 : Integer.parseInt(request.queryParams("domainCount")); - String set = request.queryParams("set") == null ? "" : request.queryParams("set"); + int count = parseInt(requireNonNullElse(request.queryParams("count"), "10")); + int page = parseInt(requireNonNullElse(request.queryParams("page"), "1")); + int domainCount = parseInt(requireNonNullElse(request.queryParams("domainCount"), "5")); + String set = requireNonNullElse(request.queryParams("set"), ""); - var params = new QueryParams(queryParams, new QueryLimits( - domainCount, count, 250, 8192 + var params = new QueryParams(queryString, new QueryLimits( + domainCount, min(100, count * 10), 250, 8192 ), set); + var pagination = new IndexClient.Pagination(page, count); + var detailedDirectResult = queryGRPCService.executeDirect( - queryParams, params, ResultRankingParameters.sensibleDefaults() + queryString, + params, + pagination, + ResultRankingParameters.sensibleDefaults() ); var results = detailedDirectResult.result(); + List paginationInfo = new ArrayList<>(); + + for (int i = 1; i <= detailedDirectResult.totalResults() / pagination.pageSize(); i++) { + paginationInfo.add(new PaginationInfoPage(i, i == pagination.page())); + } + if (request.headers("Accept").contains("application/json")) { response.type("application/json"); return gson.toJson(results); } else { return basicRenderer.render( - Map.of("query", queryParams, + Map.of("query", queryString, + "pages", paginationInfo, "results", results) ); } } + /** Handle the qdebug endpoint, which allows for query debugging and ranking parameter tuning. */ public Object handleAdvanced(Request request, Response response) { String queryString = request.queryParams("q"); if (queryString == null) { @@ -74,18 +96,24 @@ public class QueryBasicInterface { ); } - int count = request.queryParams("count") == null ? 10 : Integer.parseInt(request.queryParams("count")); - int domainCount = request.queryParams("domainCount") == null ? 5 : Integer.parseInt(request.queryParams("domainCount")); - String set = request.queryParams("set") == null ? "" : request.queryParams("set"); + int count = parseInt(requireNonNullElse(request.queryParams("count"), "10")); + int page = parseInt(requireNonNullElse(request.queryParams("page"), "1")); + int domainCount = parseInt(requireNonNullElse(request.queryParams("domainCount"), "5")); + String set = requireNonNullElse(request.queryParams("set"), ""); var queryParams = new QueryParams(queryString, new QueryLimits( - domainCount, count, 250, 8192 + domainCount, min(100, count * 10), 250, 8192 ), set); + var pagination = new IndexClient.Pagination(page, count); + var rankingParams = debugRankingParamsFromRequest(request); var detailedDirectResult = queryGRPCService.executeDirect( - queryString, queryParams, rankingParams + queryString, + queryParams, + pagination, + rankingParams ); var results = detailedDirectResult.result(); @@ -127,10 +155,12 @@ public class QueryBasicInterface { } int intFromRequest(Request request, String param, int defaultValue) { - return Strings.isNullOrEmpty(request.queryParams(param)) ? defaultValue : Integer.parseInt(request.queryParams(param)); + return Strings.isNullOrEmpty(request.queryParams(param)) ? defaultValue : parseInt(request.queryParams(param)); } String stringFromRequest(Request request, String param, String defaultValue) { return Strings.isNullOrEmpty(request.queryParams(param)) ? defaultValue : request.queryParams(param); } + + record PaginationInfoPage(int number, boolean current) {} } diff --git a/code/services-core/query-service/resources/templates/search.hdb b/code/services-core/query-service/resources/templates/search.hdb index 14bbf2b5..86bcca2a 100644 --- a/code/services-core/query-service/resources/templates/search.hdb +++ b/code/services-core/query-service/resources/templates/search.hdb @@ -24,6 +24,20 @@
{{url}}

{{description}}

+{{/each}} + +{{#each pages}} + {{/each}} {{/if}}