From a860f8f1a89ceb219f6729845bffe1354ddaf1a6 Mon Sep 17 00:00:00 2001 From: Viktor Lofgren Date: Tue, 24 Oct 2023 11:09:12 +0200 Subject: [PATCH] (index/qs) GRPC API for better query peformance --- code/api/index-api/build.gradle | 34 +++- .../index/client/IndexProtobufCodec.java | 117 ++++++++++++ .../client/model/query/SearchSubquery.java | 4 + .../results/ResultRankingParameters.java | 4 +- .../src/main/protobuf/index-api.proto | 139 +++++++++++++++ .../index/client/IndexProtobufCodecTest.java | 50 ++++++ code/api/query-api/build.gradle | 2 + .../marginalia/query/QueryProtobufCodec.java | 166 ++++++++++++++++++ .../marginalia/query/client/QueryClient.java | 39 +++- .../marginalia/query/model/QueryParams.java | 2 + code/services-core/index-service/build.gradle | 1 + .../nu/marginalia/index/IndexService.java | 10 +- .../index/svc/IndexQueryService.java | 71 +++++++- .../index/svc/SearchParameters.java | 33 ++++ code/services-core/query-service/build.gradle | 1 + .../nu/marginalia/query/QueryGRPCService.java | 148 ++++++++++++++++ .../nu/marginalia/query/QueryService.java | 13 +- settings.gradle | 9 +- 18 files changed, 822 insertions(+), 21 deletions(-) create mode 100644 code/api/index-api/src/main/java/nu/marginalia/index/client/IndexProtobufCodec.java create mode 100644 code/api/index-api/src/main/protobuf/index-api.proto create mode 100644 code/api/index-api/src/test/java/nu/marginalia/index/client/IndexProtobufCodecTest.java create mode 100644 code/api/query-api/src/main/java/nu/marginalia/query/QueryProtobufCodec.java create mode 100644 code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java diff --git a/code/api/index-api/build.gradle b/code/api/index-api/build.gradle index 001d16c0..37d76aa0 100644 --- a/code/api/index-api/build.gradle +++ b/code/api/index-api/build.gradle @@ -1,7 +1,7 @@ plugins { id 'java' - + id "com.google.protobuf" version "0.9.4" id 'jvm-test-suite' } @@ -10,7 +10,13 @@ java { languageVersion.set(JavaLanguageVersion.of(21)) } } - +sourceSets { + main { + proto { + srcDir 'src/main/protobuf' + } + } +} dependencies { implementation project(':code:common:model') implementation project(':code:common:config') @@ -26,10 +32,32 @@ dependencies { implementation libs.guice implementation libs.rxjava implementation libs.protobuf - implementation libs.bundles.gson implementation libs.fastutil + implementation libs.javax.annotation + implementation libs.bundles.gson + implementation libs.bundles.grpc testImplementation libs.bundles.slf4j.test testImplementation libs.bundles.junit testImplementation libs.mockito } + +protobuf { + protoc { + artifact = "com.google.protobuf:protoc:3.0.2" + } + plugins { + grpc { + artifact = 'io.grpc:protoc-gen-grpc-java:1.1.2' + } + } + + generateProtoTasks { + all().each { task -> + task.plugins { + grpc {} + } + } + } +} + diff --git a/code/api/index-api/src/main/java/nu/marginalia/index/client/IndexProtobufCodec.java b/code/api/index-api/src/main/java/nu/marginalia/index/client/IndexProtobufCodec.java new file mode 100644 index 00000000..1178ea2d --- /dev/null +++ b/code/api/index-api/src/main/java/nu/marginalia/index/client/IndexProtobufCodec.java @@ -0,0 +1,117 @@ +package nu.marginalia.index.client; + +import nu.marginalia.index.api.*; +import nu.marginalia.index.client.model.query.SearchSubquery; +import nu.marginalia.index.client.model.results.Bm25Parameters; +import nu.marginalia.index.client.model.results.ResultRankingParameters; +import nu.marginalia.index.query.limit.QueryLimits; +import nu.marginalia.index.query.limit.SpecificationLimit; +import nu.marginalia.index.query.limit.SpecificationLimitType; + +import java.util.ArrayList; +import java.util.List; + +public class IndexProtobufCodec { + + public static SpecificationLimit convertSpecLimit(RpcSpecLimit limit) { + return new SpecificationLimit( + SpecificationLimitType.valueOf(limit.getType().name()), + limit.getValue() + ); + } + + public static RpcSpecLimit convertSpecLimit(SpecificationLimit limit) { + return RpcSpecLimit.newBuilder() + .setType(RpcSpecLimit.TYPE.valueOf(limit.type().name())) + .setValue(limit.value()) + .build(); + } + + public static QueryLimits convertQueryLimits(RpcQueryLimits queryLimits) { + return new QueryLimits( + queryLimits.getResultsByDomain(), + queryLimits.getResultsTotal(), + queryLimits.getTimeoutMs(), + queryLimits.getFetchSize() + ); + } + + public static RpcQueryLimits convertQueryLimits(QueryLimits queryLimits) { + return RpcQueryLimits.newBuilder() + .setResultsByDomain(queryLimits.resultsByDomain()) + .setResultsTotal(queryLimits.resultsTotal()) + .setTimeoutMs(queryLimits.timeoutMs()) + .setFetchSize(queryLimits.fetchSize()) + .build(); + } + + public static SearchSubquery convertSearchSubquery(RpcSubquery subquery) { + List> coherences = new ArrayList<>(); + + for (int j = 0; j < subquery.getCoherencesCount(); j++) { + var coh = subquery.getCoherences(j); + coherences.add(new ArrayList<>(coh.getCoherencesList())); + } + + return new SearchSubquery( + subquery.getIncludeList(), + subquery.getExcludeList(), + subquery.getAdviceList(), + subquery.getPriorityList(), + coherences + ); + } + + public static RpcSubquery convertSearchSubquery(SearchSubquery searchSubquery) { + var subqueryBuilder = + RpcSubquery.newBuilder() + .addAllAdvice(searchSubquery.getSearchTermsAdvice()) + .addAllExclude(searchSubquery.getSearchTermsExclude()) + .addAllInclude(searchSubquery.getSearchTermsInclude()) + .addAllPriority(searchSubquery.getSearchTermsPriority()); + for (var coherences : searchSubquery.searchTermCoherences) { + subqueryBuilder.addCoherencesBuilder().addAllCoherences(coherences); + } + return subqueryBuilder.build(); + } + + public static ResultRankingParameters convertRankingParameterss(RpcResultRankingParameters params) { + return new ResultRankingParameters( + new Bm25Parameters(params.getFullK(), params.getFullB()), + new Bm25Parameters(params.getPrioK(), params.getPrioB()), + params.getShortDocumentThreshold(), + params.getShortDocumentPenalty(), + params.getDomainRankBonus(), + params.getQualityPenalty(), + params.getShortSentenceThreshold(), + params.getShortSentencePenalty(), + params.getBm25FullWeight(), + params.getBm25PrioWeight(), + params.getTcfWeight(), + ResultRankingParameters.TemporalBias.valueOf(params.getTemporalBias().name()), + params.getTemporalBiasWeight() + ); + }; + + public static RpcResultRankingParameters convertRankingParameterss(ResultRankingParameters rankingParams) { + return + RpcResultRankingParameters.newBuilder() + .setFullB(rankingParams.fullParams.b()) + .setFullK(rankingParams.fullParams.k()) + .setPrioB(rankingParams.prioParams.b()) + .setPrioK(rankingParams.prioParams.k()) + .setShortDocumentThreshold(rankingParams.shortDocumentThreshold) + .setShortDocumentPenalty(rankingParams.shortDocumentPenalty) + .setDomainRankBonus(rankingParams.domainRankBonus) + .setQualityPenalty(rankingParams.qualityPenalty) + .setShortSentenceThreshold(rankingParams.shortSentenceThreshold) + .setShortSentencePenalty(rankingParams.shortSentencePenalty) + .setBm25FullWeight(rankingParams.bm25FullWeight) + .setBm25PrioWeight(rankingParams.bm25PrioWeight) + .setTcfWeight(rankingParams.tcfWeight) + .setTemporalBias(RpcResultRankingParameters.TEMPORAL_BIAS.valueOf(rankingParams.temporalBias.name())) + .setTemporalBiasWeight(rankingParams.temporalBiasWeight) + .build(); + } + +} diff --git a/code/api/index-api/src/main/java/nu/marginalia/index/client/model/query/SearchSubquery.java b/code/api/index-api/src/main/java/nu/marginalia/index/client/model/query/SearchSubquery.java index 1cc1edd8..18eb34e7 100644 --- a/code/api/index-api/src/main/java/nu/marginalia/index/client/model/query/SearchSubquery.java +++ b/code/api/index-api/src/main/java/nu/marginalia/index/client/model/query/SearchSubquery.java @@ -1,6 +1,7 @@ package nu.marginalia.index.client.model.query; import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.With; @@ -10,6 +11,7 @@ import java.util.stream.Collectors; @Getter @AllArgsConstructor @With +@EqualsAndHashCode public class SearchSubquery { /** These terms must be present in the document and are used in ranking*/ @@ -27,6 +29,7 @@ public class SearchSubquery { /** Terms that we require to be in the same sentence */ public final List> searchTermCoherences; + @Deprecated // why does this exist? private double value = 0; public SearchSubquery() { @@ -49,6 +52,7 @@ public class SearchSubquery { this.searchTermCoherences = searchTermCoherences; } + @Deprecated // why does this exist? public SearchSubquery setValue(double value) { if (Double.isInfinite(value) || Double.isNaN(value)) { this.value = Double.MAX_VALUE; diff --git a/code/api/index-api/src/main/java/nu/marginalia/index/client/model/results/ResultRankingParameters.java b/code/api/index-api/src/main/java/nu/marginalia/index/client/model/results/ResultRankingParameters.java index ff28c5d5..a77d9a9a 100644 --- a/code/api/index-api/src/main/java/nu/marginalia/index/client/model/results/ResultRankingParameters.java +++ b/code/api/index-api/src/main/java/nu/marginalia/index/client/model/results/ResultRankingParameters.java @@ -2,8 +2,10 @@ package nu.marginalia.index.client.model.results; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; -@Builder @AllArgsConstructor +@Builder @AllArgsConstructor @ToString @EqualsAndHashCode public class ResultRankingParameters { /** Tuning for BM25 when applied to full document matches */ diff --git a/code/api/index-api/src/main/protobuf/index-api.proto b/code/api/index-api/src/main/protobuf/index-api.proto new file mode 100644 index 00000000..53c7cd7b --- /dev/null +++ b/code/api/index-api/src/main/protobuf/index-api.proto @@ -0,0 +1,139 @@ +syntax="proto3"; +package actorapi; + +option java_package="nu.marginalia.index.api"; +option java_multiple_files=true; + +service QueryApi { + rpc query(RpcQsQuery) returns (RpcQsResponse) {} +} +service IndexApi { + rpc query(RpcIndexQuery) returns (RpcSearchResultSet) {} +} + +message Empty {} + +message RpcQsQuery { + string humanQuery = 1; + string nearDomain = 2; + repeated string tacitIncludes = 3; + repeated string tacitExcludes = 4; + repeated string tacitPriority = 5; + repeated string tacitAdvice = 6; + RpcSpecLimit quality = 7; + RpcSpecLimit year = 8; + RpcSpecLimit size = 9; + RpcSpecLimit rank = 10; + repeated int32 domainIds = 11; + RpcQueryLimits queryLimits = 12; + string searchSetIdentifier = 13; +} + +message RpcQsResponse { + RpcIndexQuery specs = 1; + repeated RpcDecoratedResultItem results = 2; + repeated string searchTermsHuman = 3; + repeated string problems = 4; + string domain = 5; +} + +message RpcIndexQuery { + repeated RpcSubquery subqueries = 1; + repeated int32 domains = 2; + string searchSetIdentifier = 3; + string humanQuery = 4; + RpcSpecLimit quality = 5; + RpcSpecLimit year = 6; + RpcSpecLimit size = 7; + RpcSpecLimit rank = 8; + RpcQueryLimits queryLimits = 9; + string queryStrategy = 10; + RpcResultRankingParameters parameters = 11; +} + +message RpcSpecLimit { + int32 value = 1; + TYPE type = 2; + + enum TYPE { + NONE = 0; + EQUALS = 1; + LESS_THAN = 2; + GREATER_THAN = 3; + }; +} + +message RpcSearchResultSet { + repeated RpcDecoratedResultItem items = 1; +} + +message RpcDecoratedResultItem { + RpcRawResultItem rawItem = 1; + string url = 2; + string title = 3; + string description = 4; + double urlQuality = 5; + string format = 6; + int32 features = 7; + int32 pubYear = 8; + int64 dataHash = 9; + int32 wordsTotal = 10; + double rankingScore = 11; +} + +message RpcRawResultItem { + int64 combinedId = 1; + int32 resultsFromDomain = 2; + repeated RpcResultKeywordScore keywordScores = 3; +} + +message RpcResultKeywordScore { + int32 subquery = 1; + string keyword = 2; + int64 encodedWordMetadata = 3; + int64 encodedDocMetadata = 4; + bool hasPriorityTerms = 5; + int32 htmlFeatures = 6; +} + +message RpcQueryLimits { + int32 resultsByDomain = 1; + int32 resultsTotal = 2; + int32 timeoutMs = 3; + int32 fetchSize = 4; +} + +message RpcResultRankingParameters { + double fullK = 1; + double fullB = 2; + double prioK = 3; + double prioB = 4; + int32 shortDocumentThreshold = 5; + double shortDocumentPenalty = 6; + double domainRankBonus = 7; + double qualityPenalty = 8; + int32 shortSentenceThreshold = 9; + double shortSentencePenalty = 10; + double bm25FullWeight = 11; + double bm25PrioWeight = 12; + double tcfWeight = 13; + TEMPORAL_BIAS temporalBias = 14; + double temporalBiasWeight = 15; + + enum TEMPORAL_BIAS { + NONE = 0; + RECENT = 1; + OLD = 2; + } +} +message RpcSubquery { + repeated string include = 1; + repeated string exclude = 2; + repeated string advice = 3; + repeated string priority = 4; + repeated RpcCoherences coherences = 5; +} + +message RpcCoherences { + repeated string coherences = 1; +} diff --git a/code/api/index-api/src/test/java/nu/marginalia/index/client/IndexProtobufCodecTest.java b/code/api/index-api/src/test/java/nu/marginalia/index/client/IndexProtobufCodecTest.java new file mode 100644 index 00000000..36e85429 --- /dev/null +++ b/code/api/index-api/src/test/java/nu/marginalia/index/client/IndexProtobufCodecTest.java @@ -0,0 +1,50 @@ +package nu.marginalia.index.client; + +import nu.marginalia.index.client.model.query.SearchSubquery; +import nu.marginalia.index.client.model.results.ResultRankingParameters; +import nu.marginalia.index.query.limit.QueryLimits; +import nu.marginalia.index.query.limit.SpecificationLimit; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.*; + +class IndexProtobufCodecTest { + @Test + public void testSpecLimit() { + verifyIsIdentityTransformation(SpecificationLimit.none(), l -> IndexProtobufCodec.convertSpecLimit(IndexProtobufCodec.convertSpecLimit(l))); + verifyIsIdentityTransformation(SpecificationLimit.equals(1), l -> IndexProtobufCodec.convertSpecLimit(IndexProtobufCodec.convertSpecLimit(l))); + verifyIsIdentityTransformation(SpecificationLimit.greaterThan(1), l -> IndexProtobufCodec.convertSpecLimit(IndexProtobufCodec.convertSpecLimit(l))); + verifyIsIdentityTransformation(SpecificationLimit.lessThan(1), l -> IndexProtobufCodec.convertSpecLimit(IndexProtobufCodec.convertSpecLimit(l))); + } + + @Test + public void testRankingParameters() { + verifyIsIdentityTransformation(ResultRankingParameters.sensibleDefaults(), + p -> IndexProtobufCodec.convertRankingParameterss(IndexProtobufCodec.convertRankingParameterss(p))); + } + + @Test + public void testQueryLimits() { + verifyIsIdentityTransformation(new QueryLimits(1,2,3,4), + l -> IndexProtobufCodec.convertQueryLimits(IndexProtobufCodec.convertQueryLimits(l)) + ); + } + @Test + public void testSubqery() { + verifyIsIdentityTransformation(new SearchSubquery( + List.of("a", "b"), + List.of("c", "d"), + List.of("e", "f"), + List.of("g", "h"), + List.of(List.of("i", "j"), List.of("k")) + ), + s -> IndexProtobufCodec.convertSearchSubquery(IndexProtobufCodec.convertSearchSubquery(s)) + ); + } + private void verifyIsIdentityTransformation(T val, Function transformation) { + assertEquals(val, transformation.apply(val), val.toString()); + } +} \ No newline at end of file diff --git a/code/api/query-api/build.gradle b/code/api/query-api/build.gradle index 470c177c..524d21df 100644 --- a/code/api/query-api/build.gradle +++ b/code/api/query-api/build.gradle @@ -25,6 +25,8 @@ dependencies { implementation libs.guice implementation libs.rxjava implementation libs.gson + implementation libs.bundles.grpc + implementation libs.protobuf testImplementation libs.bundles.slf4j.test testImplementation libs.bundles.junit diff --git a/code/api/query-api/src/main/java/nu/marginalia/query/QueryProtobufCodec.java b/code/api/query-api/src/main/java/nu/marginalia/query/QueryProtobufCodec.java new file mode 100644 index 00000000..b8cd4fec --- /dev/null +++ b/code/api/query-api/src/main/java/nu/marginalia/query/QueryProtobufCodec.java @@ -0,0 +1,166 @@ +package nu.marginalia.query; + +import lombok.SneakyThrows; +import nu.marginalia.index.api.*; +import nu.marginalia.index.client.IndexProtobufCodec; +import nu.marginalia.index.client.model.query.SearchSetIdentifier; +import nu.marginalia.index.client.model.query.SearchSpecification; +import nu.marginalia.index.client.model.query.SearchSubquery; +import nu.marginalia.index.client.model.results.DecoratedSearchResultItem; +import nu.marginalia.index.client.model.results.SearchResultItem; +import nu.marginalia.index.client.model.results.SearchResultKeywordScore; +import nu.marginalia.index.query.limit.QueryStrategy; +import nu.marginalia.model.EdgeUrl; +import nu.marginalia.query.model.ProcessedQuery; +import nu.marginalia.query.model.QueryParams; +import nu.marginalia.query.model.QueryResponse; + +import java.util.ArrayList; +import java.util.List; + +import static nu.marginalia.index.client.IndexProtobufCodec.*; + +public class QueryProtobufCodec { + + public static RpcIndexQuery convertQuery(RpcQsQuery request, ProcessedQuery query) { + var builder = RpcIndexQuery.newBuilder(); + + builder.addAllDomains(request.getDomainIdsList()); + + for (var subquery : query.specs.subqueries) { + builder.addSubqueries(IndexProtobufCodec.convertSearchSubquery(subquery)); + } + + builder.setSearchSetIdentifier(query.specs.searchSetIdentifier.name()); + builder.setHumanQuery(request.getHumanQuery()); + + builder.setQuality(convertSpecLimit(query.specs.quality)); + builder.setYear(convertSpecLimit(query.specs.year)); + builder.setSize(convertSpecLimit(query.specs.size)); + builder.setRank(convertSpecLimit(query.specs.rank)); + + builder.setQueryLimits(IndexProtobufCodec.convertQueryLimits(query.specs.queryLimits)); + builder.setQueryStrategy(query.specs.queryStrategy.name()); + builder.setParameters(IndexProtobufCodec.convertRankingParameterss(query.specs.rankingParams)); + + return builder.build(); + } + + public static QueryParams convertRequest(RpcQsQuery request) { + return new QueryParams( + request.getHumanQuery(), + request.getNearDomain(), + request.getTacitIncludesList(), + request.getTacitExcludesList(), + request.getTacitPriorityList(), + request.getTacitAdviceList(), + convertSpecLimit(request.getQuality()), + convertSpecLimit(request.getYear()), + convertSpecLimit(request.getSize()), + convertSpecLimit(request.getRank()), + request.getDomainIdsList(), + IndexProtobufCodec.convertQueryLimits(request.getQueryLimits()), + SearchSetIdentifier.valueOf(request.getSearchSetIdentifier())); + } + + + public static QueryResponse convertQueryResponse(RpcQsResponse query) { + var results = new ArrayList(query.getResultsCount()); + + for (int i = 0; i < query.getResultsCount(); i++) + results.add(convertDecoratedResult(query.getResults(i))); + + return new QueryResponse( + convertSearchSpecification(query.getSpecs()), + results, + query.getSearchTermsHumanList(), + query.getProblemsList(), + query.getDomain() + ); + } + + @SneakyThrows + private static DecoratedSearchResultItem convertDecoratedResult(RpcDecoratedResultItem results) { + return new DecoratedSearchResultItem( + convertRawResult(results.getRawItem()), + new EdgeUrl(results.getUrl()), + results.getTitle(), + results.getDescription(), + results.getUrlQuality(), + results.getFormat(), + results.getFeatures(), + results.getPubYear(), // ??, + results.getDataHash(), + results.getWordsTotal(), + results.getRankingScore() + ); + } + + private static SearchResultItem convertRawResult(RpcRawResultItem rawItem) { + var keywordScores = new ArrayList(rawItem.getKeywordScoresCount()); + + for (int i = 0; i < rawItem.getKeywordScoresCount(); i++) + keywordScores.add(convertKeywordScore(rawItem.getKeywordScores(i))); + + return new SearchResultItem( + rawItem.getCombinedId(), + keywordScores, + rawItem.getResultsFromDomain(), + null + ); + } + + private static SearchResultKeywordScore convertKeywordScore(RpcResultKeywordScore keywordScores) { + return new SearchResultKeywordScore( + keywordScores.getSubquery(), + keywordScores.getKeyword(), + keywordScores.getEncodedWordMetadata(), + keywordScores.getEncodedDocMetadata(), + keywordScores.getHtmlFeatures(), + keywordScores.getHasPriorityTerms() + ); + } + + private static SearchSpecification convertSearchSpecification(RpcIndexQuery specs) { + List subqueries = new ArrayList<>(specs.getSubqueriesCount()); + + for (int i = 0; i < specs.getSubqueriesCount(); i++) { + subqueries.add(convertSearchSubquery(specs.getSubqueries(i))); + } + + return new SearchSpecification( + subqueries, + specs.getDomainsList(), + SearchSetIdentifier.valueOf(specs.getSearchSetIdentifier()), + specs.getHumanQuery(), + IndexProtobufCodec.convertSpecLimit(specs.getQuality()), + IndexProtobufCodec.convertSpecLimit(specs.getYear()), + IndexProtobufCodec.convertSpecLimit(specs.getSize()), + IndexProtobufCodec.convertSpecLimit(specs.getRank()), + IndexProtobufCodec.convertQueryLimits(specs.getQueryLimits()), + QueryStrategy.valueOf(specs.getQueryStrategy()), + convertRankingParameterss(specs.getParameters()) + ); + } + + public static RpcQsQuery convertQueryParams(QueryParams params) { + var builder = RpcQsQuery.newBuilder() + .addAllDomainIds(params.domainIds()) + .addAllTacitAdvice(params.tacitAdvice()) + .addAllTacitExcludes(params.tacitExcludes()) + .addAllTacitIncludes(params.tacitIncludes()) + .addAllTacitPriority(params.tacitPriority()) + .setHumanQuery(params.humanQuery()) + .setQueryLimits(convertQueryLimits(params.limits())) + .setQuality(convertSpecLimit(params.quality())) + .setYear(convertSpecLimit(params.year())) + .setSize(convertSpecLimit(params.size())) + .setRank(convertSpecLimit(params.rank())) + .setSearchSetIdentifier(params.identifier().name()); + + if (params.nearDomain() != null) + builder.setNearDomain(params.nearDomain()); + + return builder.build(); + } +} diff --git a/code/api/query-api/src/main/java/nu/marginalia/query/client/QueryClient.java b/code/api/query-api/src/main/java/nu/marginalia/query/client/QueryClient.java index 4241ef76..8aabaa42 100644 --- a/code/api/query-api/src/main/java/nu/marginalia/query/client/QueryClient.java +++ b/code/api/query-api/src/main/java/nu/marginalia/query/client/QueryClient.java @@ -2,15 +2,16 @@ package nu.marginalia.query.client; import com.google.inject.Inject; import com.google.inject.Singleton; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; import io.prometheus.client.Summary; -import nu.marginalia.WmsaHome; import nu.marginalia.client.AbstractDynamicClient; import nu.marginalia.client.Context; +import nu.marginalia.index.api.QueryApiGrpc; import nu.marginalia.index.client.model.query.SearchSpecification; import nu.marginalia.index.client.model.results.SearchResultSet; import nu.marginalia.model.gson.GsonFactory; -import nu.marginalia.mq.MessageQueueFactory; -import nu.marginalia.mq.outbox.MqOutbox; +import nu.marginalia.query.QueryProtobufCodec; import nu.marginalia.query.model.QueryParams; import nu.marginalia.query.model.QueryResponse; import nu.marginalia.service.descriptor.ServiceDescriptors; @@ -19,7 +20,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.CheckReturnValue; -import java.util.UUID; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; @Singleton public class QueryClient extends AbstractDynamicClient { @@ -27,6 +29,30 @@ public class QueryClient extends AbstractDynamicClient { private static final Summary wmsa_search_index_api_delegate_time = Summary.build().name("wmsa_search_index_api_delegate_time").help("-").register(); private static final Summary wmsa_search_index_api_search_time = Summary.build().name("wmsa_search_index_api_search_time").help("-").register(); + private final Map channels = new ConcurrentHashMap<>(); + private final Map queryApis = new ConcurrentHashMap<>(); + + record ServiceAndNode(String service, int node) { + public String getHostName() { + return service; + } + } + private ManagedChannel getChannel(ServiceAndNode serviceAndNode) { + return channels.computeIfAbsent(serviceAndNode, + san -> ManagedChannelBuilder + .forAddress(serviceAndNode.getHostName(), 81) + .usePlaintext() + .build()); + } + + public QueryApiGrpc.QueryApiBlockingStub queryApi(int node) { + return queryApis.computeIfAbsent(new ServiceAndNode("query-service", node), n -> + QueryApiGrpc.newBlockingStub( + getChannel(n) + ) + ); + } + private final Logger logger = LoggerFactory.getLogger(getClass()); @Inject @@ -42,11 +68,10 @@ public class QueryClient extends AbstractDynamicClient { () -> this.postGet(ctx, 0, "/delegate/", specs, SearchResultSet.class).blockingFirst() ); } + @CheckReturnValue public QueryResponse search(Context ctx, QueryParams params) { - return wmsa_search_index_api_search_time.time( - () -> this.postGet(ctx, 0, "/search/", params, QueryResponse.class).blockingFirst() - ); + return QueryProtobufCodec.convertQueryResponse(queryApi(0).query(QueryProtobufCodec.convertQueryParams(params))); } } diff --git a/code/api/query-api/src/main/java/nu/marginalia/query/model/QueryParams.java b/code/api/query-api/src/main/java/nu/marginalia/query/model/QueryParams.java index 2a912324..6e74d90c 100644 --- a/code/api/query-api/src/main/java/nu/marginalia/query/model/QueryParams.java +++ b/code/api/query-api/src/main/java/nu/marginalia/query/model/QueryParams.java @@ -5,10 +5,12 @@ import nu.marginalia.index.client.model.query.SearchSpecification; import nu.marginalia.index.query.limit.QueryLimits; import nu.marginalia.index.query.limit.SpecificationLimit; +import javax.annotation.Nullable; import java.util.List; public record QueryParams( String humanQuery, + @Nullable String nearDomain, List tacitIncludes, List tacitExcludes, diff --git a/code/services-core/index-service/build.gradle b/code/services-core/index-service/build.gradle index 2b9f9aff..4523dc27 100644 --- a/code/services-core/index-service/build.gradle +++ b/code/services-core/index-service/build.gradle @@ -57,6 +57,7 @@ dependencies { implementation libs.trove implementation libs.fastutil implementation libs.bundles.gson + implementation libs.bundles.grpc implementation libs.bundles.mariadb testImplementation libs.bundles.slf4j.test diff --git a/code/services-core/index-service/src/main/java/nu/marginalia/index/IndexService.java b/code/services-core/index-service/src/main/java/nu/marginalia/index/IndexService.java index b178686a..ddfd43c9 100644 --- a/code/services-core/index-service/src/main/java/nu/marginalia/index/IndexService.java +++ b/code/services-core/index-service/src/main/java/nu/marginalia/index/IndexService.java @@ -2,6 +2,7 @@ package nu.marginalia.index; import com.google.gson.Gson; import com.google.inject.Inject; +import io.grpc.ServerBuilder; import io.reactivex.rxjava3.schedulers.Schedulers; import lombok.SneakyThrows; import nu.marginalia.IndexLocations; @@ -23,6 +24,7 @@ import spark.Request; import spark.Response; import spark.Spark; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.concurrent.TimeUnit; @@ -49,8 +51,7 @@ public class IndexService extends Service { SearchIndex searchIndex, FileStorageService fileStorageService, LinkdbReader linkdbReader, - ServiceEventLog eventLog) - { + ServiceEventLog eventLog) throws IOException { super(params); this.opsService = opsService; @@ -63,6 +64,11 @@ public class IndexService extends Service { this.init = params.initialization; + var grpcServer = ServerBuilder.forPort(params.configuration.port() + 1) + .addService(indexQueryService) + .build(); + grpcServer.start(); + Spark.post("/search/", indexQueryService::search, gson::toJson); Spark.get("/public/debug/docmeta", indexQueryService::debugEndpointDocMetadata, gson::toJson); diff --git a/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/IndexQueryService.java b/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/IndexQueryService.java index badd292e..6553afd9 100644 --- a/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/IndexQueryService.java +++ b/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/IndexQueryService.java @@ -9,6 +9,9 @@ import io.prometheus.client.Counter; import io.prometheus.client.Gauge; import io.prometheus.client.Histogram; import lombok.SneakyThrows; +import nu.marginalia.index.api.*; +import nu.marginalia.index.api.IndexApiGrpc.IndexApiImplBase; +import nu.marginalia.index.client.model.query.SearchSetIdentifier; import nu.marginalia.index.client.model.query.SearchSubquery; import nu.marginalia.index.client.model.results.ResultRankingParameters; import nu.marginalia.index.client.model.results.SearchResultItem; @@ -41,8 +44,10 @@ import java.sql.SQLException; import java.util.*; import java.util.stream.Collectors; +import static io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall; + @Singleton -public class IndexQueryService { +public class IndexQueryService extends IndexApiImplBase { private final Logger logger = LoggerFactory.getLogger(getClass()); @@ -142,6 +147,61 @@ public class IndexQueryService { } } + // GRPC endpoint + @SneakyThrows + public void query(nu.marginalia.index.api.RpcIndexQuery request, + io.grpc.stub.StreamObserver responseObserver) { + + try { + var params = new SearchParameters(request, getSearchSet(request)); + + SearchResultSet results = executeSearch(params); + RpcSearchResultSet.Builder retBuilder = RpcSearchResultSet.newBuilder(); + for (var result : results.results) { + + var rawResult = result.rawIndexResult; + + var rawItem = RpcRawResultItem.newBuilder(); + rawItem.setCombinedId(rawResult.combinedId); + rawItem.setResultsFromDomain(rawResult.resultsFromDomain); + + for (var score : rawResult.keywordScores) { + rawItem.addKeywordScores( + RpcResultKeywordScore.newBuilder() + .setEncodedDocMetadata(score.encodedDocMetadata()) + .setEncodedWordMetadata(score.encodedWordMetadata()) + .setKeyword(score.keyword) + .setHtmlFeatures(score.htmlFeatures()) + .setHasPriorityTerms(score.hasPriorityTerms()) + .setSubquery(score.subquery) + ); + } + + var decoratedBuilder = RpcDecoratedResultItem.newBuilder() + .setDataHash(result.dataHash) + .setDescription(result.description) + .setFeatures(result.features) + .setFormat(result.format) + .setRankingScore(result.rankingScore) + .setTitle(result.title) + .setUrl(result.url.toString()) + .setWordsTotal(result.wordsTotal) + .setRawItem(rawItem); + + if (result.pubYear != null) { + decoratedBuilder.setPubYear(result.pubYear); + } + retBuilder.addItems(decoratedBuilder.build()); + } + responseObserver.onNext(retBuilder.build()); + responseObserver.onCompleted(); + } + catch (Exception ex) { + logger.error("Error in handling request", ex); + responseObserver.onError(ex); + } + } + // exists for test access @SneakyThrows SearchResultSet justQuery(SearchSpecification specsSet) { @@ -156,7 +216,16 @@ public class IndexQueryService { return searchSetsService.getSearchSetByName(specsSet.searchSetIdentifier); } + private SearchSet getSearchSet(RpcIndexQuery request) { + if (request.getDomainsCount() > 0) { + return new SmallSearchSet(request.getDomainsList()); + } + + return searchSetsService.getSearchSetByName( + SearchSetIdentifier.valueOf(request.getSearchSetIdentifier()) + ); + } private SearchResultSet executeSearch(SearchParameters params) throws SQLException { var rankingContext = createRankingContext(params.rankingParams, params.subqueries); diff --git a/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/SearchParameters.java b/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/SearchParameters.java index d8d04818..141dc32d 100644 --- a/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/SearchParameters.java +++ b/code/services-core/index-service/src/main/java/nu/marginalia/index/svc/SearchParameters.java @@ -1,16 +1,25 @@ package nu.marginalia.index.svc; import gnu.trove.set.hash.TLongHashSet; +import nu.marginalia.index.api.RpcIndexQuery; +import nu.marginalia.index.api.RpcSpecLimit; +import nu.marginalia.index.client.IndexProtobufCodec; import nu.marginalia.index.client.model.query.SearchSpecification; import nu.marginalia.index.client.model.query.SearchSubquery; +import nu.marginalia.index.client.model.results.Bm25Parameters; import nu.marginalia.index.client.model.results.ResultRankingParameters; import nu.marginalia.index.index.SearchIndex; import nu.marginalia.index.index.SearchIndexSearchTerms; import nu.marginalia.index.query.IndexQuery; import nu.marginalia.index.query.IndexQueryParams; import nu.marginalia.index.query.IndexSearchBudget; +import nu.marginalia.index.query.limit.QueryLimits; +import nu.marginalia.index.query.limit.QueryStrategy; +import nu.marginalia.index.query.limit.SpecificationLimit; +import nu.marginalia.index.query.limit.SpecificationLimitType; import nu.marginalia.index.searchset.SearchSet; +import java.util.ArrayList; import java.util.List; public class SearchParameters { @@ -62,6 +71,30 @@ public class SearchParameters { rankingParams = specsSet.rankingParams; } + public SearchParameters(RpcIndexQuery request, SearchSet searchSet) { + var limits = IndexProtobufCodec.convertQueryLimits(request.getQueryLimits()); + + this.fetchSize = limits.fetchSize(); + this.budget = new IndexSearchBudget(limits.timeoutMs()); + this.subqueries = new ArrayList<>(request.getSubqueriesCount()); + for (int i = 0; i < request.getSubqueriesCount(); i++) { + this.subqueries.add(IndexProtobufCodec.convertSearchSubquery(request.getSubqueries(i))); + } + this.limitByDomain = limits.resultsByDomain(); + this.limitTotal = limits.resultsTotal(); + + this.consideredUrlIds = CachedObjects.getConsideredUrlsMap(); + + queryParams = new IndexQueryParams( + IndexProtobufCodec.convertSpecLimit(request.getQuality()), + IndexProtobufCodec.convertSpecLimit(request.getYear()), + IndexProtobufCodec.convertSpecLimit(request.getSize()), + IndexProtobufCodec.convertSpecLimit(request.getRank()), + searchSet, + QueryStrategy.valueOf(request.getQueryStrategy())); + + rankingParams = IndexProtobufCodec.convertRankingParameterss(request.getParameters()); + } List createIndexQueries(SearchIndex index, SearchIndexSearchTerms terms) { return index.createQueries(terms, queryParams, consideredUrlIds::add); diff --git a/code/services-core/query-service/build.gradle b/code/services-core/query-service/build.gradle index 5477d634..4907a86e 100644 --- a/code/services-core/query-service/build.gradle +++ b/code/services-core/query-service/build.gradle @@ -45,6 +45,7 @@ dependencies { implementation libs.protobuf implementation libs.rxjava implementation libs.bundles.mariadb + implementation libs.bundles.grpc testImplementation libs.bundles.slf4j.test testImplementation libs.bundles.junit diff --git a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java new file mode 100644 index 00000000..190305ee --- /dev/null +++ b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryGRPCService.java @@ -0,0 +1,148 @@ +package nu.marginalia.query; + +import com.google.inject.Inject; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import nu.marginalia.db.DomainBlacklist; +import nu.marginalia.index.api.*; +import nu.marginalia.model.id.UrlIdCodec; +import nu.marginalia.query.svc.NodeConfigurationWatcher; +import nu.marginalia.query.svc.QueryFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +public class QueryGRPCService extends QueryApiGrpc.QueryApiImplBase { + + private final Logger logger = LoggerFactory.getLogger(QueryGRPCService.class); + + private final Map channels + = new ConcurrentHashMap<>(); + private final Map actorRpcApis + = new ConcurrentHashMap<>(); + + private ManagedChannel getChannel(ServiceAndNode serviceAndNode) { + return channels.computeIfAbsent(serviceAndNode, + san -> ManagedChannelBuilder + .forAddress(serviceAndNode.getHostName(), 81) + .usePlaintext() + .build()); + } + + public IndexApiGrpc.IndexApiFutureStub indexApi(int node) { + return actorRpcApis.computeIfAbsent(new ServiceAndNode("index-service", node), n -> + IndexApiGrpc.newFutureStub( + getChannel(n) + ) + ); + } + + record ServiceAndNode(String service, int node) { + public String getHostName() { + return service+"-"+node; + } + } + + private final QueryFactory queryFactory; + private final DomainBlacklist blacklist; + private final NodeConfigurationWatcher nodeConfigurationWatcher; + + @Inject + public QueryGRPCService(QueryFactory queryFactory, DomainBlacklist blacklist, NodeConfigurationWatcher nodeConfigurationWatcher) { + this.queryFactory = queryFactory; + this.blacklist = blacklist; + this.nodeConfigurationWatcher = nodeConfigurationWatcher; + } + + public void query(nu.marginalia.index.api.RpcQsQuery request, + io.grpc.stub.StreamObserver responseObserver) + { + try { + var params = QueryProtobufCodec.convertRequest(request); + var query = queryFactory.createQuery(params); + + RpcIndexQuery indexRequest = QueryProtobufCodec.convertQuery(request, query); + List bestItems = executeQueries(indexRequest, request.getQueryLimits().getResultsTotal()); + + var responseBuilder = RpcQsResponse.newBuilder() + .addAllResults(bestItems) + .setSpecs(indexRequest) + .addAllSearchTermsHuman(query.searchTermsHuman); + + if (query.domain != null) + responseBuilder.setDomain(query.domain); + + responseObserver.onNext(responseBuilder.build()); + + responseObserver.onCompleted(); + } catch (Exception e) { + logger.error("Exception", e); + responseObserver.onError(e); + } + } + + private List executeQueries(RpcIndexQuery indexRequest, int totalSize) throws InterruptedException + { + + final List bestItems = new ArrayList<>(2 * totalSize); + + LinkedList> resultSets = new LinkedList<>(); + for (var node : nodeConfigurationWatcher.getQueryNodes()) { + resultSets.add(indexApi(node).query(indexRequest)); + } + + long start = System.currentTimeMillis(); + long timeout = start + 500; + + while (!resultSets.isEmpty() && System.currentTimeMillis() < timeout) + { + resultSets.removeIf(f -> switch(f.state()) { + case CANCELLED -> true; + case FAILED -> { + logger.error("Error in query", f.exceptionNow()); + yield true; + } + case SUCCESS -> { + mergeResults(bestItems, f.resultNow(), totalSize); + yield true; + } + case RUNNING -> false; + }); + + if (!resultSets.isEmpty()) { + // yield + TimeUnit.MILLISECONDS.sleep(1); + } + } + return bestItems; + } + + private static final Comparator comparator = + Comparator.comparing(RpcDecoratedResultItem::getRankingScore); + private void mergeResults(List bestItems, + RpcSearchResultSet result, + int totalSize) + { + for (int i = 0; i < result.getItemsCount(); i++) { + var item = result.getItems(i); + if (isBlacklisted(item)) { + continue; + } + bestItems.add(result.getItems(i)); + } + + bestItems.sort(comparator); + + if (bestItems.size() > totalSize) { + bestItems.subList(totalSize, bestItems.size()).clear(); + } + } + + private boolean isBlacklisted(RpcDecoratedResultItem item) { + return blacklist.isBlacklisted(UrlIdCodec.getDomainId(item.getRawItem().getCombinedId())); + } +} diff --git a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryService.java b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryService.java index 00a220fe..d950a8d0 100644 --- a/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryService.java +++ b/code/services-core/query-service/src/main/java/nu/marginalia/query/QueryService.java @@ -2,6 +2,7 @@ package nu.marginalia.query; import com.google.gson.Gson; import com.google.inject.Inject; +import io.grpc.ServerBuilder; import nu.marginalia.client.Context; import nu.marginalia.db.DomainBlacklist; import nu.marginalia.index.client.IndexClient; @@ -19,6 +20,7 @@ import spark.Request; import spark.Response; import spark.Spark; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; @@ -31,16 +33,14 @@ public class QueryService extends Service { private final DomainBlacklist blacklist; private final QueryFactory queryFactory; - private volatile List nodes = new ArrayList<>(); - @Inject public QueryService(BaseServiceParams params, IndexClient indexClient, NodeConfigurationWatcher nodeWatcher, + QueryGRPCService queryGRPCService, Gson gson, DomainBlacklist blacklist, - QueryFactory queryFactory) - { + QueryFactory queryFactory) throws IOException { super(params); this.indexClient = indexClient; this.nodeWatcher = nodeWatcher; @@ -48,6 +48,11 @@ public class QueryService extends Service { this.blacklist = blacklist; this.queryFactory = queryFactory; + var grpcServer = ServerBuilder.forPort(params.configuration.port() + 1) + .addService(queryGRPCService) + .build(); + grpcServer.start(); + Spark.post("/delegate/", this::delegateToIndex, gson::toJson); Spark.post("/search/", this::search, gson::toJson); } diff --git a/settings.gradle b/settings.gradle index 423077f7..2fd33ceb 100644 --- a/settings.gradle +++ b/settings.gradle @@ -46,6 +46,7 @@ include 'code:features-index:index-forward' include 'code:features-index:index-reverse' include 'code:features-index:domain-ranking' +include 'code:api:actor-api' include 'code:api:query-api' include 'code:api:index-api' include 'code:api:assistant-api' @@ -125,7 +126,9 @@ dependencyResolutionManagement { library('guice', 'com.google.inject', 'guice').version('7.0.0') library('guava', 'com.google.guava', 'guava').version('32.0.1-jre') library('protobuf', 'com.google.protobuf', 'protobuf-java').version('3.0.0') - + library('grpc-protobuf', 'io.grpc', 'grpc-protobuf').version('1.49.2') + library('grpc-stub', 'io.grpc', 'grpc-stub').version('1.49.2') + library('grpc-netty', 'io.grpc', 'grpc-netty-shaded').version('1.49.2') library('rxjava', 'io.reactivex.rxjava3', 'rxjava').version('3.1.6') library('prometheus', 'io.prometheus', 'simpleclient').version('0.16.0') @@ -189,7 +192,7 @@ dependencyResolutionManagement { library('handlebars.markdown','com.github.jknack','handlebars-markdown').version('4.2.1') library('sqlite','org.xerial','sqlite-jdbc').version('3.41.2.1') - + library('javax.annotation','javax.annotation','javax.annotation-api').version('1.3.2') library('parquet-column', 'org.apache.parquet','parquet-column').version('1.13.1') library('parquet-hadoop', 'org.apache.parquet','parquet-hadoop').version('1.13.1') @@ -200,7 +203,7 @@ dependencyResolutionManagement { bundle('nlp', ['stanford.corenlp', 'opennlp', 'fasttext']) bundle('selenium', ['selenium.chrome', 'selenium.java']) bundle('handlebars', ['handlebars', 'handlebars.markdown']) - + bundle('grpc', ['protobuf', 'grpc-stub', 'grpc-protobuf', 'grpc-netty']) bundle('gson', ['gson', 'gson-type-adapter']) bundle('httpcomponents', ['httpcomponents.core', 'httpcomponents.client']) bundle('parquet', ['parquet-column', 'parquet-hadoop'])