diff --git a/code/common/config/java/nu/marginalia/LanguageModels.java b/code/common/config/java/nu/marginalia/LanguageModels.java index 04ab0aa0..d1854963 100644 --- a/code/common/config/java/nu/marginalia/LanguageModels.java +++ b/code/common/config/java/nu/marginalia/LanguageModels.java @@ -1,9 +1,11 @@ package nu.marginalia; +import lombok.Builder; + import java.nio.file.Path; +@Builder public class LanguageModels { - public final Path ngramBloomFilter; public final Path termFrequencies; public final Path openNLPSentenceDetectionData; @@ -11,20 +13,21 @@ public class LanguageModels { public final Path posDict; public final Path openNLPTokenData; public final Path fasttextLanguageModel; + public final Path segments; - public LanguageModels(Path ngramBloomFilter, - Path termFrequencies, + public LanguageModels(Path termFrequencies, Path openNLPSentenceDetectionData, Path posRules, Path posDict, Path openNLPTokenData, - Path fasttextLanguageModel) { - this.ngramBloomFilter = ngramBloomFilter; + Path fasttextLanguageModel, + Path segments) { this.termFrequencies = termFrequencies; this.openNLPSentenceDetectionData = openNLPSentenceDetectionData; this.posRules = posRules; this.posDict = posDict; this.openNLPTokenData = openNLPTokenData; this.fasttextLanguageModel = fasttextLanguageModel; + this.segments = segments; } } diff --git a/code/common/config/java/nu/marginalia/WmsaHome.java b/code/common/config/java/nu/marginalia/WmsaHome.java index b61ee4dd..b5378afc 100644 --- a/code/common/config/java/nu/marginalia/WmsaHome.java +++ b/code/common/config/java/nu/marginalia/WmsaHome.java @@ -85,13 +85,14 @@ public class WmsaHome { final Path home = getHomePath(); return new LanguageModels( - home.resolve("model/ngrams.bin"), home.resolve("model/tfreq-new-algo3.bin"), home.resolve("model/opennlp-sentence.bin"), home.resolve("model/English.RDR"), home.resolve("model/English.DICT"), home.resolve("model/opennlp-tok.bin"), - home.resolve("model/lid.176.ftz")); + home.resolve("model/lid.176.ftz"), + home.resolve("model/segments.bin") + ); } public static Path getAtagsPath() { diff --git a/code/common/model/java/nu/marginalia/model/idx/WordFlags.java b/code/common/model/java/nu/marginalia/model/idx/WordFlags.java index dc627715..db54df77 100644 --- a/code/common/model/java/nu/marginalia/model/idx/WordFlags.java +++ b/code/common/model/java/nu/marginalia/model/idx/WordFlags.java @@ -50,6 +50,10 @@ public enum WordFlags { return (asBit() & value) > 0; } + public boolean isAbsent(long value) { + return (asBit() & value) == 0; + } + public static EnumSet decode(long encodedValue) { EnumSet ret = EnumSet.noneOf(WordFlags.class); @@ -61,4 +65,5 @@ public enum WordFlags { return ret; } + } diff --git a/code/execution/api/java/nu/marginalia/executor/client/ExecutorExportClient.java b/code/execution/api/java/nu/marginalia/executor/client/ExecutorExportClient.java index a3286a1b..e12fa0d3 100644 --- a/code/execution/api/java/nu/marginalia/executor/client/ExecutorExportClient.java +++ b/code/execution/api/java/nu/marginalia/executor/client/ExecutorExportClient.java @@ -2,10 +2,7 @@ package nu.marginalia.executor.client; import com.google.inject.Inject; import com.google.inject.Singleton; -import nu.marginalia.functions.execution.api.Empty; -import nu.marginalia.functions.execution.api.ExecutorExportApiGrpc; -import nu.marginalia.functions.execution.api.RpcExportSampleData; -import nu.marginalia.functions.execution.api.RpcFileStorageId; +import nu.marginalia.functions.execution.api.*; import nu.marginalia.service.client.GrpcChannelPoolFactory; import nu.marginalia.service.client.GrpcMultiNodeChannelPool; import nu.marginalia.service.discovery.property.ServiceKey; @@ -55,6 +52,7 @@ public class ExecutorExportClient { .setFileStorageId(fid.id()) .build()); } + public void exportTermFrequencies(int node, FileStorageId fid) { channelPool.call(ExecutorExportApiBlockingStub::exportTermFrequencies) .forNode(node) @@ -69,6 +67,14 @@ public class ExecutorExportClient { .run(Empty.getDefaultInstance()); } + public void exportSegmentationModel(int node, String path) { + channelPool.call(ExecutorExportApiBlockingStub::exportSegmentationModel) + .forNode(node) + .run(RpcExportSegmentationModel + .newBuilder() + .setSourcePath(path) + .build()); + } } diff --git a/code/execution/api/src/main/protobuf/executor-api.proto b/code/execution/api/src/main/protobuf/executor-api.proto index 31cffe9b..565770ac 100644 --- a/code/execution/api/src/main/protobuf/executor-api.proto +++ b/code/execution/api/src/main/protobuf/executor-api.proto @@ -38,6 +38,7 @@ service ExecutorSideloadApi { service ExecutorExportApi { rpc exportAtags(RpcFileStorageId) returns (Empty) {} + rpc exportSegmentationModel(RpcExportSegmentationModel) returns (Empty) {} rpc exportSampleData(RpcExportSampleData) returns (Empty) {} rpc exportRssFeeds(RpcFileStorageId) returns (Empty) {} rpc exportTermFrequencies(RpcFileStorageId) returns (Empty) {} @@ -61,6 +62,9 @@ message RpcSideloadEncyclopedia { string sourcePath = 1; string baseUrl = 2; } +message RpcExportSegmentationModel { + string sourcePath = 1; +} message RpcSideloadDirtree { string sourcePath = 1; } diff --git a/code/execution/build.gradle b/code/execution/build.gradle index 74449214..3824a8c1 100644 --- a/code/execution/build.gradle +++ b/code/execution/build.gradle @@ -32,8 +32,10 @@ dependencies { implementation project(':third-party:commons-codec') implementation project(':code:libraries:message-queue') + implementation project(':code:libraries:term-frequency-dict') implementation project(':code:functions:link-graph:api') + implementation project(':code:functions:search-query') implementation project(':code:execution:api') implementation project(':code:process-models:crawl-spec') diff --git a/code/execution/java/nu/marginalia/actor/ExecutorActor.java b/code/execution/java/nu/marginalia/actor/ExecutorActor.java index ee7fb1d3..d04b3eaa 100644 --- a/code/execution/java/nu/marginalia/actor/ExecutorActor.java +++ b/code/execution/java/nu/marginalia/actor/ExecutorActor.java @@ -12,6 +12,7 @@ public enum ExecutorActor { ADJACENCY_CALCULATION, CRAWL_JOB_EXTRACTOR, EXPORT_DATA, + EXPORT_SEGMENTATION_MODEL, EXPORT_ATAGS, EXPORT_TERM_FREQUENCIES, EXPORT_FEEDS, diff --git a/code/execution/java/nu/marginalia/actor/ExecutorActorControlService.java b/code/execution/java/nu/marginalia/actor/ExecutorActorControlService.java index 53abdfe3..6f37d7ab 100644 --- a/code/execution/java/nu/marginalia/actor/ExecutorActorControlService.java +++ b/code/execution/java/nu/marginalia/actor/ExecutorActorControlService.java @@ -47,6 +47,7 @@ public class ExecutorActorControlService { ExportFeedsActor exportFeedsActor, ExportSampleDataActor exportSampleDataActor, ExportTermFreqActor exportTermFrequenciesActor, + ExportSegmentationModelActor exportSegmentationModelActor, DownloadSampleActor downloadSampleActor, ExecutorActorStateMachines stateMachines) { this.messageQueueFactory = messageQueueFactory; @@ -76,6 +77,7 @@ public class ExecutorActorControlService { register(ExecutorActor.EXPORT_FEEDS, exportFeedsActor); register(ExecutorActor.EXPORT_SAMPLE_DATA, exportSampleDataActor); register(ExecutorActor.EXPORT_TERM_FREQUENCIES, exportTermFrequenciesActor); + register(ExecutorActor.EXPORT_SEGMENTATION_MODEL, exportSegmentationModelActor); register(ExecutorActor.DOWNLOAD_SAMPLE, downloadSampleActor); } diff --git a/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java b/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java new file mode 100644 index 00000000..98cf114e --- /dev/null +++ b/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java @@ -0,0 +1,55 @@ +package nu.marginalia.actor.task; + +import com.google.gson.Gson; +import com.google.inject.Inject; +import com.google.inject.Singleton; +import nu.marginalia.actor.prototype.RecordActorPrototype; +import nu.marginalia.actor.state.ActorStep; +import nu.marginalia.segmentation.NgramExtractorMain; +import nu.marginalia.storage.FileStorageService; +import nu.marginalia.storage.model.FileStorageType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.file.Path; +import java.time.LocalDateTime; + +@Singleton +public class ExportSegmentationModelActor extends RecordActorPrototype { + + private final FileStorageService storageService; + private final Logger logger = LoggerFactory.getLogger(getClass()); + + public record Export(String zimFile) implements ActorStep {} + + @Override + public ActorStep transition(ActorStep self) throws Exception { + return switch(self) { + case Export(String zimFile) -> { + + var storage = storageService.allocateStorage(FileStorageType.EXPORT, "segmentation-model", "Segmentation Model Export " + LocalDateTime.now()); + + Path countsFile = storage.asPath().resolve("ngram-counts.bin"); + + NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile); + + yield new End(); + } + default -> new Error(); + }; + } + + @Override + public String describe() { + return "Generate a query segmentation model from a ZIM file."; + } + + @Inject + public ExportSegmentationModelActor(Gson gson, + FileStorageService storageService) + { + super(gson); + this.storageService = storageService; + } + +} diff --git a/code/execution/java/nu/marginalia/execution/ExecutorExportGrpcService.java b/code/execution/java/nu/marginalia/execution/ExecutorExportGrpcService.java index 41c8bb8b..3c5a8d5b 100644 --- a/code/execution/java/nu/marginalia/execution/ExecutorExportGrpcService.java +++ b/code/execution/java/nu/marginalia/execution/ExecutorExportGrpcService.java @@ -6,10 +6,7 @@ import io.grpc.stub.StreamObserver; import nu.marginalia.actor.ExecutorActor; import nu.marginalia.actor.ExecutorActorControlService; import nu.marginalia.actor.task.*; -import nu.marginalia.functions.execution.api.Empty; -import nu.marginalia.functions.execution.api.ExecutorExportApiGrpc; -import nu.marginalia.functions.execution.api.RpcExportSampleData; -import nu.marginalia.functions.execution.api.RpcFileStorageId; +import nu.marginalia.functions.execution.api.*; import nu.marginalia.storage.model.FileStorageId; @Singleton @@ -92,4 +89,20 @@ public class ExecutorExportGrpcService extends ExecutorExportApiGrpc.ExecutorExp responseObserver.onError(e); } } + + @Override + public void exportSegmentationModel(RpcExportSegmentationModel request, StreamObserver responseObserver) { + try { + actorControlService.startFrom(ExecutorActor.EXPORT_SEGMENTATION_MODEL, + new ExportSegmentationModelActor.Export(request.getSourcePath()) + ); + + responseObserver.onNext(Empty.getDefaultInstance()); + responseObserver.onCompleted(); + } + catch (Exception e) { + responseObserver.onError(e); + } + } + } diff --git a/code/features-convert/anchor-keywords/build.gradle b/code/features-convert/anchor-keywords/build.gradle index 880ce467..ae92b066 100644 --- a/code/features-convert/anchor-keywords/build.gradle +++ b/code/features-convert/anchor-keywords/build.gradle @@ -19,6 +19,7 @@ dependencies { implementation project(':code:common:process') implementation project(':code:features-convert:keyword-extraction') implementation project(':code:libraries:language-processing') + implementation project(':code:libraries:term-frequency-dict') implementation libs.bundles.slf4j diff --git a/code/features-convert/anchor-keywords/test/nu/marginalia/atags/DomainAnchorTagsImplTest.java b/code/features-convert/anchor-keywords/test/nu/marginalia/atags/DomainAnchorTagsImplTest.java index ee555ca5..17443c51 100644 --- a/code/features-convert/anchor-keywords/test/nu/marginalia/atags/DomainAnchorTagsImplTest.java +++ b/code/features-convert/anchor-keywords/test/nu/marginalia/atags/DomainAnchorTagsImplTest.java @@ -5,6 +5,7 @@ import nu.marginalia.keyword.KeywordExtractor; import nu.marginalia.language.sentence.SentenceExtractor; import nu.marginalia.model.EdgeDomain; import nu.marginalia.model.EdgeUrl; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.util.TestLanguageModels; import org.junit.jupiter.api.Test; diff --git a/code/features-convert/anchor-keywords/test/nu/marginalia/util/TestLanguageModels.java b/code/features-convert/anchor-keywords/test/nu/marginalia/util/TestLanguageModels.java index 5efd2025..a4cc012b 100644 --- a/code/features-convert/anchor-keywords/test/nu/marginalia/util/TestLanguageModels.java +++ b/code/features-convert/anchor-keywords/test/nu/marginalia/util/TestLanguageModels.java @@ -26,13 +26,13 @@ public class TestLanguageModels { var languageModelsHome = getLanguageModelsPath(); return new LanguageModels( - languageModelsHome.resolve("ngrams.bin"), languageModelsHome.resolve("tfreq-new-algo3.bin"), languageModelsHome.resolve("opennlp-sentence.bin"), languageModelsHome.resolve("English.RDR"), languageModelsHome.resolve("English.DICT"), languageModelsHome.resolve("opennlp-tokens.bin"), - languageModelsHome.resolve("lid.176.ftz") + languageModelsHome.resolve("lid.176.ftz"), + languageModelsHome.resolve("segments.bin") ); } } diff --git a/code/features-convert/data-extractors/build.gradle b/code/features-convert/data-extractors/build.gradle index 73aebd49..69ae1388 100644 --- a/code/features-convert/data-extractors/build.gradle +++ b/code/features-convert/data-extractors/build.gradle @@ -21,6 +21,7 @@ dependencies { implementation project(':code:common:model') implementation project(':code:libraries:language-processing') implementation project(':code:libraries:term-frequency-dict') + implementation project(':code:libraries:blocking-thread-pool') implementation project(':code:features-crawl:link-parser') implementation project(':code:features-convert:anchor-keywords') implementation project(':code:process-models:crawling-model') diff --git a/code/features-convert/data-extractors/java/nu/marginalia/extractor/TermFrequencyExporter.java b/code/features-convert/data-extractors/java/nu/marginalia/extractor/TermFrequencyExporter.java index df1e56a9..1e1a2cd5 100644 --- a/code/features-convert/data-extractors/java/nu/marginalia/extractor/TermFrequencyExporter.java +++ b/code/features-convert/data-extractors/java/nu/marginalia/extractor/TermFrequencyExporter.java @@ -14,6 +14,7 @@ import nu.marginalia.process.log.WorkLog; import nu.marginalia.storage.FileStorageService; import nu.marginalia.storage.model.FileStorage; import nu.marginalia.storage.model.FileStorageId; +import nu.marginalia.util.SimpleBlockingThreadPool; import org.jsoup.Jsoup; import org.jsoup.nodes.Document; import org.slf4j.Logger; @@ -53,27 +54,23 @@ public class TermFrequencyExporter implements ExporterIf { TLongIntHashMap counts = new TLongIntHashMap(100_000_000, 0.7f, -1, -1); AtomicInteger docCount = new AtomicInteger(); - try (ForkJoinPool fjp = new ForkJoinPool(Math.max(2, Runtime.getRuntime().availableProcessors() / 2))) { + SimpleBlockingThreadPool sjp = new SimpleBlockingThreadPool("exporter", Math.clamp(2, 16, Runtime.getRuntime().availableProcessors() / 2), 4); + Path crawlerLogFile = inputDir.resolve("crawler.log"); - Path crawlerLogFile = inputDir.resolve("crawler.log"); + for (var item : WorkLog.iterable(crawlerLogFile)) { + if (Thread.interrupted()) { + sjp.shutDownNow(); - for (var item : WorkLog.iterable(crawlerLogFile)) { - if (Thread.interrupted()) { - fjp.shutdownNow(); - - throw new InterruptedException(); - } - - Path crawlDataPath = inputDir.resolve(item.relPath()); - fjp.execute(() -> processFile(crawlDataPath, counts, docCount, se.get())); + throw new InterruptedException(); } - while (!fjp.isQuiescent()) { - if (fjp.awaitQuiescence(10, TimeUnit.SECONDS)) - break; - } + Path crawlDataPath = inputDir.resolve(item.relPath()); + sjp.submitQuietly(() -> processFile(crawlDataPath, counts, docCount, se.get())); } + sjp.shutDown(); + sjp.awaitTermination(10, TimeUnit.DAYS); + var tmpFile = Files.createTempFile(destStorage.asPath(), "freqs", ".dat.tmp", PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rw-r--r--"))); @@ -127,6 +124,10 @@ public class TermFrequencyExporter implements ExporterIf { for (var word : sent) { words.add(longHash(word.stemmed().getBytes(StandardCharsets.UTF_8))); } + + for (var ngram : sent.ngramStemmed) { + words.add(longHash(ngram.getBytes())); + } } synchronized (counts) { diff --git a/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/DocumentKeywordExtractor.java b/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/DocumentKeywordExtractor.java index 8feb5fd8..aaad9800 100644 --- a/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/DocumentKeywordExtractor.java +++ b/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/DocumentKeywordExtractor.java @@ -1,5 +1,6 @@ package nu.marginalia.keyword; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.keyword.extractors.*; import nu.marginalia.keyword.model.DocumentKeywordsBuilder; import nu.marginalia.language.model.DocumentLanguageData; @@ -15,11 +16,13 @@ public class DocumentKeywordExtractor { private final KeywordExtractor keywordExtractor; private final TermFrequencyDict dict; + private final NgramLexicon ngramLexicon; @Inject - public DocumentKeywordExtractor(TermFrequencyDict dict) { + public DocumentKeywordExtractor(TermFrequencyDict dict, NgramLexicon ngramLexicon) { this.dict = dict; + this.ngramLexicon = ngramLexicon; this.keywordExtractor = new KeywordExtractor(); } @@ -131,6 +134,17 @@ public class DocumentKeywordExtractor { wordsBuilder.add(rep.word, meta); } + + for (int i = 0; i < sent.ngrams.length; i++) { + var ngram = sent.ngrams[i]; + var ngramStemmed = sent.ngramStemmed[i]; + + long meta = metadata.getMetadataForWord(ngramStemmed); + assert meta != 0L : "Missing meta for " + ngram; + + wordsBuilder.add(ngram, meta); + } + } } diff --git a/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/extractors/KeywordPositionBitmask.java b/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/extractors/KeywordPositionBitmask.java index b402c9f6..230c895f 100644 --- a/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/extractors/KeywordPositionBitmask.java +++ b/code/features-convert/keyword-extraction/java/nu/marginalia/keyword/extractors/KeywordPositionBitmask.java @@ -14,7 +14,9 @@ public class KeywordPositionBitmask { private static final int unmodulatedPortion = 16; @Inject - public KeywordPositionBitmask(KeywordExtractor keywordExtractor, DocumentLanguageData dld) { + public KeywordPositionBitmask(KeywordExtractor keywordExtractor, + DocumentLanguageData dld) + { // Mark the title words as position 0 for (var sent : dld.titleSentences) { @@ -24,6 +26,10 @@ public class KeywordPositionBitmask { positionMask.merge(word.stemmed(), posBit, this::bitwiseOr); } + for (var ngram : sent.ngramStemmed) { + positionMask.merge(ngram, posBit, this::bitwiseOr); + } + for (var span : keywordExtractor.getKeywordsFromSentence(sent)) { positionMask.merge(sent.constructStemmedWordFromSpan(span), posBit, this::bitwiseOr); } @@ -43,6 +49,10 @@ public class KeywordPositionBitmask { positionMask.merge(word.stemmed(), posBit, this::bitwiseOr); } + for (var ngram : sent.ngramStemmed) { + positionMask.merge(ngram, posBit, this::bitwiseOr); + } + for (var span : keywordExtractor.getKeywordsFromSentence(sent)) { positionMask.merge(sent.constructStemmedWordFromSpan(span), posBit, this::bitwiseOr); } diff --git a/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/DocumentKeywordExtractorTest.java b/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/DocumentKeywordExtractorTest.java index 8a4f3b6b..54577f80 100644 --- a/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/DocumentKeywordExtractorTest.java +++ b/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/DocumentKeywordExtractorTest.java @@ -5,6 +5,7 @@ import nu.marginalia.converting.processor.logic.dom.DomPruningFilter; import nu.marginalia.language.sentence.SentenceExtractor; import nu.marginalia.model.EdgeUrl; import nu.marginalia.model.idx.WordMetadata; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.term_frequency_dict.TermFrequencyDict; import org.jsoup.Jsoup; import org.junit.jupiter.api.Assertions; @@ -20,7 +21,9 @@ import java.util.Set; class DocumentKeywordExtractorTest { - DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels())); + DocumentKeywordExtractor extractor = new DocumentKeywordExtractor( + new TermFrequencyDict(WmsaHome.getLanguageModels()), + new NgramLexicon(WmsaHome.getLanguageModels())); SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels()); @Test @@ -56,6 +59,22 @@ class DocumentKeywordExtractorTest { } + @Test + public void testKeyboards2() throws IOException, URISyntaxException { + var resource = Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("test-data/keyboards.html"), + "Could not load word frequency table"); + String html = new String(resource.readAllBytes(), Charset.defaultCharset()); + var doc = Jsoup.parse(html); + doc.filter(new DomPruningFilter(0.5)); + + var keywords = extractor.extractKeywords(se.extractSentences(doc), new EdgeUrl("https://pmortensen.eu/world2/2021/12/24/rapoo-mechanical-keyboards-gotchas-and-setup/")); + + keywords.getWords().forEach((k, v) -> { + if (k.contains("_")) { + System.out.println(k + " " + new WordMetadata(v)); + } + }); + } @Test public void testKeyboards() throws IOException, URISyntaxException { var resource = Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("test-data/keyboards.html"), @@ -119,7 +138,9 @@ class DocumentKeywordExtractorTest { var doc = Jsoup.parse(html); doc.filter(new DomPruningFilter(0.5)); - DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels())); + DocumentKeywordExtractor extractor = new DocumentKeywordExtractor( + new TermFrequencyDict(WmsaHome.getLanguageModels()), + new NgramLexicon(WmsaHome.getLanguageModels())); SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels()); var keywords = extractor.extractKeywords(se.extractSentences(doc), new EdgeUrl("https://math.byu.edu/wiki/index.php/All_You_Need_To_Know_About_Earning_Money_Online")); diff --git a/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/SentenceExtractorTest.java b/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/SentenceExtractorTest.java index dabad6d1..bfc78a9c 100644 --- a/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/SentenceExtractorTest.java +++ b/code/features-convert/keyword-extraction/test/nu/marginalia/keyword/SentenceExtractorTest.java @@ -3,6 +3,7 @@ package nu.marginalia.keyword; import lombok.SneakyThrows; import nu.marginalia.LanguageModels; import nu.marginalia.language.sentence.SentenceExtractor; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.term_frequency_dict.TermFrequencyDict; import nu.marginalia.WmsaHome; import nu.marginalia.model.EdgeUrl; @@ -20,9 +21,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Tag("slow") class SentenceExtractorTest { - final LanguageModels lm = TestLanguageModels.getLanguageModels(); + static final LanguageModels lm = TestLanguageModels.getLanguageModels(); - SentenceExtractor se = new SentenceExtractor(lm); + static NgramLexicon ngramLexicon = new NgramLexicon(lm); + static SentenceExtractor se = new SentenceExtractor(lm); @SneakyThrows public static void main(String... args) throws IOException { @@ -32,11 +34,9 @@ class SentenceExtractorTest { System.out.println("Running"); - SentenceExtractor se = new SentenceExtractor(lm); - var dict = new TermFrequencyDict(lm); var url = new EdgeUrl("https://memex.marginalia.nu/"); - DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor(dict); + DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor(dict, ngramLexicon); for (;;) { long total = 0; diff --git a/code/features-convert/keyword-extraction/test/nu/marginalia/test/util/TestLanguageModels.java b/code/features-convert/keyword-extraction/test/nu/marginalia/test/util/TestLanguageModels.java index 0675559a..d857c048 100644 --- a/code/features-convert/keyword-extraction/test/nu/marginalia/test/util/TestLanguageModels.java +++ b/code/features-convert/keyword-extraction/test/nu/marginalia/test/util/TestLanguageModels.java @@ -26,13 +26,13 @@ public class TestLanguageModels { var languageModelsHome = getLanguageModelsPath(); return new LanguageModels( - languageModelsHome.resolve("ngrams.bin"), languageModelsHome.resolve("tfreq-new-algo3.bin"), languageModelsHome.resolve("opennlp-sentence.bin"), languageModelsHome.resolve("English.RDR"), languageModelsHome.resolve("English.DICT"), languageModelsHome.resolve("opennlp-tokens.bin"), - languageModelsHome.resolve("lid.176.ftz") + languageModelsHome.resolve("lid.176.ftz"), + languageModelsHome.resolve("segments.bin") ); } } diff --git a/code/features-convert/summary-extraction/test/nu/marginalia/summary/SummaryExtractorTest.java b/code/features-convert/summary-extraction/test/nu/marginalia/summary/SummaryExtractorTest.java index c1a326da..cabe558f 100644 --- a/code/features-convert/summary-extraction/test/nu/marginalia/summary/SummaryExtractorTest.java +++ b/code/features-convert/summary-extraction/test/nu/marginalia/summary/SummaryExtractorTest.java @@ -5,6 +5,7 @@ import nu.marginalia.WmsaHome; import nu.marginalia.keyword.DocumentKeywordExtractor; import nu.marginalia.language.sentence.SentenceExtractor; import nu.marginalia.model.EdgeUrl; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.summary.heuristic.*; import nu.marginalia.term_frequency_dict.TermFrequencyDict; import org.jsoup.Jsoup; @@ -25,7 +26,9 @@ class SummaryExtractorTest { @BeforeEach public void setUp() { - keywordExtractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels())); + keywordExtractor = new DocumentKeywordExtractor( + new TermFrequencyDict(WmsaHome.getLanguageModels()), + new NgramLexicon(WmsaHome.getLanguageModels())); setenceExtractor = new SentenceExtractor(WmsaHome.getLanguageModels()); summaryExtractor = new SummaryExtractor(255, diff --git a/code/functions/search-query/api/build.gradle b/code/functions/search-query/api/build.gradle index 727b5b86..1a8d55d2 100644 --- a/code/functions/search-query/api/build.gradle +++ b/code/functions/search-query/api/build.gradle @@ -30,6 +30,7 @@ dependencies { implementation libs.notnull implementation libs.guice implementation libs.gson + implementation libs.commons.lang3 implementation libs.bundles.protobuf implementation libs.bundles.grpc implementation libs.fastutil diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/IndexProtobufCodec.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/IndexProtobufCodec.java index 4b2f0032..4d2cf7a6 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/IndexProtobufCodec.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/IndexProtobufCodec.java @@ -1,7 +1,6 @@ package nu.marginalia.api.searchquery; -import nu.marginalia.api.searchquery.*; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.Bm25Parameters; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.query.limit.QueryLimits; @@ -45,33 +44,37 @@ public class IndexProtobufCodec { .build(); } - public static SearchSubquery convertSearchSubquery(RpcSubquery subquery) { + public static SearchQuery convertRpcQuery(RpcQuery query) { List> coherences = new ArrayList<>(); - for (int j = 0; j < subquery.getCoherencesCount(); j++) { - var coh = subquery.getCoherences(j); + for (int j = 0; j < query.getCoherencesCount(); j++) { + var coh = query.getCoherences(j); coherences.add(new ArrayList<>(coh.getCoherencesList())); } - return new SearchSubquery( - subquery.getIncludeList(), - subquery.getExcludeList(), - subquery.getAdviceList(), - subquery.getPriorityList(), + return new SearchQuery( + query.getCompiledQuery(), + query.getIncludeList(), + query.getExcludeList(), + query.getAdviceList(), + query.getPriorityList(), coherences ); } - public static RpcSubquery convertSearchSubquery(SearchSubquery searchSubquery) { + public static RpcQuery convertRpcQuery(SearchQuery searchQuery) { var subqueryBuilder = - RpcSubquery.newBuilder() - .addAllAdvice(searchSubquery.getSearchTermsAdvice()) - .addAllExclude(searchSubquery.getSearchTermsExclude()) - .addAllInclude(searchSubquery.getSearchTermsInclude()) - .addAllPriority(searchSubquery.getSearchTermsPriority()); - for (var coherences : searchSubquery.searchTermCoherences) { + RpcQuery.newBuilder() + .setCompiledQuery(searchQuery.compiledQuery) + .addAllInclude(searchQuery.getSearchTermsInclude()) + .addAllAdvice(searchQuery.getSearchTermsAdvice()) + .addAllExclude(searchQuery.getSearchTermsExclude()) + .addAllPriority(searchQuery.getSearchTermsPriority()); + + for (var coherences : searchQuery.searchTermCoherences) { subqueryBuilder.addCoherencesBuilder().addAllCoherences(coherences); } + return subqueryBuilder.build(); } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/QueryProtobufCodec.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/QueryProtobufCodec.java index 28d14c82..2907992d 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/QueryProtobufCodec.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/QueryProtobufCodec.java @@ -2,7 +2,6 @@ package nu.marginalia.api.searchquery; import lombok.SneakyThrows; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.api.searchquery.model.results.SearchResultItem; @@ -14,7 +13,6 @@ import nu.marginalia.api.searchquery.model.query.QueryParams; import nu.marginalia.api.searchquery.model.query.QueryResponse; import java.util.ArrayList; -import java.util.List; public class QueryProtobufCodec { @@ -23,9 +21,7 @@ public class QueryProtobufCodec { builder.addAllDomains(request.getDomainIdsList()); - for (var subquery : query.specs.subqueries) { - builder.addSubqueries(IndexProtobufCodec.convertSearchSubquery(subquery)); - } + builder.setQuery(IndexProtobufCodec.convertRpcQuery(query.specs.query)); builder.setSearchSetIdentifier(query.specs.searchSetIdentifier); builder.setHumanQuery(request.getHumanQuery()); @@ -51,9 +47,7 @@ public class QueryProtobufCodec { public static RpcIndexQuery convertQuery(String humanQuery, ProcessedQuery query) { var builder = RpcIndexQuery.newBuilder(); - for (var subquery : query.specs.subqueries) { - builder.addSubqueries(IndexProtobufCodec.convertSearchSubquery(subquery)); - } + builder.setQuery(IndexProtobufCodec.convertRpcQuery(query.specs.query)); builder.setSearchSetIdentifier(query.specs.searchSetIdentifier); builder.setHumanQuery(humanQuery); @@ -127,6 +121,7 @@ public class QueryProtobufCodec { results.getPubYear(), // ??, results.getDataHash(), results.getWordsTotal(), + results.getBestPositions(), results.getRankingScore() ); } @@ -139,31 +134,26 @@ public class QueryProtobufCodec { return new SearchResultItem( rawItem.getCombinedId(), + rawItem.getEncodedDocMetadata(), + rawItem.getHtmlFeatures(), keywordScores, rawItem.getResultsFromDomain(), + rawItem.getHasPriorityTerms(), Double.NaN // Not set ); } private static SearchResultKeywordScore convertKeywordScore(RpcResultKeywordScore keywordScores) { return new SearchResultKeywordScore( - keywordScores.getSubquery(), keywordScores.getKeyword(), - keywordScores.getEncodedWordMetadata(), - keywordScores.getEncodedDocMetadata(), - keywordScores.getHtmlFeatures() + -1, // termId is internal to index service + keywordScores.getEncodedWordMetadata() ); } private static SearchSpecification convertSearchSpecification(RpcIndexQuery specs) { - List subqueries = new ArrayList<>(specs.getSubqueriesCount()); - - for (int i = 0; i < specs.getSubqueriesCount(); i++) { - subqueries.add(IndexProtobufCodec.convertSearchSubquery(specs.getSubqueries(i))); - } - return new SearchSpecification( - subqueries, + IndexProtobufCodec.convertRpcQuery(specs.getQuery()), specs.getDomainsList(), specs.getSearchSetIdentifier(), specs.getHumanQuery(), @@ -182,7 +172,6 @@ public class QueryProtobufCodec { .addAllDomainIds(params.domainIds()) .addAllTacitAdvice(params.tacitAdvice()) .addAllTacitExcludes(params.tacitExcludes()) - .addAllTacitIncludes(params.tacitIncludes()) .addAllTacitPriority(params.tacitPriority()) .setHumanQuery(params.humanQuery()) .setQueryLimits(IndexProtobufCodec.convertQueryLimits(params.limits())) @@ -215,6 +204,7 @@ public class QueryProtobufCodec { rpcDecoratedResultItem.getPubYear(), rpcDecoratedResultItem.getDataHash(), rpcDecoratedResultItem.getWordsTotal(), + rpcDecoratedResultItem.getBestPositions(), rpcDecoratedResultItem.getRankingScore() ); } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQuery.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQuery.java new file mode 100644 index 00000000..356a1d86 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQuery.java @@ -0,0 +1,80 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import org.jetbrains.annotations.NotNull; + +import java.util.Iterator; +import java.util.function.*; +import java.util.stream.IntStream; +import java.util.stream.Stream; + + +/** A compiled index service query. The class separates the topology of the query from the data, + * and it's possible to create new queries supplanting the data */ +public class CompiledQuery implements Iterable { + + /** The root expression, conveys the topology of the query */ + public final CqExpression root; + + private final CqData data; + + public CompiledQuery(CqExpression root, CqData data) { + this.root = root; + this.data = data; + } + + public CompiledQuery(CqExpression root, T[] data) { + this.root = root; + this.data = new CqData<>(data); + } + + /** Exists for testing, creates a simple query that ANDs all the provided items */ + public static CompiledQuery just(T... item) { + return new CompiledQuery<>(new CqExpression.And( + IntStream.range(0, item.length).mapToObj(CqExpression.Word::new).toList() + ), item); + } + + /** Create a new CompiledQuery mapping the leaf nodes using the provided mapper */ + public CompiledQuery map(Class clazz, Function mapper) { + return new CompiledQuery<>( + root, + data.map(clazz, mapper) + ); + } + + public CompiledQueryLong mapToLong(ToLongFunction mapper) { + return new CompiledQueryLong(root, data.mapToLong(mapper)); + } + + public CompiledQueryLong mapToInt(ToIntFunction mapper) { + return new CompiledQueryLong(root, data.mapToInt(mapper)); + } + + public CqExpression root() { + return root; + } + + public Stream stream() { + return data.stream(); + } + + public IntStream indices() { + return IntStream.range(0, data.size()); + } + + public T at(int index) { + return data.get(index); + } + + @NotNull + @Override + public Iterator iterator() { + return stream().iterator(); + } + + public int size() { + return data.size(); + } + + +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryInt.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryInt.java new file mode 100644 index 00000000..9e26c35c --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryInt.java @@ -0,0 +1,44 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import java.util.stream.IntStream; + + +/** A compiled index service query */ +public class CompiledQueryInt { + private final CqExpression root; + private final CqDataInt data; + + public CompiledQueryInt(CqExpression root, CqDataInt data) { + this.root = root; + this.data = data; + } + + + public CqExpression root() { + return root; + } + + public IntStream stream() { + return data.stream(); + } + + public IntStream indices() { + return IntStream.range(0, data.size()); + } + + public long at(int index) { + return data.get(index); + } + + public int[] copyData() { + return data.copyData(); + } + + public boolean isEmpty() { + return data.size() == 0; + } + + public int size() { + return data.size(); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java new file mode 100644 index 00000000..718aaca7 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryLong.java @@ -0,0 +1,54 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import org.jetbrains.annotations.NotNull; + +import java.util.Iterator; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + + +/** A compiled index service query */ +public class CompiledQueryLong implements Iterable { + public final CqExpression root; + public final CqDataLong data; + + public CompiledQueryLong(CqExpression root, CqDataLong data) { + this.root = root; + this.data = data; + } + + + public CqExpression root() { + return root; + } + + public LongStream stream() { + return data.stream(); + } + + public IntStream indices() { + return IntStream.range(0, data.size()); + } + + public long at(int index) { + return data.get(index); + } + + @NotNull + @Override + public Iterator iterator() { + return stream().iterator(); + } + + public long[] copyData() { + return data.copyData(); + } + + public boolean isEmpty() { + return data.size() == 0; + } + + public int size() { + return data.size(); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParser.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParser.java new file mode 100644 index 00000000..ae197fb9 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParser.java @@ -0,0 +1,113 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import org.apache.commons.lang3.StringUtils; + +import java.util.*; + +/** Parser for a compiled index query */ +public class CompiledQueryParser { + + public static CompiledQuery parse(String query) { + List parts = tokenize(query); + + if (parts.isEmpty()) { + return new CompiledQuery<>( + CqExpression.empty(), + new CqData<>(new String[0]) + ); + } + + // We aren't interested in a binary tree representation, but an n-ary tree one, + // so a somewhat unusual parsing technique is used to avoid having an additional + // flattening step at the end. + + // This is only possible due to the trivial and unambiguous grammar of the compiled queries + + List parenState = new ArrayList<>(); + parenState.add(new AndOrState()); + + Map wordIds = new HashMap<>(); + + for (var part : parts) { + var head = parenState.getLast(); + + if (part.equals("|")) { + head.or(); + } + else if (part.equals("(")) { + parenState.addLast(new AndOrState()); + } + else if (part.equals(")")) { + if (parenState.size() < 2) { + throw new IllegalStateException("Mismatched parentheses in expression: " + query); + } + parenState.removeLast(); + parenState.getLast().and(head.closeOr()); + } + else { + head.and( + new CqExpression.Word( + wordIds.computeIfAbsent(part, p -> wordIds.size()) + ) + ); + } + } + + if (parenState.size() != 1) + throw new IllegalStateException("Mismatched parentheses in expression: " + query); + + // Construct the CompiledQuery object with String:s as leaves + var root = parenState.getLast().closeOr(); + + String[] cqData = new String[wordIds.size()]; + wordIds.forEach((w, i) -> cqData[i] = w); + return new CompiledQuery<>(root, new CqData<>(cqData)); + + } + + private static class AndOrState { + private List andState = new ArrayList<>(); + private List orState = new ArrayList<>(); + + /** Add a new item to the and-list */ + public void and(CqExpression e) { + andState.add(e); + } + + /** Turn the and-list into an expression on the or-list, and then start a new and-list */ + public void or() { + closeAnd(); + + andState = new ArrayList<>(); + } + + /** Turn the and-list into an And-expression in the or-list */ + private void closeAnd() { + if (andState.size() == 1) + orState.add(andState.getFirst()); + else if (!andState.isEmpty()) + orState.add(new CqExpression.And(andState)); + } + + /** Finalize the current and-list, then turn the or-list into an Or-expression */ + public CqExpression closeOr() { + closeAnd(); + + if (orState.isEmpty()) + return CqExpression.empty(); + if (orState.size() == 1) + return orState.getFirst(); + + return new CqExpression.Or(orState); + } + } + + private static List tokenize(String query) { + // Each token is guaranteed to be separated by one or more space characters + + return Arrays.stream(StringUtils.split(query, ' ')) + .filter(StringUtils::isNotBlank) + .toList(); + } + +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqData.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqData.java new file mode 100644 index 00000000..145f3f0f --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqData.java @@ -0,0 +1,60 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import java.lang.reflect.Array; +import java.util.Arrays; +import java.util.function.Function; +import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; +import java.util.stream.Stream; + +public class CqData { + private final T[] data; + + public CqData(T[] data) { + this.data = data; + } + + @SuppressWarnings("unchecked") + public CqData map(Class clazz, Function mapper) { + T2[] newData = (T2[]) Array.newInstance(clazz, data.length); + for (int i = 0; i < data.length; i++) { + newData[i] = mapper.apply((T) data[i]); + } + + return new CqData<>(newData); + } + + public CqDataLong mapToLong(ToLongFunction mapper) { + long[] newData = new long[data.length]; + for (int i = 0; i < data.length; i++) { + newData[i] = mapper.applyAsLong((T) data[i]); + } + + return new CqDataLong(newData); + } + + public CqDataLong mapToInt(ToIntFunction mapper) { + long[] newData = new long[data.length]; + for (int i = 0; i < data.length; i++) { + newData[i] = mapper.applyAsInt((T) data[i]); + } + + return new CqDataLong(newData); + } + + public T get(int i) { + return data[i]; + } + + public T get(CqExpression.Word w) { + return data[w.idx()]; + } + + public Stream stream() { + return Arrays.stream(data); + } + + public int size() { + return data.length; + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataInt.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataInt.java new file mode 100644 index 00000000..24991686 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataInt.java @@ -0,0 +1,31 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import java.util.Arrays; +import java.util.stream.IntStream; + +public class CqDataInt { + private final int[] data; + + public CqDataInt(int[] data) { + this.data = data; + } + + public int get(int i) { + return data[i]; + } + public int get(CqExpression.Word w) { + return data[w.idx()]; + } + + public IntStream stream() { + return Arrays.stream(data); + } + + public int size() { + return data.length; + } + + public int[] copyData() { + return Arrays.copyOf(data, data.length); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java new file mode 100644 index 00000000..24f76b13 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqDataLong.java @@ -0,0 +1,31 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import java.util.Arrays; +import java.util.stream.LongStream; + +public class CqDataLong { + private final long[] data; + + public CqDataLong(long[] data) { + this.data = data; + } + + public long get(int i) { + return data[i]; + } + public long get(CqExpression.Word w) { + return data[w.idx()]; + } + + public LongStream stream() { + return Arrays.stream(data); + } + + public int size() { + return data.length; + } + + public long[] copyData() { + return Arrays.copyOf(data, data.length); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqExpression.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqExpression.java new file mode 100644 index 00000000..e9972526 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/CqExpression.java @@ -0,0 +1,170 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import java.util.List; +import java.util.StringJoiner; +import java.util.stream.Stream; + +/** Expression in a parsed index service query + * + */ +public sealed interface CqExpression { + + Stream stream(); + + /** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */ + long visit(LongVisitor visitor); + /** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */ + double visit(DoubleVisitor visitor); + /** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */ + int visit(IntVisitor visitor); + /** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */ + boolean visit(BoolVisitor visitor); + + T visit(ObjectVisitor visitor); + + static CqExpression empty() { + return new Or(List.of()); + } + + + record And(List parts) implements CqExpression { + @Override + public Stream stream() { + return parts.stream().flatMap(CqExpression::stream); + } + + @Override + public long visit(LongVisitor visitor) { + return visitor.onAnd(parts); + } + + @Override + public double visit(DoubleVisitor visitor) { + return visitor.onAnd(parts); + } + + @Override + public int visit(IntVisitor visitor) { + return visitor.onAnd(parts); + } + + @Override + public boolean visit(BoolVisitor visitor) { + return visitor.onAnd(parts); + } + + @Override + public T visit(ObjectVisitor visitor) { return visitor.onAnd(parts); } + + public String toString() { + StringJoiner sj = new StringJoiner(", ", "And[ ", "]"); + parts.forEach(part -> sj.add(part.toString())); + return sj.toString(); + } + + } + + record Or(List parts) implements CqExpression { + @Override + public Stream stream() { + return parts.stream().flatMap(CqExpression::stream); + } + + @Override + public long visit(LongVisitor visitor) { + return visitor.onOr(parts); + } + + @Override + public double visit(DoubleVisitor visitor) { + return visitor.onOr(parts); + } + + @Override + public int visit(IntVisitor visitor) { + return visitor.onOr(parts); + } + + @Override + public boolean visit(BoolVisitor visitor) { + return visitor.onOr(parts); + } + + @Override + public T visit(ObjectVisitor visitor) { return visitor.onOr(parts); } + + public String toString() { + StringJoiner sj = new StringJoiner(", ", "Or[ ", "]"); + parts.forEach(part -> sj.add(part.toString())); + return sj.toString(); + } + + + } + + record Word(int idx) implements CqExpression { + @Override + public Stream stream() { + return Stream.of(this); + } + + @Override + public long visit(LongVisitor visitor) { + return visitor.onLeaf(idx); + } + + @Override + public double visit(DoubleVisitor visitor) { + return visitor.onLeaf(idx); + } + + @Override + public int visit(IntVisitor visitor) { + return visitor.onLeaf(idx); + } + + @Override + public boolean visit(BoolVisitor visitor) { + return visitor.onLeaf(idx); + } + + @Override + public T visit(ObjectVisitor visitor) { return visitor.onLeaf(idx); } + + @Override + public String toString() { + return Integer.toString(idx); + } + } + + interface LongVisitor { + long onAnd(List parts); + long onOr(List parts); + long onLeaf(int idx); + } + + interface IntVisitor { + int onAnd(List parts); + int onOr(List parts); + int onLeaf(int idx); + } + + interface BoolVisitor { + boolean onAnd(List parts); + boolean onOr(List parts); + boolean onLeaf(int idx); + } + + interface DoubleVisitor { + double onAnd(List parts); + double onOr(List parts); + double onLeaf(int idx); + } + + interface ObjectVisitor { + T onAnd(List parts); + T onOr(List parts); + T onLeaf(int idx); + } + +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java new file mode 100644 index 00000000..7e8ca8ec --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregates.java @@ -0,0 +1,67 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import it.unimi.dsi.fastutil.longs.LongSet; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.*; + +/** Contains methods for aggregating across a CompiledQuery or CompiledQueryLong */ +public class CompiledQueryAggregates { + /** Compiled query aggregate that for a single boolean that treats or-branches as logical OR, + * and and-branches as logical AND operations. Will return true if there exists a path through + * the query where the provided predicate returns true for each item. + */ + static public boolean booleanAggregate(CompiledQuery query, Predicate predicate) { + return query.root.visit(new CqBooleanAggregate(query, predicate)); + } + static public boolean booleanAggregate(CompiledQueryLong query, LongPredicate predicate) { + return query.root.visit(new CqBooleanAggregate(query, predicate)); + } + + + /** Compiled query aggregate that for a 64b bitmask that treats or-branches as logical OR, + * and and-branches as logical AND operations. + */ + public static long longBitmaskAggregate(CompiledQuery query, ToLongFunction operator) { + return query.root.visit(new CqLongBitmaskOperator(query, operator)); + } + public static long longBitmaskAggregate(CompiledQueryLong query, LongUnaryOperator operator) { + return query.root.visit(new CqLongBitmaskOperator(query, operator)); + } + + /** Apply the operator to each leaf node, then return the highest minimum value found along any path */ + public static int intMaxMinAggregate(CompiledQuery query, ToIntFunction operator) { + return query.root.visit(new CqIntMaxMinOperator(query, operator)); + } + + /** Apply the operator to each leaf node, then return the highest minimum value found along any path */ + public static int intMaxMinAggregate(CompiledQueryLong query, LongToIntFunction operator) { + return query.root.visit(new CqIntMaxMinOperator(query, operator)); + } + + /** Apply the operator to each leaf node, and then return the highest sum of values possible + * through each branch in the compiled query. + * + */ + public static double doubleSumAggregate(CompiledQuery query, ToDoubleFunction operator) { + return query.root.visit(new CqDoubleSumOperator(query, operator)); + } + + /** Enumerate all possible paths through the compiled query */ + public static List queriesAggregate(CompiledQueryLong query) { + return new ArrayList<>(query.root().visit(new CqQueryPathsOperator(query))); + } + + /** Using the bitwise AND operator, aggregate all possible combined values of the long generated by the provided operator */ + public static LongSet positionsAggregate(CompiledQuery query, ToLongFunction operator) { + return query.root().visit(new CqPositionsOperator(query, operator)); + } + + /** Using the bitwise AND operator, aggregate all possible combined values of the long generated by the provided operator */ + public static LongSet positionsAggregate(CompiledQueryLong query, LongUnaryOperator operator) { + return query.root().visit(new CqPositionsOperator(query, operator)); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqBooleanAggregate.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqBooleanAggregate.java new file mode 100644 index 00000000..2a87ec79 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqBooleanAggregate.java @@ -0,0 +1,46 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.List; +import java.util.function.IntPredicate; +import java.util.function.LongPredicate; +import java.util.function.Predicate; + +public class CqBooleanAggregate implements CqExpression.BoolVisitor { + + private final IntPredicate predicate; + + public CqBooleanAggregate(CompiledQuery query, Predicate objPred) { + this.predicate = idx -> objPred.test(query.at(idx)); + } + + public CqBooleanAggregate(CompiledQueryLong query, LongPredicate longPredicate) { + this.predicate = idx -> longPredicate.test(query.at(idx)); + } + + @Override + public boolean onAnd(List parts) { + for (var part : parts) { + if (!part.visit(this)) // short-circuit + return false; + } + return true; + } + + @Override + public boolean onOr(List parts) { + for (var part : parts) { + if (part.visit(this)) // short-circuit + return true; + } + return false; + } + + @Override + public boolean onLeaf(int idx) { + return predicate.test(idx); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqDoubleSumOperator.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqDoubleSumOperator.java new file mode 100644 index 00000000..082de29e --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqDoubleSumOperator.java @@ -0,0 +1,46 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.List; +import java.util.function.IntToDoubleFunction; +import java.util.function.LongToDoubleFunction; +import java.util.function.ToDoubleFunction; + +public class CqDoubleSumOperator implements CqExpression.DoubleVisitor { + + private final IntToDoubleFunction operator; + + public CqDoubleSumOperator(CompiledQuery query, ToDoubleFunction operator) { + this.operator = idx -> operator.applyAsDouble(query.at(idx)); + } + + public CqDoubleSumOperator(IntToDoubleFunction operator) { + this.operator = operator; + } + + @Override + public double onAnd(List parts) { + double value = 0; + for (var part : parts) { + value += part.visit(this); + } + return value; + } + + @Override + public double onOr(List parts) { + double value = parts.getFirst().visit(this); + for (int i = 1; i < parts.size(); i++) { + value = Math.max(value, parts.get(i).visit(this)); + } + return value; + } + + @Override + public double onLeaf(int idx) { + return operator.applyAsDouble(idx); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqIntMaxMinOperator.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqIntMaxMinOperator.java new file mode 100644 index 00000000..621dff73 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqIntMaxMinOperator.java @@ -0,0 +1,47 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.List; +import java.util.function.IntUnaryOperator; +import java.util.function.LongToIntFunction; +import java.util.function.ToIntFunction; + +public class CqIntMaxMinOperator implements CqExpression.IntVisitor { + + private final IntUnaryOperator operator; + + + public CqIntMaxMinOperator(CompiledQuery query, ToIntFunction operator) { + this.operator = idx -> operator.applyAsInt(query.at(idx)); + } + + public CqIntMaxMinOperator(CompiledQueryLong query, LongToIntFunction operator) { + this.operator = idx -> operator.applyAsInt(query.at(idx)); + } + + @Override + public int onAnd(List parts) { + int value = parts.getFirst().visit(this); + for (int i = 1; i < parts.size(); i++) { + value = Math.min(value, parts.get(i).visit(this)); + } + return value; + } + + @Override + public int onOr(List parts) { + int value = parts.getFirst().visit(this); + for (int i = 1; i < parts.size(); i++) { + value = Math.max(value, parts.get(i).visit(this)); + } + return value; + } + + @Override + public int onLeaf(int idx) { + return operator.applyAsInt(idx); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqLongBitmaskOperator.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqLongBitmaskOperator.java new file mode 100644 index 00000000..b64029c1 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqLongBitmaskOperator.java @@ -0,0 +1,45 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.List; +import java.util.function.IntToLongFunction; +import java.util.function.LongUnaryOperator; +import java.util.function.ToLongFunction; + +public class CqLongBitmaskOperator implements CqExpression.LongVisitor { + + private final IntToLongFunction operator; + + public CqLongBitmaskOperator(CompiledQuery query, ToLongFunction operator) { + this.operator = idx-> operator.applyAsLong(query.at(idx)); + } + public CqLongBitmaskOperator(CompiledQueryLong query, LongUnaryOperator operator) { + this.operator = idx-> operator.applyAsLong(query.at(idx)); + } + + @Override + public long onAnd(List parts) { + long value = ~0L; + for (var part : parts) { + value &= part.visit(this); + } + return value; + } + + @Override + public long onOr(List parts) { + long value = 0L; + for (var part : parts) { + value |= part.visit(this); + } + return value; + } + + @Override + public long onLeaf(int idx) { + return operator.applyAsLong(idx); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqPositionsOperator.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqPositionsOperator.java new file mode 100644 index 00000000..715c4cb2 --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqPositionsOperator.java @@ -0,0 +1,85 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import it.unimi.dsi.fastutil.longs.LongArraySet; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongSet; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.List; +import java.util.function.IntToLongFunction; +import java.util.function.LongUnaryOperator; +import java.util.function.ToLongFunction; + +public class CqPositionsOperator implements CqExpression.ObjectVisitor { + private final IntToLongFunction operator; + + public CqPositionsOperator(CompiledQuery query, ToLongFunction operator) { + this.operator = idx -> operator.applyAsLong(query.at(idx)); + } + + public CqPositionsOperator(CompiledQueryLong query, LongUnaryOperator operator) { + this.operator = idx -> operator.applyAsLong(query.at(idx)); + } + + @Override + public LongSet onAnd(List parts) { + LongSet ret = new LongArraySet(); + + for (var part : parts) { + ret = comineSets(ret, part.visit(this)); + } + + return ret; + } + + private LongSet comineSets(LongSet a, LongSet b) { + if (a.isEmpty()) + return b; + if (b.isEmpty()) + return a; + + LongSet ret = newSet(a.size() * b.size()); + + var ai = a.longIterator(); + + while (ai.hasNext()) { + long aval = ai.nextLong(); + + var bi = b.longIterator(); + while (bi.hasNext()) { + ret.add(aval & bi.nextLong()); + } + } + + return ret; + } + + @Override + public LongSet onOr(List parts) { + LongSet ret = newSet(parts.size()); + + for (var part : parts) { + ret.addAll(part.visit(this)); + } + + return ret; + } + + @Override + public LongSet onLeaf(int idx) { + var set = newSet(1); + set.add(operator.applyAsLong(idx)); + return set; + } + + /** Allocate a new set suitable for a collection with the provided cardinality */ + private LongSet newSet(int cardinality) { + if (cardinality < 8) + return new LongArraySet(cardinality); + else + return new LongOpenHashSet(cardinality); + } + +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqQueryPathsOperator.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqQueryPathsOperator.java new file mode 100644 index 00000000..2339104e --- /dev/null +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/compiled/aggregate/CqQueryPathsOperator.java @@ -0,0 +1,75 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import it.unimi.dsi.fastutil.longs.LongArraySet; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongSet; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; + +import java.util.ArrayList; +import java.util.List; + +public class CqQueryPathsOperator implements CqExpression.ObjectVisitor> { + private final CompiledQueryLong query; + + public CqQueryPathsOperator(CompiledQueryLong query) { + this.query = query; + } + + @Override + public List onAnd(List parts) { + return parts.stream() + .map(expr -> expr.visit(this)) + .reduce(List.of(), this::combineAnd); + } + + private List combineAnd(List a, List b) { + // No-op cases + if (a.isEmpty()) + return b; + if (b.isEmpty()) + return a; + + // Simple cases + if (a.size() == 1) { + b.forEach(set -> set.addAll(a.getFirst())); + return b; + } + else if (b.size() == 1) { + a.forEach(set -> set.addAll(b.getFirst())); + return a; + } + + // Case where we AND two ORs + List ret = new ArrayList<>(); + + for (var aPart : a) { + for (var bPart : b) { + LongSet set = new LongOpenHashSet(aPart.size() + bPart.size()); + set.addAll(aPart); + set.addAll(bPart); + ret.add(set); + } + } + + return ret; + } + + @Override + public List onOr(List parts) { + List ret = new ArrayList<>(); + + for (var part : parts) { + ret.addAll(part.visit(this)); + } + + return ret; + } + + @Override + public List onLeaf(int idx) { + var set = new LongArraySet(1); + set.add(query.at(idx)); + return List.of(set); + } +} diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/QueryResponse.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/QueryResponse.java index 80e5b61a..1834c08f 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/QueryResponse.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/QueryResponse.java @@ -13,10 +13,6 @@ public record QueryResponse(SearchSpecification specs, String domain) { public Set getAllKeywords() { - Set keywords = new HashSet<>(100); - for (var sq : specs.subqueries) { - keywords.addAll(sq.searchTermsInclude); - } - return keywords; + return new HashSet<>(specs.query.searchTermsInclude); } } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSubquery.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchQuery.java similarity index 75% rename from code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSubquery.java rename to code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchQuery.java index 3798ae89..ffe02868 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSubquery.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchQuery.java @@ -13,9 +13,12 @@ import java.util.stream.Collectors; @AllArgsConstructor @With @EqualsAndHashCode -public class SearchSubquery { +public class SearchQuery { - /** These terms must be present in the document and are used in ranking*/ + /** An infix style expression that encodes the required terms in the query */ + public final String compiledQuery; + + /** All terms that appear in {@see compiledQuery} */ public final List searchTermsInclude; /** These terms must be absent from the document */ @@ -33,7 +36,8 @@ public class SearchSubquery { @Deprecated // why does this exist? private double value = 0; - public SearchSubquery() { + public SearchQuery() { + this.compiledQuery = ""; this.searchTermsInclude = new ArrayList<>(); this.searchTermsExclude = new ArrayList<>(); this.searchTermsAdvice = new ArrayList<>(); @@ -41,11 +45,13 @@ public class SearchSubquery { this.searchTermCoherences = new ArrayList<>(); } - public SearchSubquery(List searchTermsInclude, - List searchTermsExclude, - List searchTermsAdvice, - List searchTermsPriority, - List> searchTermCoherences) { + public SearchQuery(String compiledQuery, + List searchTermsInclude, + List searchTermsExclude, + List searchTermsAdvice, + List searchTermsPriority, + List> searchTermCoherences) { + this.compiledQuery = compiledQuery; this.searchTermsInclude = searchTermsInclude; this.searchTermsExclude = searchTermsExclude; this.searchTermsAdvice = searchTermsAdvice; @@ -54,7 +60,7 @@ public class SearchSubquery { } @Deprecated // why does this exist? - public SearchSubquery setValue(double value) { + public SearchQuery setValue(double value) { if (Double.isInfinite(value) || Double.isNaN(value)) { this.value = Double.MAX_VALUE; } else { @@ -66,7 +72,7 @@ public class SearchSubquery { @Override public String toString() { StringBuilder sb = new StringBuilder(); - if (!searchTermsInclude.isEmpty()) sb.append("include=").append(searchTermsInclude.stream().collect(Collectors.joining(",", "[", "] "))); + if (!compiledQuery.isEmpty()) sb.append("compiledQuery=").append(compiledQuery).append(", "); if (!searchTermsExclude.isEmpty()) sb.append("exclude=").append(searchTermsExclude.stream().collect(Collectors.joining(",", "[", "] "))); if (!searchTermsAdvice.isEmpty()) sb.append("advice=").append(searchTermsAdvice.stream().collect(Collectors.joining(",", "[", "] "))); if (!searchTermsPriority.isEmpty()) sb.append("priority=").append(searchTermsPriority.stream().collect(Collectors.joining(",", "[", "] "))); diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSpecification.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSpecification.java index be2a6895..bbb5b7ae 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSpecification.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/query/SearchSpecification.java @@ -10,7 +10,7 @@ import java.util.List; @ToString @Getter @Builder @With @AllArgsConstructor public class SearchSpecification { - public List subqueries; + public SearchQuery query; /** If present and not empty, limit the search to these domain IDs */ public List domains; diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/DecoratedSearchResultItem.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/DecoratedSearchResultItem.java index b099dc01..df48ea64 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/DecoratedSearchResultItem.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/DecoratedSearchResultItem.java @@ -30,6 +30,7 @@ public class DecoratedSearchResultItem implements Comparable fullCounts = new Object2IntOpenHashMap<>(10, 0.5f); - private final Object2IntOpenHashMap priorityCounts = new Object2IntOpenHashMap<>(10, 0.5f); + /** CqDataInt associated with frequency information of the terms in the query + * in the full index. The dataset is indexed by the compiled query. */ + public final CqDataInt fullCounts; + + /** CqDataInt associated with frequency information of the terms in the query + * in the full index. The dataset is indexed by the compiled query. */ + public final CqDataInt priorityCounts; public ResultRankingContext(int docCount, ResultRankingParameters params, - Map fullCounts, - Map prioCounts - ) { + CqDataInt fullCounts, + CqDataInt prioCounts) + { this.docCount = docCount; this.params = params; - this.fullCounts.putAll(fullCounts); - this.priorityCounts.putAll(prioCounts); + this.fullCounts = fullCounts; + this.priorityCounts = prioCounts; } public int termFreqDocCount() { return docCount; } - public int frequency(String keyword) { - return fullCounts.getOrDefault(keyword, 1); - } - - public int priorityFrequency(String keyword) { - return priorityCounts.getOrDefault(keyword, 1); - } } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultItem.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultItem.java index cc02ae28..ad8b8cb1 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultItem.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultItem.java @@ -15,15 +15,30 @@ public class SearchResultItem implements Comparable { * probably not what you want, use getDocumentId() instead */ public final long combinedId; + /** Encoded document metadata */ + public final long encodedDocMetadata; + + /** Encoded html features of document */ + + public final int htmlFeatures; + /** How did the subqueries match against the document ? */ public final List keywordScores; /** How many other potential results existed in the same domain */ public int resultsFromDomain; - public SearchResultItem(long combinedId, int scoresCount) { + public boolean hasPrioTerm; + + public SearchResultItem(long combinedId, + long encodedDocMetadata, + int htmlFeatures, + boolean hasPrioTerm) { this.combinedId = combinedId; - this.keywordScores = new ArrayList<>(scoresCount); + this.encodedDocMetadata = encodedDocMetadata; + this.keywordScores = new ArrayList<>(); + this.htmlFeatures = htmlFeatures; + this.hasPrioTerm = hasPrioTerm; } @@ -76,4 +91,6 @@ public class SearchResultItem implements Comparable { return Long.compare(this.combinedId, o.combinedId); } + + } diff --git a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultKeywordScore.java b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultKeywordScore.java index b84dad0b..212b2302 100644 --- a/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultKeywordScore.java +++ b/code/functions/search-query/api/java/nu/marginalia/api/searchquery/model/results/SearchResultKeywordScore.java @@ -2,41 +2,27 @@ package nu.marginalia.api.searchquery.model.results; import nu.marginalia.model.idx.WordFlags; import nu.marginalia.model.idx.WordMetadata; -import nu.marginalia.model.idx.DocumentMetadata; import java.util.Objects; public final class SearchResultKeywordScore { - public final int subquery; + public final long termId; public final String keyword; private final long encodedWordMetadata; - private final long encodedDocMetadata; - private final int htmlFeatures; - - public SearchResultKeywordScore(int subquery, - String keyword, - long encodedWordMetadata, - long encodedDocMetadata, - int htmlFeatures) { - this.subquery = subquery; + public SearchResultKeywordScore(String keyword, + long termId, + long encodedWordMetadata) { + this.termId = termId; this.keyword = keyword; this.encodedWordMetadata = encodedWordMetadata; - this.encodedDocMetadata = encodedDocMetadata; - this.htmlFeatures = htmlFeatures; } public boolean hasTermFlag(WordFlags flag) { return WordMetadata.hasFlags(encodedWordMetadata, flag.asBit()); } - public int positionCount() { - return Long.bitCount(positions()); - } - public int subquery() { - return subquery; - } public long positions() { return WordMetadata.decodePositions(encodedWordMetadata); } @@ -45,46 +31,28 @@ public final class SearchResultKeywordScore { return keyword.contains(":") || hasTermFlag(WordFlags.Synthetic); } - public boolean isKeywordRegular() { - return !keyword.contains(":") - && !hasTermFlag(WordFlags.Synthetic); - } - public long encodedWordMetadata() { return encodedWordMetadata; } - public long encodedDocMetadata() { - return encodedDocMetadata; - } - - public int htmlFeatures() { - return htmlFeatures; - } - @Override public boolean equals(Object obj) { if (obj == this) return true; if (obj == null || obj.getClass() != this.getClass()) return false; var that = (SearchResultKeywordScore) obj; - return this.subquery == that.subquery && - Objects.equals(this.keyword, that.keyword) && - this.encodedWordMetadata == that.encodedWordMetadata && - this.encodedDocMetadata == that.encodedDocMetadata; + return Objects.equals(this.termId, that.termId); } @Override public int hashCode() { - return Objects.hash(subquery, keyword, encodedWordMetadata, encodedDocMetadata); + return Objects.hash(termId); } @Override public String toString() { return "SearchResultKeywordScore[" + - "set=" + subquery + ", " + "keyword=" + keyword + ", " + - "encodedWordMetadata=" + new WordMetadata(encodedWordMetadata) + ", " + - "encodedDocMetadata=" + new DocumentMetadata(encodedDocMetadata) + ']'; + "encodedWordMetadata=" + new WordMetadata(encodedWordMetadata) + ']'; } } diff --git a/code/functions/search-query/api/src/main/protobuf/query-api.proto b/code/functions/search-query/api/src/main/protobuf/query-api.proto index f5ec5e8d..bae06e66 100644 --- a/code/functions/search-query/api/src/main/protobuf/query-api.proto +++ b/code/functions/search-query/api/src/main/protobuf/query-api.proto @@ -52,7 +52,7 @@ message RpcTemporalBias { /* Index service query request */ message RpcIndexQuery { - repeated RpcSubquery subqueries = 1; + RpcQuery query = 1; repeated int32 domains = 2; // (optional) A list of domain IDs to consider string searchSetIdentifier = 3; // (optional) A named set of domains to consider string humanQuery = 4; // The search query as the user entered it @@ -91,23 +91,23 @@ message RpcDecoratedResultItem { int64 dataHash = 9; int32 wordsTotal = 10; double rankingScore = 11; // The ranking score of this search result item, lower is better + int64 bestPositions = 12; } /** A raw index-service view of a search result */ message RpcRawResultItem { int64 combinedId = 1; // raw ID with bit-encoded ranking information still present int32 resultsFromDomain = 2; // number of other results from the same domain - repeated RpcResultKeywordScore keywordScores = 3; + int64 encodedDocMetadata = 3; // bit encoded document metadata + int32 htmlFeatures = 4; // bitmask encoding features of the document + repeated RpcResultKeywordScore keywordScores = 5; + bool hasPriorityTerms = 6; // true if this word is important to the document } /* Information about how well a keyword matches a query */ message RpcResultKeywordScore { - int32 subquery = 1; // index of the subquery this keyword relates to - string keyword = 2; // the keyword - int64 encodedWordMetadata = 3; // bit encoded word metadata - int64 encodedDocMetadata = 4; // bit encoded document metadata - bool hasPriorityTerms = 5; // true if this word is important to the document - int32 htmlFeatures = 6; // bit encoded document features + string keyword = 1; // the keyword + int64 encodedWordMetadata = 2; // bit encoded word metadata } /* Query execution parameters */ @@ -137,12 +137,13 @@ message RpcResultRankingParameters { } /* Defines a single subquery */ -message RpcSubquery { +message RpcQuery { repeated string include = 1; // These terms must be present repeated string exclude = 2; // These terms must be absent repeated string advice = 3; // These terms must be present, but do not affect ranking repeated string priority = 4; // These terms are not mandatory, but affect ranking positively if they are present repeated RpcCoherences coherences = 5; // Groups of terms that must exist in proximity of each other + string compiledQuery = 6; // Compiled query in infix notation } /* Defines a group of search terms that must exist in close proximity within the document */ diff --git a/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParserTest.java b/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParserTest.java new file mode 100644 index 00000000..47983820 --- /dev/null +++ b/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/CompiledQueryParserTest.java @@ -0,0 +1,79 @@ +package nu.marginalia.api.searchquery.model.compiled; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class CompiledQueryParserTest { + + @Test + public void testEmpty() { + assertEquals(CqExpression.empty(), CompiledQueryParser.parse("").root); + assertEquals(CqExpression.empty(), CompiledQueryParser.parse("( )").root); + assertEquals(CqExpression.empty(), CompiledQueryParser.parse("( | )").root); + assertEquals(CqExpression.empty(), CompiledQueryParser.parse("| ( | ) |").root); + } + + @Test + public void testSingleWord() { + CompiledQuery q = CompiledQueryParser.parse("foo"); + assertEquals(w(q, "foo"), q.root); + } + + @Test + public void testAndTwoWords() { + CompiledQuery q = CompiledQueryParser.parse("foo bar"); + assertEquals(and(w(q, "foo"), w(q,"bar")), q.root); + } + + @Test + public void testOrTwoWords() { + CompiledQuery q = CompiledQueryParser.parse("foo | bar"); + assertEquals(or(w(q, "foo"), w(q,"bar")), q.root); + } + + @Test + public void testOrAndWords() { + CompiledQuery q = CompiledQueryParser.parse("foo | bar baz"); + assertEquals(or(w(q,"foo"), and(w(q,"bar"), w(q,"baz"))), q.root); + } + + @Test + public void testAndAndOrAndAndWords() { + CompiledQuery q = CompiledQueryParser.parse("foo foobar | bar baz"); + assertEquals(or( + and(w(q, "foo"), w(q, "foobar")), + and(w(q, "bar"), w(q, "baz"))) + , q.root); + } + @Test + public void testComplex1() { + CompiledQuery q = CompiledQueryParser.parse("foo ( bar | baz ) quux"); + assertEquals(and(w(q,"foo"), or(w(q, "bar"), w(q, "baz")), w(q, "quux")), q.root); + } + @Test + public void testComplex2() { + CompiledQuery q = CompiledQueryParser.parse("( ( ( a ) b ) c ) d"); + assertEquals(and(and(and(w(q, "a"), w(q, "b")), w(q, "c")), w(q, "d")), q.root); + } + + @Test + public void testNested() { + CompiledQuery q = CompiledQueryParser.parse("( ( ( a ) ) )"); + assertEquals(w(q,"a"), q.root); + } + + private CqExpression.Word w(CompiledQuery query, String word) { + return new CqExpression.Word(query.indices().filter(idx -> word.equals(query.at(idx))).findAny().orElseThrow()); + } + + private CqExpression and(CqExpression... parts) { + return new CqExpression.And(List.of(parts)); + } + + private CqExpression or(CqExpression... parts) { + return new CqExpression.Or(List.of(parts)); + } +} \ No newline at end of file diff --git a/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregatesTest.java b/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregatesTest.java new file mode 100644 index 00000000..c3e36180 --- /dev/null +++ b/code/functions/search-query/api/test/nu/marginalia/api/searchquery/model/compiled/aggregate/CompiledQueryAggregatesTest.java @@ -0,0 +1,35 @@ +package nu.marginalia.api.searchquery.model.compiled.aggregate; + +import static nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser.parse; +import static nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates.*; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class CompiledQueryAggregatesTest { + + @Test + void booleanAggregates() { + assertFalse(booleanAggregate(parse("false"), Boolean::parseBoolean)); + assertTrue(booleanAggregate(parse("true"), Boolean::parseBoolean)); + assertFalse(booleanAggregate(parse("false true"), Boolean::parseBoolean)); + assertTrue(booleanAggregate(parse("( true ) | ( true false )"), Boolean::parseBoolean)); + assertTrue(booleanAggregate(parse("( false ) | ( true )"), Boolean::parseBoolean)); + assertTrue(booleanAggregate(parse("( true false ) | ( true true )"), Boolean::parseBoolean)); + assertFalse(booleanAggregate(parse("( true false ) | ( true false )"), Boolean::parseBoolean)); + } + + @Test + void intMaxMinAggregates() { + assertEquals(5, intMaxMinAggregate(parse("5"), Integer::parseInt)); + assertEquals(3, intMaxMinAggregate(parse("5 3"), Integer::parseInt)); + assertEquals(6, intMaxMinAggregate(parse("5 3 | 6 7"), Integer::parseInt)); + } + + @Test + void doubleSumAggregates() { + assertEquals(5, (int) doubleSumAggregate(parse("5"), Double::parseDouble)); + assertEquals(8, (int) doubleSumAggregate(parse("5 3"), Double::parseDouble)); + assertEquals(13, (int) doubleSumAggregate(parse("1 ( 5 3 | 2 10 )"), Double::parseDouble)); + } +} \ No newline at end of file diff --git a/code/functions/search-query/api/test/nu/marginalia/index/client/IndexProtobufCodecTest.java b/code/functions/search-query/api/test/nu/marginalia/index/client/IndexProtobufCodecTest.java index 1782765d..e93f715c 100644 --- a/code/functions/search-query/api/test/nu/marginalia/index/client/IndexProtobufCodecTest.java +++ b/code/functions/search-query/api/test/nu/marginalia/index/client/IndexProtobufCodecTest.java @@ -1,7 +1,7 @@ package nu.marginalia.index.client; import nu.marginalia.api.searchquery.IndexProtobufCodec; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.query.limit.QueryLimits; import nu.marginalia.index.query.limit.SpecificationLimit; @@ -35,14 +35,15 @@ class IndexProtobufCodecTest { } @Test public void testSubqery() { - verifyIsIdentityTransformation(new SearchSubquery( + verifyIsIdentityTransformation(new SearchQuery( + "qs", List.of("a", "b"), List.of("c", "d"), List.of("e", "f"), List.of("g", "h"), List.of(List.of("i", "j"), List.of("k")) ), - s -> IndexProtobufCodec.convertSearchSubquery(IndexProtobufCodec.convertSearchSubquery(s)) + s -> IndexProtobufCodec.convertRpcQuery(IndexProtobufCodec.convertRpcQuery(s)) ); } private void verifyIsIdentityTransformation(T val, Function transformation) { diff --git a/code/functions/search-query/build.gradle b/code/functions/search-query/build.gradle index dc1f9c4c..7b792b48 100644 --- a/code/functions/search-query/build.gradle +++ b/code/functions/search-query/build.gradle @@ -26,6 +26,9 @@ dependencies { implementation project(':code:libraries:term-frequency-dict') implementation project(':third-party:porterstemmer') + implementation project(':third-party:openzim') + implementation project(':third-party:commons-codec') + implementation project(':code:libraries:language-processing') implementation project(':code:libraries:term-frequency-dict') implementation project(':code:features-convert:keyword-extraction') @@ -36,6 +39,8 @@ dependencies { implementation libs.bundles.grpc implementation libs.notnull implementation libs.guice + implementation libs.jsoup + implementation libs.commons.lang3 implementation libs.trove implementation libs.fastutil implementation libs.bundles.gson diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java new file mode 100644 index 00000000..d4e324fa --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java @@ -0,0 +1,181 @@ +package nu.marginalia.functions.searchquery.query_parser; + +import ca.rmen.porterstemmer.PorterStemmer; +import com.google.inject.Inject; +import nu.marginalia.functions.searchquery.query_parser.model.QWord; +import nu.marginalia.functions.searchquery.query_parser.model.QWordGraph; +import nu.marginalia.functions.searchquery.query_parser.model.QWordPathsRenderer; +import nu.marginalia.segmentation.NgramLexicon; +import nu.marginalia.term_frequency_dict.TermFrequencyDict; +import org.apache.commons.lang3.StringUtils; + +import java.util.*; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** Responsible for expanding a query, that is creating alternative branches of query execution + * to increase the number of results + */ +public class QueryExpansion { + private static final PorterStemmer ps = new PorterStemmer(); + private final TermFrequencyDict dict; + private final NgramLexicon lexicon; + + private final List expansionStrategies = List.of( + this::joinDashes, + this::splitWordNum, + this::joinTerms, + this::createSegments + ); + + @Inject + public QueryExpansion(TermFrequencyDict dict, + NgramLexicon lexicon + ) { + this.dict = dict; + this.lexicon = lexicon; + } + + public String expandQuery(List words) { + + QWordGraph graph = new QWordGraph(words); + + for (var strategy : expansionStrategies) { + strategy.expand(graph); + } + + return QWordPathsRenderer.render(graph); + } + + private static final Pattern dashPattern = Pattern.compile("-"); + private static final Pattern numWordBoundary = Pattern.compile("[0-9][a-zA-Z]|[a-zA-Z][0-9]"); + + // Turn 'lawn-chair' into 'lawnchair' + public void joinDashes(QWordGraph graph) { + for (var qw : graph) { + if (qw.word().contains("-")) { + var joined = StringUtils.join(dashPattern.split(qw.word())); + graph.addVariant(qw, joined); + } + } + } + + + // Turn 'MP3' into 'MP-3' + public void splitWordNum(QWordGraph graph) { + for (var qw : graph) { + var matcher = numWordBoundary.matcher(qw.word()); + if (matcher.matches()) { + var joined = StringUtils.join(dashPattern.split(qw.word()), '-'); + graph.addVariant(qw, joined); + } + } + } + + // Turn 'lawn chair' into 'lawnchair' + public void joinTerms(QWordGraph graph) { + QWord prev = null; + + for (var qw : graph) { + if (prev != null) { + var joinedWord = prev.word() + qw.word(); + var joinedStemmed = ps.stemWord(joinedWord); + + var scoreA = dict.getTermFreqStemmed(prev.stemmed()); + var scoreB = dict.getTermFreqStemmed(qw.stemmed()); + + var scoreCombo = dict.getTermFreqStemmed(joinedStemmed); + + if (scoreCombo > scoreA + scoreB || scoreCombo > 1000) { + graph.addVariantForSpan(prev, qw, joinedWord); + } + } + + prev = qw; + } + } + + /** Create an alternative interpretation of the query that replaces a sequence of words + * with a word n-gram. This makes it so that when possible, the order of words in the document + * matches the order of the words in the query. + */ + public void createSegments(QWordGraph graph) { + List nodes = new ArrayList<>(); + + for (var qw : graph) { + nodes.add(qw); + } + + String[] words = nodes.stream().map(QWord::stemmed).toArray(String[]::new); + + // Grab all segments + + List allSegments = new ArrayList<>(); + for (int length = 2; length < Math.min(10, words.length); length++) { + allSegments.addAll(lexicon.findSegmentOffsets(length, words)); + } + allSegments.sort(Comparator.comparing(NgramLexicon.SentenceSegment::start)); + + if (allSegments.isEmpty()) { + return; + } + + Set bestSegmentation = + findBestSegmentation(allSegments); + + for (var segment : bestSegmentation) { + + int start = segment.start(); + int end = segment.start() + segment.length(); + + var word = IntStream.range(start, end) + .mapToObj(nodes::get) + .map(QWord::word) + .collect(Collectors.joining("_")); + + graph.addVariantForSpan(nodes.get(start), nodes.get(end - 1), word); + } + + } + + private Set findBestSegmentation(List allSegments) { + Set bestSet = Set.of(); + double bestScore = Double.MIN_VALUE; + + for (int i = 0; i < allSegments.size(); i++) { + Set parts = new HashSet<>(); + parts.add(allSegments.get(i)); + + outer: + for (int j = i+1; j < allSegments.size(); j++) { + var candidate = allSegments.get(j); + for (var part : parts) { + if (part.overlaps(candidate)) { + continue outer; + } + } + parts.add(candidate); + } + + double score = 0.; + for (var part : parts) { + // |s|^|s|-normalization per M Hagen et al + double normFactor = Math.pow(part.length(), part.length()); + + score += normFactor * part.count(); + } + + if (bestScore < score) { + bestScore = score; + bestSet = parts; + } + } + + return bestSet; + } + + public interface ExpansionStrategy { + void expand(QWordGraph graph); + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryParser.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryParser.java index bbaf5c87..3f92a594 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryParser.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryParser.java @@ -1,8 +1,7 @@ package nu.marginalia.functions.searchquery.query_parser; +import nu.marginalia.functions.searchquery.query_parser.token.QueryToken; import nu.marginalia.language.WordPatterns; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenType; import nu.marginalia.util.transform_list.TransformList; import java.util.List; @@ -11,95 +10,126 @@ public class QueryParser { private final QueryTokenizer tokenizer = new QueryTokenizer(); - public List parse(String query) { - List basicTokens = tokenizer.tokenizeQuery(query); + public List parse(String query) { + List basicTokens = tokenizer.tokenizeQuery(query); - TransformList list = new TransformList<>(basicTokens); + TransformList list = new TransformList<>(basicTokens); list.transformEach(QueryParser::handleQuoteTokens); list.transformEach(QueryParser::trimLiterals); list.transformEachPair(QueryParser::createNegatedTerms); list.transformEachPair(QueryParser::createPriorityTerms); list.transformEach(QueryParser::handleSpecialOperations); - list.scanAndTransform(TokenType.LPAREN, TokenType.RPAREN, QueryParser::handleAdvisoryTerms); + list.scanAndTransform(QueryToken.LParen.class::isInstance, QueryToken.RParen.class::isInstance, QueryParser::handleAdvisoryTerms); + list.transformEach(QueryParser::normalizeDomainName); return list.getBackingList(); } - private static void handleQuoteTokens(TransformList.Entity entity) { - var t = entity.value(); - if (t.type == TokenType.QUOT) { - entity.replace(new Token(TokenType.QUOT_TERM, - t.str.replaceAll("\\s+", WordPatterns.WORD_TOKEN_JOINER), - t.displayStr)); - } - } - - private static void trimLiterals(TransformList.Entity entity) { + private static void normalizeDomainName(TransformList.Entity entity) { var t = entity.value(); - if (t.type == TokenType.LITERAL_TERM - && (t.str.endsWith(":") || t.str.endsWith(".")) - && t.str.length() > 1) { - entity.replace(new Token(TokenType.LITERAL_TERM, t.str.substring(0, t.str.length() - 1), t.displayStr)); + if (!(t instanceof QueryToken.LiteralTerm)) + return; + + if (t.str().startsWith("site:")) { + entity.replace(new QueryToken.LiteralTerm(t.str().toLowerCase(), t.displayStr())); } } - private static void createNegatedTerms(TransformList.Entity first, TransformList.Entity second) { - var t = first.value(); - var tn = second.value(); - - if (t.type == TokenType.MINUS && tn.type == TokenType.LITERAL_TERM) { - first.remove(); - second.replace(new Token(TokenType.EXCLUDE_TERM, tn.str, "-" + tn.str)); - } - } - - private static void createPriorityTerms(TransformList.Entity first, TransformList.Entity second) { - var t = first.value(); - var tn = second.value(); - - if (t.type == TokenType.QMARK && tn.type == TokenType.LITERAL_TERM) { - first.remove(); - second.replace(new Token(TokenType.PRIORTY_TERM, tn.str, "?" + tn.str)); - } - } - - private static void handleSpecialOperations(TransformList.Entity entity) { + private static void handleQuoteTokens(TransformList.Entity entity) { var t = entity.value(); - if (t.type != TokenType.LITERAL_TERM) { + + if (!(t instanceof QueryToken.Quot)) { return; } - if (t.str.startsWith("q") && t.str.matches("q[=><]\\d+")) { - entity.replace(new Token(TokenType.QUALITY_TERM, t.str.substring(1), t.displayStr)); - } else if (t.str.startsWith("near:")) { - entity.replace(new Token(TokenType.NEAR_TERM, t.str.substring(5), t.displayStr)); - } else if (t.str.startsWith("year") && t.str.matches("year[=><]\\d{4}")) { - entity.replace(new Token(TokenType.YEAR_TERM, t.str.substring(4), t.displayStr)); - } else if (t.str.startsWith("size") && t.str.matches("size[=><]\\d+")) { - entity.replace(new Token(TokenType.SIZE_TERM, t.str.substring(4), t.displayStr)); - } else if (t.str.startsWith("rank") && t.str.matches("rank[=><]\\d+")) { - entity.replace(new Token(TokenType.RANK_TERM, t.str.substring(4), t.displayStr)); - } else if (t.str.startsWith("qs=")) { - entity.replace(new Token(TokenType.QS_TERM, t.str.substring(3), t.displayStr)); - } else if (t.str.contains(":")) { - entity.replace(new Token(TokenType.ADVICE_TERM, t.str, t.displayStr)); - } + entity.replace(new QueryToken.QuotTerm( + t.str().replaceAll("\\s+", WordPatterns.WORD_TOKEN_JOINER), + t.displayStr())); } - private static void handleAdvisoryTerms(TransformList.Entity entity) { + private static void trimLiterals(TransformList.Entity entity) { var t = entity.value(); - if (t.type == TokenType.LPAREN) { - entity.remove(); - } else if (t.type == TokenType.RPAREN) { - entity.remove(); - } else if (t.type == TokenType.LITERAL_TERM) { - entity.replace(new Token(TokenType.ADVICE_TERM, t.str, "(" + t.str + ")")); + + if (!(t instanceof QueryToken.LiteralTerm lt)) + return; + + String str = lt.str(); + if (str.isBlank()) + return; + + if (str.endsWith(":") || str.endsWith(".")) { + entity.replace(new QueryToken.LiteralTerm(str.substring(0, str.length() - 1), lt.displayStr())); + } + + } + + private static void createNegatedTerms(TransformList.Entity first, TransformList.Entity second) { + var t = first.value(); + var tn = second.value(); + + if (!(t instanceof QueryToken.Minus)) + return; + if (!(tn instanceof QueryToken.LiteralTerm) && !(tn instanceof QueryToken.AdviceTerm)) + return; + + first.remove(); + + second.replace(new QueryToken.ExcludeTerm(tn.str(), "-" + tn.displayStr())); + } + + private static void createPriorityTerms(TransformList.Entity first, TransformList.Entity second) { + var t = first.value(); + var tn = second.value(); + + if (!(t instanceof QueryToken.QMark)) + return; + if (!(tn instanceof QueryToken.LiteralTerm) && !(tn instanceof QueryToken.AdviceTerm)) + return; + + var replacement = new QueryToken.PriorityTerm(tn.str(), "?" + tn.displayStr()); + + first.remove(); + second.replace(replacement); + } + + private static void handleSpecialOperations(TransformList.Entity entity) { + var t = entity.value(); + if (!(t instanceof QueryToken.LiteralTerm)) { + return; + } + + String str = t.str(); + + if (str.startsWith("q") && str.matches("q[=><]\\d+")) { + entity.replace(new QueryToken.QualityTerm(str.substring(1))); + } else if (str.startsWith("near:")) { + entity.replace(new QueryToken.NearTerm(str.substring(5))); + } else if (str.startsWith("year") && str.matches("year[=><]\\d{4}")) { + entity.replace(new QueryToken.YearTerm(str.substring(4))); + } else if (str.startsWith("size") && str.matches("size[=><]\\d+")) { + entity.replace(new QueryToken.SizeTerm(str.substring(4))); + } else if (str.startsWith("rank") && str.matches("rank[=><]\\d+")) { + entity.replace(new QueryToken.RankTerm(str.substring(4))); + } else if (str.startsWith("qs=")) { + entity.replace(new QueryToken.QsTerm(str.substring(3))); + } else if (str.contains(":")) { + entity.replace(new QueryToken.AdviceTerm(str, t.displayStr())); } } + private static void handleAdvisoryTerms(TransformList.Entity entity) { + var t = entity.value(); + if (t instanceof QueryToken.LParen) { + entity.remove(); + } else if (t instanceof QueryToken.RParen) { + entity.remove(); + } else if (t instanceof QueryToken.LiteralTerm) { + entity.replace(new QueryToken.AdviceTerm(t.str(), "(" + t.displayStr() + ")")); + } + } } diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryPermutation.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryPermutation.java deleted file mode 100644 index 417ceda3..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryPermutation.java +++ /dev/null @@ -1,229 +0,0 @@ -package nu.marginalia.functions.searchquery.query_parser; - -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenType; -import nu.marginalia.language.WordPatterns; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.function.Predicate; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -import static java.util.stream.Stream.concat; - -public class QueryPermutation { - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final QueryVariants queryVariants; - - public static final Pattern wordPattern = Pattern.compile("[#]?[_@.a-zA-Z0-9'+\\-\\u00C0-\\u00D6\\u00D8-\\u00f6\\u00f8-\\u00ff]+[#]?"); - public static final Pattern wordAppendixPattern = Pattern.compile("[.]?[0-9a-zA-Z\\u00C0-\\u00D6\\u00D8-\\u00f6\\u00f8-\\u00ff]{1,3}[0-9]?"); - - public static final Predicate wordQualitiesPredicate = wordPattern.asMatchPredicate(); - - public static final Predicate wordAppendixPredicate = wordAppendixPattern.asMatchPredicate(); - public static final Predicate wordPredicateEither = wordQualitiesPredicate.or(wordAppendixPredicate); - - public QueryPermutation(QueryVariants queryVariants) { - this.queryVariants = queryVariants; - } - - public List> permuteQueries(List items) { - int start = -1; - int end = items.size(); - - for (int i = 0; i < items.size(); i++) { - var token = items.get(i); - - if (start < 0) { - if (token.type == TokenType.LITERAL_TERM && wordQualitiesPredicate.test(token.str)) { - start = i; - } - } - else { - if (token.type != TokenType.LITERAL_TERM || !wordPredicateEither.test(token.str)) { - end = i; - break; - } - } - } - - if (start >= 0 && end - start > 1) { - List> permuteParts = combineSearchTerms(items.subList(start, end)); - int s = start; - int e = end; - return permuteParts.stream().map(part -> - concat(items.subList(0, s).stream(), concat(part.stream(), items.subList(e, items.size()).stream())) - .collect(Collectors.toList())) - .peek(lst -> lst.removeIf(this::isJunkWord)) - .limit(24) - .collect(Collectors.toList()); - } - else { - return List.of(items); - } - } - - - public List> permuteQueriesNew(List items) { - int start = -1; - int end = items.size(); - - for (int i = 0; i < items.size(); i++) { - var token = items.get(i); - - if (start < 0) { - if (token.type == TokenType.LITERAL_TERM && wordQualitiesPredicate.test(token.str)) { - start = i; - } - } - else { - if (token.type != TokenType.LITERAL_TERM || !wordPredicateEither.test(token.str)) { - end = i; - break; - } - } - } - - if (start >= 0 && end - start >= 1) { - var result = queryVariants.getQueryVariants(items.subList(start, end)); - - logger.debug("{}", result); - - if (result.isEmpty()) { - logger.warn("Empty variants result, falling back on old code"); - return permuteQueries(items); - } - - List> queryVariants = new ArrayList<>(); - for (var query : result.faithful) { - var tokens = query.terms.stream().map(term -> new Token(TokenType.LITERAL_TERM, term)).collect(Collectors.toList()); - tokens.addAll(result.nonLiterals); - - queryVariants.add(tokens); - } - for (var query : result.alternative) { - if (queryVariants.size() >= 6) - break; - - var tokens = query.terms.stream().map(term -> new Token(TokenType.LITERAL_TERM, term)).collect(Collectors.toList()); - tokens.addAll(result.nonLiterals); - - queryVariants.add(tokens); - } - - List> returnValue = new ArrayList<>(queryVariants.size()); - for (var variant: queryVariants) { - List r = new ArrayList<>(start + variant.size() + (items.size() - end)); - r.addAll(items.subList(0, start)); - r.addAll(variant); - r.addAll(items.subList(end, items.size())); - returnValue.add(r); - } - - return returnValue; - - } - else { - return List.of(items); - } - } - - private boolean isJunkWord(Token token) { - if (WordPatterns.isStopWord(token.str) && - !token.str.matches("^(\\d+|([a-z]+:.*))$")) { - return true; - } - return switch (token.str) { - case "vs", "versus", "or", "and" -> true; - default -> false; - }; - } - - private List> combineSearchTerms(List subList) { - int size = subList.size(); - if (size < 1) { - return Collections.emptyList(); - } - else if (size == 1) { - if (WordPatterns.isStopWord(subList.get(0).str)) { - return Collections.emptyList(); - } - return List.of(subList); - } - - List> results = new ArrayList<>(size*(size+1)/2); - - if (subList.size() <= 4 && subList.get(0).str.length() >= 2 && !isPrefixWord(subList.get(subList.size()-1).str)) { - results.add(List.of(joinTokens(subList))); - } - outer: for (int i = size - 1; i >= 1; i--) { - - var left = combineSearchTerms(subList.subList(0, i)); - var right = combineSearchTerms(subList.subList(i, size)); - - for (var l : left) { - if (results.size() > 48) { - break outer; - } - - for (var r : right) { - if (results.size() > 48) { - break outer; - } - - List combined = new ArrayList<>(l.size() + r.size()); - combined.addAll(l); - combined.addAll(r); - if (!results.contains(combined)) { - results.add(combined); - } - } - } - } - if (!results.contains(subList)) { - results.add(subList); - } - Comparator> tc = (o1, o2) -> { - int dJoininess = o2.stream().mapToInt(s->(int)Math.pow(joininess(s.str), 2)).sum() - - o1.stream().mapToInt(s->(int)Math.pow(joininess(s.str), 2)).sum(); - if (dJoininess == 0) { - return (o2.stream().mapToInt(s->(int)Math.pow(rightiness(s.str), 2)).sum() - - o1.stream().mapToInt(s->(int)Math.pow(rightiness(s.str), 2)).sum()); - } - return (int) Math.signum(dJoininess); - }; - results.sort(tc); - return results; - } - - private boolean isPrefixWord(String str) { - return switch (str) { - case "the", "of", "when" -> true; - default -> false; - }; - } - - int joininess(String s) { - return (int) s.chars().filter(c -> c == '_').count(); - } - int rightiness(String s) { - int rightiness = 0; - for (int i = 0; i < s.length(); i++) { - if (s.charAt(i) == '_') { - rightiness+=i; - } - } - return rightiness; - } - - private Token joinTokens(List subList) { - return new Token(TokenType.LITERAL_TERM, - subList.stream().map(t -> t.str).collect(Collectors.joining("_")), - subList.stream().map(t -> t.str).collect(Collectors.joining(" "))); - } -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryTokenizer.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryTokenizer.java index b7b0a2b7..b12d68a9 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryTokenizer.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryTokenizer.java @@ -1,7 +1,6 @@ package nu.marginalia.functions.searchquery.query_parser; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenType; +import nu.marginalia.functions.searchquery.query_parser.token.QueryToken; import nu.marginalia.language.encoding.AsciiFlattener; import java.util.ArrayList; @@ -11,8 +10,8 @@ import java.util.regex.Pattern; public class QueryTokenizer { private static final Pattern noisePattern = Pattern.compile("[,\\s]"); - public List tokenizeQuery(String rawQuery) { - List tokens = new ArrayList<>(); + public List tokenizeQuery(String rawQuery) { + List tokens = new ArrayList<>(); String query = AsciiFlattener.flattenUnicode(rawQuery); query = noisePattern.matcher(query).replaceAll(" "); @@ -21,26 +20,27 @@ public class QueryTokenizer { int chr = query.charAt(i); if ('(' == chr) { - tokens.add(new Token(TokenType.LPAREN, "(", "(")); + tokens.add(new QueryToken.LParen()); } else if (')' == chr) { - tokens.add(new Token(TokenType.RPAREN, ")", ")")); + tokens.add(new QueryToken.RParen()); } else if ('"' == chr) { int end = query.indexOf('"', i+1); + if (end == -1) { end = query.length(); } - tokens.add(new Token(TokenType.QUOT, - query.substring(i+1, end).toLowerCase(), - query.substring(i, Math.min(query.length(), end+1)))); + + tokens.add(new QueryToken.Quot(query.substring(i + 1, end).toLowerCase())); + i = end; } else if ('-' == chr) { - tokens.add(new Token(TokenType.MINUS, "-")); + tokens.add(new QueryToken.Minus()); } else if ('?' == chr) { - tokens.add(new Token(TokenType.QMARK, "?")); + tokens.add(new QueryToken.QMark()); } else if (Character.isSpaceChar(chr)) { // @@ -52,9 +52,12 @@ public class QueryTokenizer { if (query.charAt(end) == ' ' || query.charAt(end) == ')') break; } - tokens.add(new Token(TokenType.LITERAL_TERM, - query.substring(i, end).toLowerCase(), - query.substring(i, end))); + + String displayStr = query.substring(i, end); + String str = displayStr.toLowerCase(); + + tokens.add(new QueryToken.LiteralTerm(str, displayStr)); + i = end-1; } } diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryVariants.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryVariants.java deleted file mode 100644 index 9732e53f..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryVariants.java +++ /dev/null @@ -1,378 +0,0 @@ -package nu.marginalia.functions.searchquery.query_parser; - -import ca.rmen.porterstemmer.PorterStemmer; -import lombok.AllArgsConstructor; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenType; -import nu.marginalia.util.language.EnglishDictionary; -import nu.marginalia.LanguageModels; -import nu.marginalia.keyword.KeywordExtractor; -import nu.marginalia.language.sentence.SentenceExtractor; -import nu.marginalia.util.ngrams.NGramBloomFilter; -import nu.marginalia.term_frequency_dict.TermFrequencyDict; -import nu.marginalia.language.model.DocumentSentence; -import nu.marginalia.language.model.WordSpan; - -import java.util.*; -import java.util.regex.Pattern; - -public class QueryVariants { - private final KeywordExtractor keywordExtractor; - private final TermFrequencyDict dict; - private final PorterStemmer ps = new PorterStemmer(); - - private final NGramBloomFilter nGramBloomFilter; - private final EnglishDictionary englishDictionary; - private final ThreadLocal sentenceExtractor; - - public QueryVariants(LanguageModels lm, - TermFrequencyDict dict, - NGramBloomFilter nGramBloomFilter, - EnglishDictionary englishDictionary) { - this.nGramBloomFilter = nGramBloomFilter; - this.englishDictionary = englishDictionary; - this.keywordExtractor = new KeywordExtractor(); - this.sentenceExtractor = ThreadLocal.withInitial(() -> new SentenceExtractor(lm)); - this.dict = dict; - } - - - final Pattern numWordBoundary = Pattern.compile("[0-9][a-zA-Z]|[a-zA-Z][0-9]"); - final Pattern dashBoundary = Pattern.compile("-"); - - @AllArgsConstructor - private static class Word { - public final String stemmed; - public final String word; - public final String wordOriginal; - } - - @AllArgsConstructor @Getter @ToString @EqualsAndHashCode - public static class QueryVariant { - public final List terms; - public final double value; - } - - @Getter @ToString - public static class QueryVariantSet { - final List faithful = new ArrayList<>(); - final List alternative = new ArrayList<>(); - - final List nonLiterals = new ArrayList<>(); - - public boolean isEmpty() { - return faithful.isEmpty() && alternative.isEmpty() && nonLiterals.isEmpty(); - } - } - - public QueryVariantSet getQueryVariants(List query) { - final JoinedQueryAndNonLiteralTokens joinedQuery = joinQuery(query); - - final TreeMap> byStart = new TreeMap<>(); - - var se = sentenceExtractor.get(); - var sentence = se.extractSentence(joinedQuery.joinedQuery); - - for (int i = 0; i < sentence.posTags.length; i++) { - if (sentence.posTags[i].startsWith("N") || sentence.posTags[i].startsWith("V")) { - sentence.posTags[i] = "NNP"; - } - else if ("JJ".equals(sentence.posTags[i]) || "CD".equals(sentence.posTags[i]) || sentence.posTags[i].startsWith("P")) { - sentence.posTags[i] = "NNP"; - sentence.setIsStopWord(i, false); - } - } - - for (var kw : keywordExtractor.getKeywordsFromSentence(sentence)) { - byStart.computeIfAbsent(kw.start, k -> new ArrayList<>()).add(kw); - } - - final List> livingSpans = new ArrayList<>(); - - var first = byStart.firstEntry(); - if (first == null) { - var span = new WordSpan(0, sentence.length()); - byStart.put(0, List.of(span)); - } - else if (first.getKey() > 0) { - List elongatedFirstWords = new ArrayList<>(first.getValue().size()); - - first.getValue().forEach(span -> { - elongatedFirstWords.add(new WordSpan(0, span.start)); - elongatedFirstWords.add(new WordSpan(0, span.end)); - }); - - byStart.put(0, elongatedFirstWords); - } - - final List> goodSpans = getWordSpans(byStart, sentence, livingSpans); - - List> faithfulQueries = new ArrayList<>(); - List> alternativeQueries = new ArrayList<>(); - - for (var ls : goodSpans) { - faithfulQueries.addAll(createTokens(ls)); - } - - for (var span : goodSpans) { - alternativeQueries.addAll(joinTerms(span)); - } - - for (var ls : goodSpans) { - var last = ls.get(ls.size() - 1); - - if (!last.wordOriginal.isBlank() && !Character.isUpperCase(last.wordOriginal.charAt(0))) { - var altLast = englishDictionary.getWordVariants(last.word); - for (String s : altLast) { - List newList = new ArrayList<>(ls.size()); - for (int i = 0; i < ls.size() - 1; i++) { - newList.add(ls.get(i).word); - } - newList.add(s); - alternativeQueries.add(newList); - } - } - - } - - QueryVariantSet returnValue = new QueryVariantSet(); - - returnValue.faithful.addAll(evaluateQueries(faithfulQueries)); - returnValue.alternative.addAll(evaluateQueries(alternativeQueries)); - - returnValue.faithful.sort(Comparator.comparing(QueryVariant::getValue)); - returnValue.alternative.sort(Comparator.comparing(QueryVariant::getValue)); - - returnValue.nonLiterals.addAll(joinedQuery.nonLiterals); - - return returnValue; - } - - final Pattern underscore = Pattern.compile("_"); - - private List evaluateQueries(List> queryStrings) { - Set variantsSet = new HashSet<>(); - List ret = new ArrayList<>(); - for (var lst : queryStrings) { - double q = 0; - for (var word : lst) { - String[] parts = underscore.split(word); - double qp = 0; - for (String part : parts) { - qp += 1./(1+ dict.getTermFreq(part)); - } - q += 1.0 / qp; - } - var qv = new QueryVariant(lst, q); - if (variantsSet.add(qv)) { - ret.add(qv); - } - } - return ret; - } - - private Collection> createTokens(List ls) { - List asTokens = new ArrayList<>(); - List> ret = new ArrayList<>(); - - - boolean dash = false; - boolean num = false; - - for (var span : ls) { - dash |= dashBoundary.matcher(span.word).find(); - num |= numWordBoundary.matcher(span.word).find(); - if (ls.size() == 1 || !isOmittableWord(span.word)) { - asTokens.add(span.word); - } - } - ret.add(asTokens); - - if (dash) { - ret.addAll(combineDashWords(ls)); - } - - if (num) { - ret.addAll(splitWordNum(ls)); - } - - return ret; - } - - private boolean isOmittableWord(String word) { - return switch (word) { - case "vs", "or", "and", "versus", "is", "the", "why", "when", "if", "who", "are", "am" -> true; - default -> false; - }; - } - - private Collection> splitWordNum(List ls) { - List asTokens2 = new ArrayList<>(); - - boolean num = false; - - for (var span : ls) { - var wordMatcher = numWordBoundary.matcher(span.word); - var stemmedMatcher = numWordBoundary.matcher(span.stemmed); - - int ws = 0; - int ss = 0; - boolean didSplit = false; - while (wordMatcher.find(ws) && stemmedMatcher.find(ss)) { - ws = wordMatcher.start()+1; - ss = stemmedMatcher.start()+1; - if (nGramBloomFilter.isKnownNGram(splitAtNumBoundary(span.word, stemmedMatcher.start(), "_")) - || nGramBloomFilter.isKnownNGram(splitAtNumBoundary(span.word, stemmedMatcher.start(), "-"))) - { - String combined = splitAtNumBoundary(span.word, wordMatcher.start(), "_"); - asTokens2.add(combined); - didSplit = true; - num = true; - } - } - - if (!didSplit) { - asTokens2.add(span.word); - } - } - - if (num) { - return List.of(asTokens2); - } - return Collections.emptyList(); - } - - private Collection> combineDashWords(List ls) { - List asTokens2 = new ArrayList<>(); - boolean dash = false; - - for (var span : ls) { - var matcher = dashBoundary.matcher(span.word); - if (matcher.find() && nGramBloomFilter.isKnownNGram(ps.stemWord(dashBoundary.matcher(span.word).replaceAll("")))) { - dash = true; - String combined = dashBoundary.matcher(span.word).replaceAll(""); - asTokens2.add(combined); - } - else { - asTokens2.add(span.word); - } - } - - if (dash) { - return List.of(asTokens2); - } - return Collections.emptyList(); - } - - private String splitAtNumBoundary(String in, int splitPoint, String joiner) { - return in.substring(0, splitPoint+1) + joiner + in.substring(splitPoint+1); - } - - private List> getWordSpans(TreeMap> byStart, DocumentSentence sentence, List> livingSpans) { - List> goodSpans = new ArrayList<>(); - for (int i = 0; i < 1; i++) { - var spans = byStart.get(i); - - - if (spans == null ) - continue; - - for (var span : spans) { - ArrayList fragment = new ArrayList<>(); - fragment.add(span); - livingSpans.add(fragment); - } - - if (sentence.posTags[i].startsWith("N") || sentence.posTags[i].startsWith("V")) break; - } - - - while (!livingSpans.isEmpty()) { - - final List> newLivingSpans = new ArrayList<>(livingSpans.size()); - - for (var span : livingSpans) { - int end = span.get(span.size()-1).end; - - if (end == sentence.length()) { - var gs = new ArrayList(span.size()); - for (var s : span) { - gs.add(new Word(sentence.constructStemmedWordFromSpan(s), sentence.constructWordFromSpan(s), - s.size() == 1 ? sentence.words[s.start] : "")); - } - goodSpans.add(gs); - } - var nextWordsKey = byStart.ceilingKey(end); - - if (null == nextWordsKey) - continue; - - for (var next : byStart.get(nextWordsKey)) { - var newSpan = new ArrayList(span.size() + 1); - newSpan.addAll(span); - newSpan.add(next); - newLivingSpans.add(newSpan); - } - } - - livingSpans.clear(); - livingSpans.addAll(newLivingSpans); - } - - return goodSpans; - } - - private List> joinTerms(List span) { - List> ret = new ArrayList<>(); - - for (int i = 0; i < span.size()-1; i++) { - var a = span.get(i); - var b = span.get(i+1); - - var stemmed = ps.stemWord(a.word + b.word); - - double scoreCombo = dict.getTermFreqStemmed(stemmed); - if (scoreCombo > 10000) { - List asTokens = new ArrayList<>(); - - for (int j = 0; j < i; j++) { - var word = span.get(j).word; - asTokens.add(word); - } - { - var word = a.word + b.word; - asTokens.add(word); - } - for (int j = i+2; j < span.size(); j++) { - var word = span.get(j).word; - asTokens.add(word); - } - - ret.add(asTokens); - } - } - - return ret; - } - - private JoinedQueryAndNonLiteralTokens joinQuery(List query) { - StringJoiner s = new StringJoiner(" "); - List leftovers = new ArrayList<>(5); - - for (var t : query) { - if (t.type == TokenType.LITERAL_TERM) { - s.add(t.displayStr); - } - else { - leftovers.add(t); - } - } - - return new JoinedQueryAndNonLiteralTokens(s.toString(), leftovers); - } - - record JoinedQueryAndNonLiteralTokens(String joinedQuery, List nonLiterals) {} -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWord.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWord.java new file mode 100644 index 00000000..eac2e68b --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWord.java @@ -0,0 +1,51 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import ca.rmen.porterstemmer.PorterStemmer; + +public record QWord( + int ord, + boolean variant, + String stemmed, + String word, + String original) +{ + + // These are special words that are not in the input, but are added to the graph, + // note the space around the ^ and $, to avoid collisions with real words + private static final String BEG_MARKER = " ^ "; + private static final String END_MARKER = " $ "; + + private static final PorterStemmer ps = new PorterStemmer(); + + public boolean isBeg() { + return word.equals(BEG_MARKER); + } + + public boolean isEnd() { + return word.equals(END_MARKER); + } + + public static QWord beg() { + return new QWord(Integer.MIN_VALUE, false, BEG_MARKER, BEG_MARKER, BEG_MARKER); + } + + public static QWord end() { + return new QWord(Integer.MAX_VALUE, false, END_MARKER, END_MARKER, END_MARKER); + } + + public boolean isOriginal() { + return !variant; + } + + public QWord(int ord, String word) { + this(ord, false, ps.stemWord(word), word, word); + } + + public QWord(int ord, QWord original, String word) { + this(ord, true, ps.stemWord(word), word, original.original); + } + + public String toString() { + return STR."q{\{word}}"; + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraph.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraph.java new file mode 100644 index 00000000..a8b1a768 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraph.java @@ -0,0 +1,267 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import org.jetbrains.annotations.NotNull; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Graph structure for constructing query variants. The graph should be a directed acyclic graph, + * with a single start node and a single end node, denoted by QWord.beg() and QWord.end() respectively. + *

+ * Naively, every path from the start to the end node should represent a valid query variant, although in + * practice it is desirable to be clever about how to evaluate the paths, to avoid a large number of queries + * being generated. + */ +public class QWordGraph implements Iterable { + + + public record QWordGraphLink(QWord from, QWord to) {} + + private final List links = new ArrayList<>(); + private final Map> fromTo = new HashMap<>(); + private final Map> toFrom = new HashMap<>(); + + private int wordId = 0; + + public QWordGraph(String... words) { + this(List.of(words)); + } + + public QWordGraph(List words) { + QWord beg = QWord.beg(); + QWord end = QWord.end(); + + var prev = beg; + + for (String s : words) { + var word = new QWord(wordId++, s); + addLink(prev, word); + prev = word; + } + + addLink(prev, end); + } + + public void addVariant(QWord original, String word) { + var siblings = getVariants(original); + if (siblings.stream().anyMatch(w -> w.word().equals(word))) + return; + + var newWord = new QWord(wordId++, original, word); + + for (var prev : getPrev(original)) + addLink(prev, newWord); + for (var next : getNext(original)) + addLink(newWord, next); + } + + public void addVariantForSpan(QWord first, QWord last, String word) { + var newWord = new QWord(wordId++, first, word); + + for (var prev : getPrev(first)) + addLink(prev, newWord); + for (var next : getNext(last)) + addLink(newWord, next); + } + + public List getVariants(QWord original) { + var prevNext = getPrev(original).stream() + .flatMap(prev -> getNext(prev).stream()) + .collect(Collectors.toSet()); + + return getNext(original).stream() + .flatMap(next -> getPrev(next).stream()) + .filter(prevNext::contains) + .collect(Collectors.toList()); + } + + + public void addLink(QWord from, QWord to) { + links.add(new QWordGraphLink(from, to)); + fromTo.computeIfAbsent(from, k -> new ArrayList<>()).add(to); + toFrom.computeIfAbsent(to, k -> new ArrayList<>()).add(from); + } + + public List links() { + return Collections.unmodifiableList(links); + } + + public List nodes() { + return links.stream() + .flatMap(l -> Stream.of(l.from(), l.to())) + .sorted(Comparator.comparing(QWord::ord)) + .distinct() + .collect(Collectors.toList()); + } + + public QWord node(String word) { + return nodes().stream() + .filter(n -> n.word().equals(word)) + .findFirst() + .orElseThrow(); + } + + public List getNext(QWord word) { + return fromTo.getOrDefault(word, List.of()); + } + public List getNextOriginal(QWord word) { + return fromTo.getOrDefault(word, List.of()) + .stream() + .filter(QWord::isOriginal) + .toList(); + } + + public List getPrev(QWord word) { + return toFrom.getOrDefault(word, List.of()); + } + public List getPrevOriginal(QWord word) { + return toFrom.getOrDefault(word, List.of()) + .stream() + .filter(QWord::isOriginal) + .toList(); + } + + public Map> forwardReachability() { + Map> ret = new HashMap<>(); + + Set edge = Set.of(QWord.beg()); + Set visited = new HashSet<>(); + + while (!edge.isEmpty()) { + Set next = new LinkedHashSet<>(); + + for (var w : edge) { + + for (var n : getNext(w)) { + var set = ret.computeIfAbsent(n, k -> new HashSet<>()); + + set.add(w); + set.addAll(ret.getOrDefault(w, Set.of())); + + next.add(n); + } + } + + next.removeAll(visited); + visited.addAll(next); + edge = next; + } + + return ret; + } + + public Map> reverseReachability() { + Map> ret = new HashMap<>(); + + Set edge = Set.of(QWord.end()); + Set visited = new HashSet<>(); + + while (!edge.isEmpty()) { + Set prev = new LinkedHashSet<>(); + + for (var w : edge) { + + for (var p : getPrev(w)) { + var set = ret.computeIfAbsent(p, k -> new HashSet<>()); + + set.add(w); + set.addAll(ret.getOrDefault(w, Set.of())); + + prev.add(p); + } + } + + prev.removeAll(visited); + visited.addAll(prev); + edge = prev; + } + + return ret; + } + + public record ReachabilityData(List sortedNodes, + Map sortOrder, + + Map> forward, + Map> reverse) + { + public Set forward(QWord node) { + return forward.getOrDefault(node, Set.of()); + } + public Set reverse(QWord node) { + return reverse.getOrDefault(node, Set.of()); + } + + public Comparator topologicalComparator() { + Comparator comp = Comparator.comparing(sortOrder::get); + return comp.thenComparing(QWord::ord); + } + + } + + /** Gather data about graph reachability, including the topological order of nodes */ + public ReachabilityData reachability() { + var forwardReachability = forwardReachability(); + var reverseReachability = reverseReachability(); + + List nodes = new ArrayList<>(nodes()); + nodes.sort(new SetMembershipComparator<>(forwardReachability)); + + Map topologicalOrder = new HashMap<>(); + for (int i = 0; i < nodes.size(); i++) { + topologicalOrder.put(nodes.get(i), i); + } + + return new ReachabilityData(nodes, topologicalOrder, forwardReachability, reverseReachability); + } + + static class SetMembershipComparator implements Comparator { + private final Map> membership; + + SetMembershipComparator(Map> membership) { + this.membership = membership; + } + + @Override + public int compare(T o1, T o2) { + return Boolean.compare(isIn(o1, o2), isIn(o2, o1)); + } + + private boolean isIn(T a, T b) { + return membership.getOrDefault(a, Set.of()).contains(b); + } + } + + public String compileToQuery() { + return QWordPathsRenderer.render(this); + } + public String compileToDot() { + StringBuilder sb = new StringBuilder(); + sb.append("digraph {\n"); + for (var link : links) { + sb.append(STR."\"\{link.from().word()}\" -> \"\{link.to.word()}\";\n"); + } + sb.append("}\n"); + return sb.toString(); + } + + @NotNull + @Override + public Iterator iterator() { + return new Iterator<>() { + QWord pos = QWord.beg(); + + @Override + public boolean hasNext() { + return !pos.isEnd(); + } + + @Override + public QWord next() { + pos = getNextOriginal(pos).getFirst(); + return pos; + } + }; + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphPathLister.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphPathLister.java new file mode 100644 index 00000000..f26c01f7 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphPathLister.java @@ -0,0 +1,57 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Objects; +import java.util.Set; + +/** Utility class for listing each path in a {@link QWordGraph}, from the beginning node to the end. + * Normally this would be a risk for combinatorial explosion, but in practice the graph will be constructed + * in a way that avoids this risk. + * */ +public class QWordGraphPathLister { + private final QWordGraph graph; + + public QWordGraphPathLister(QWordGraph graph) { + this.graph = graph; + } + + public static Set listPaths(QWordGraph graph) { + return new QWordGraphPathLister(graph).listPaths(); + } + + Set listPaths() { + + Set paths = new HashSet<>(); + listPaths(paths, new LinkedList<>(), QWord.beg(), QWord.end()); + return paths; + } + + void listPaths(Set acc, + LinkedList stack, + QWord start, + QWord end) + { + stack.addLast(start); + + if (Objects.equals(start, end)) { + var nodes = new HashSet<>(stack); + + // Remove the start and end nodes from the path, as these are + // not part of the query but merely used to simplify the construction + // of the graph + + nodes.remove(QWord.beg()); + nodes.remove(QWord.end()); + + acc.add(new QWordPath(nodes)); + } + else { + for (var next : graph.getNext(start)) { + listPaths(acc, stack, next, end); + } + } + + stack.removeLast(); + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPath.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPath.java new file mode 100644 index 00000000..daa2a1f1 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPath.java @@ -0,0 +1,68 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Represents a path of QWords in a QWordGraph. Since the order of operations when + * evaluating a query does not affect its semantics, only performance, the order of the + * nodes in the path is not significant; thus the path is represented with a set. + */ +public class QWordPath { + private final Set nodes; + + QWordPath(Collection nodes) { + this.nodes = new HashSet<>(nodes); + } + + public boolean contains(QWord node) { + return nodes.contains(node); + } + + /** Construct a new path by removing a word from the path. */ + public QWordPath without(QWord word) { + Set newNodes = new HashSet<>(nodes); + newNodes.remove(word); + return new QWordPath(newNodes); + } + + public Stream stream() { + return nodes.stream(); + } + + /** Construct a new path by projecting the path onto a set of nodes, such that + * the nodes in the new set is a strict subset of the provided nodes */ + public QWordPath project(Set nodes) { + return new QWordPath(this.nodes.stream().filter(nodes::contains).collect(Collectors.toSet())); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QWordPath wordPath = (QWordPath) o; + + return nodes.equals(wordPath.nodes); + } + + public boolean isEmpty() { + return nodes.isEmpty(); + } + + public int size() { + return nodes.size(); + } + + @Override + public int hashCode() { + return nodes.hashCode(); + } + + @Override + public String toString() { + return STR."WordPath{nodes=\{nodes}\{'}'}"; + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPathsRenderer.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPathsRenderer.java new file mode 100644 index 00000000..b1ee7956 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/model/QWordPathsRenderer.java @@ -0,0 +1,149 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import java.util.*; +import java.util.stream.Collectors; + +/** Renders a set of QWordPaths into a human-readable infix-style expression. It's not guaranteed to find + * the globally optimal expression, but rather uses a greedy algorithm as a tradeoff in effort to outcome. + */ +public class QWordPathsRenderer { + private final Set paths; + + private QWordPathsRenderer(Collection paths) { + this.paths = Set.copyOf(paths); + } + + private QWordPathsRenderer(QWordGraph graph) { + this.paths = Set.copyOf(QWordGraphPathLister.listPaths(graph)); + } + + public static String render(QWordGraph graph) { + return new QWordPathsRenderer(graph).render(graph.reachability()); + } + + + private static String render(Collection paths, + QWordGraph.ReachabilityData reachability) + { + return new QWordPathsRenderer(paths).render(reachability); + } + + /** Render the paths into a human-readable infix-style expression. + *

+ * This method is recursive, but the recursion depth is limited by the + * maximum length of the paths, which is hard limited to a value typically around 10, + * so we don't need to worry about stack overflows here... + */ + String render(QWordGraph.ReachabilityData reachability) { + if (paths.size() == 1) { + return paths.iterator().next().stream().map(QWord::word).collect(Collectors.joining(" ")); + } + + // Find the commonality of words in the paths + + Map commonality = nodeCommonality(paths); + + // Break the words into two categories: those that are common to all paths, and those that are not + + List commonToAll = new ArrayList<>(); + Set notCommonToAll = new HashSet<>(); + commonality.forEach((k, v) -> { + if (v == paths.size()) { + commonToAll.add(k); + } else { + notCommonToAll.add(k); + } + }); + + StringJoiner resultJoiner = new StringJoiner(" "); + + if (!commonToAll.isEmpty()) { // Case where one or more words are common to all paths + commonToAll.sort(reachability.topologicalComparator()); + + for (var word : commonToAll) { + resultJoiner.add(word.word()); + } + + // Deal portion of the paths that do not all share a common word + if (!notCommonToAll.isEmpty()) { + + List nonOverlappingPortions = new ArrayList<>(); + + // Create a new path for each path that does not contain the common words we just printed + for (var path : paths) { + var np = path.project(notCommonToAll); + if (np.isEmpty()) + continue; + nonOverlappingPortions.add(np); + } + + // Recurse into the non-overlapping portions + resultJoiner.add(render(nonOverlappingPortions, reachability)); + } + } else if (commonality.size() > 1) { // The case where no words are common to all paths + + + // Sort the words by commonality, so that we can consider the most common words first + Map> pathsByCommonWord = new HashMap<>(); + + // Mutable copy of the paths + List allDivergentPaths = new ArrayList<>(paths); + + // Break the paths into branches by the first common word they contain, in order of decreasing commonality + while (!allDivergentPaths.isEmpty()) { + QWord mostCommon = mostCommonQWord(allDivergentPaths); + + var iter = allDivergentPaths.iterator(); + while (iter.hasNext()) { + var path = iter.next(); + + if (!path.contains(mostCommon)) { + continue; + } + + // Remove the common word from the path + var newPath = path.without(mostCommon); + + pathsByCommonWord + .computeIfAbsent(mostCommon, k -> new ArrayList<>()) + .add(newPath); + + // Remove the path from the list of divergent paths since we've now accounted for it and + // we don't want redundant branches: + iter.remove(); + } + } + + var branches = pathsByCommonWord.entrySet().stream() + .sorted(Map.Entry.comparingByKey(reachability.topologicalComparator())) // Sort by topological order to ensure consistent output + .map(e -> { + String commonWord = e.getKey().word(); + + // Recurse into the branches: + String branchPart = render(e.getValue(), reachability); + + return STR."\{commonWord} \{branchPart}"; + }) + .collect(Collectors.joining(" | ", " ( ", " ) ")); + + resultJoiner.add(branches); + } + + // Remove any double spaces that may have been introduced + return resultJoiner.toString().replaceAll("\\s+", " ").trim(); + } + + /** Compute how many paths each word is part of */ + private static Map nodeCommonality(Collection paths) { + return paths.stream().flatMap(QWordPath::stream) + .collect(Collectors.groupingBy(w -> w, Collectors.summingInt(w -> 1))); + } + private static QWord mostCommonQWord(Collection paths) { + assert !paths.isEmpty(); + + return nodeCommonality(paths).entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElseThrow(); + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/QueryToken.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/QueryToken.java new file mode 100644 index 00000000..b11fe370 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/QueryToken.java @@ -0,0 +1,86 @@ +package nu.marginalia.functions.searchquery.query_parser.token; + + +public sealed interface QueryToken { + String str(); + String displayStr(); + + record LiteralTerm(String str, String displayStr) implements QueryToken {} + record QuotTerm(String str, String displayStr) implements QueryToken {} + record ExcludeTerm(String str, String displayStr) implements QueryToken {} + record AdviceTerm(String str, String displayStr) implements QueryToken {} + record PriorityTerm(String str, String displayStr) implements QueryToken {} + + record QualityTerm(String str) implements QueryToken { + public String displayStr() { + return "q" + str; + } + } + record YearTerm(String str) implements QueryToken { + public String displayStr() { + return "year" + str; + } + } + record SizeTerm(String str) implements QueryToken { + public String displayStr() { + return "size" + str; + } + } + record RankTerm(String str) implements QueryToken { + public String displayStr() { + return "rank" + str; + } + } + record NearTerm(String str) implements QueryToken { + public String displayStr() { + return "near:" + str; + } + } + + record QsTerm(String str) implements QueryToken { + public String displayStr() { + return "qs" + str; + } + } + + record Quot(String str) implements QueryToken { + public String displayStr() { + return "\"" + str + "\""; + } + } + record Minus() implements QueryToken { + public String str() { + return "-"; + } + public String displayStr() { + return "-"; + } + } + record QMark() implements QueryToken { + public String str() { + return "?"; + } + public String displayStr() { + return "?"; + } + } + record LParen() implements QueryToken { + public String str() { + return "("; + } + public String displayStr() { + return "("; + } + } + record RParen() implements QueryToken { + public String str() { + return ")"; + } + public String displayStr() { + return ")"; + } + } + + record Ignore(String str, String displayStr) implements QueryToken {} + +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/Token.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/Token.java deleted file mode 100644 index 06c28972..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/Token.java +++ /dev/null @@ -1,49 +0,0 @@ -package nu.marginalia.functions.searchquery.query_parser.token; - -import lombok.EqualsAndHashCode; -import lombok.ToString; -import lombok.With; - -@ToString -@EqualsAndHashCode -@With -public class Token { - public TokenType type; - public String str; - public final String displayStr; - - public Token(TokenType type, String str, String displayStr) { - this.type = type; - this.str = str; - this.displayStr = safeString(displayStr); - } - - - public Token(TokenType type, String str) { - this.type = type; - this.str = str; - this.displayStr = safeString(str); - } - - private static String safeString(String s) { - return s.replaceAll("<", "<") - .replaceAll(">", ">"); - } - - public void visit(TokenVisitor visitor) { - switch (type) { - case QUOT_TERM: visitor.onQuotTerm(this); break; - case EXCLUDE_TERM: visitor.onExcludeTerm(this); break; - case PRIORTY_TERM: visitor.onPriorityTerm(this); break; - case ADVICE_TERM: visitor.onAdviceTerm(this); break; - case LITERAL_TERM: visitor.onLiteralTerm(this); break; - - case YEAR_TERM: visitor.onYearTerm(this); break; - case RANK_TERM: visitor.onRankTerm(this); break; - case SIZE_TERM: visitor.onSizeTerm(this); break; - case QS_TERM: visitor.onQsTerm(this); break; - - case QUALITY_TERM: visitor.onQualityTerm(this); break; - } - } -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenType.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenType.java deleted file mode 100644 index 85d55c35..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenType.java +++ /dev/null @@ -1,34 +0,0 @@ -package nu.marginalia.functions.searchquery.query_parser.token; - -import java.util.function.Predicate; - -public enum TokenType implements Predicate { - TERM, - - - LITERAL_TERM, - QUOT_TERM, - EXCLUDE_TERM, - ADVICE_TERM, - PRIORTY_TERM, - - QUALITY_TERM, - YEAR_TERM, - SIZE_TERM, - RANK_TERM, - NEAR_TERM, - - QS_TERM, - - QUOT, - MINUS, - QMARK, - LPAREN, - RPAREN, - - IGNORE; - - public boolean test(Token t) { - return t.type == this; - } -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenVisitor.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenVisitor.java deleted file mode 100644 index 2e14f837..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/token/TokenVisitor.java +++ /dev/null @@ -1,14 +0,0 @@ -package nu.marginalia.functions.searchquery.query_parser.token; - -public interface TokenVisitor { - void onLiteralTerm(Token token); - void onQuotTerm(Token token); - void onExcludeTerm(Token token); - void onPriorityTerm(Token token); - void onAdviceTerm(Token token); - void onYearTerm(Token token); - void onSizeTerm(Token token); - void onRankTerm(Token token); - void onQualityTerm(Token token); - void onQsTerm(Token token); -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryFactory.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryFactory.java index ac7ce2b2..15596d5c 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryFactory.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryFactory.java @@ -2,72 +2,42 @@ package nu.marginalia.functions.searchquery.svc; import com.google.inject.Inject; import com.google.inject.Singleton; -import nu.marginalia.LanguageModels; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; -import nu.marginalia.util.language.EnglishDictionary; +import nu.marginalia.functions.searchquery.query_parser.QueryExpansion; +import nu.marginalia.functions.searchquery.query_parser.token.QueryToken; +import nu.marginalia.index.query.limit.QueryStrategy; +import nu.marginalia.index.query.limit.SpecificationLimit; import nu.marginalia.language.WordPatterns; -import nu.marginalia.util.ngrams.NGramBloomFilter; import nu.marginalia.api.searchquery.model.query.QueryParams; import nu.marginalia.api.searchquery.model.query.ProcessedQuery; import nu.marginalia.functions.searchquery.query_parser.QueryParser; -import nu.marginalia.functions.searchquery.query_parser.QueryPermutation; -import nu.marginalia.functions.searchquery.query_parser.QueryVariants; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenType; -import nu.marginalia.term_frequency_dict.TermFrequencyDict; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; @Singleton public class QueryFactory { private final Logger logger = LoggerFactory.getLogger(getClass()); - private static final int RETAIN_QUERY_VARIANT_COUNT = 5; - private final ThreadLocal queryVariants; private final QueryParser queryParser = new QueryParser(); + private final QueryExpansion queryExpansion; @Inject - public QueryFactory(LanguageModels lm, - TermFrequencyDict dict, - EnglishDictionary englishDictionary, - NGramBloomFilter nGramBloomFilter) { - this.queryVariants = ThreadLocal.withInitial(() -> new QueryVariants(lm ,dict, nGramBloomFilter, englishDictionary)); + public QueryFactory(QueryExpansion queryExpansion) + { + this.queryExpansion = queryExpansion; } - public QueryPermutation getQueryPermutation() { - return new QueryPermutation(queryVariants.get()); - } public ProcessedQuery createQuery(QueryParams params) { - final var processedQuery = createQuery(getQueryPermutation(), params); - final List subqueries = processedQuery.specs.subqueries; - - // There used to be a piece of logic here that would try to figure out which one of these subqueries were the "best", - // it's gone for the moment, but it would be neat if it resurrected somehow - - trimArray(subqueries, RETAIN_QUERY_VARIANT_COUNT); - - return processedQuery; - } - - private void trimArray(List arr, int maxSize) { - if (arr.size() > maxSize) { - arr.subList(0, arr.size() - maxSize).clear(); - } - } - - public ProcessedQuery createQuery(QueryPermutation queryPermutation, - QueryParams params) - { final var query = params.humanQuery(); if (query.length() > 1000) { @@ -77,42 +47,86 @@ public class QueryFactory { List searchTermsHuman = new ArrayList<>(); List problems = new ArrayList<>(); - String domain = null; - - var basicQuery = queryParser.parse(query); + List basicQuery = queryParser.parse(query); if (basicQuery.size() >= 12) { problems.add("Your search query is too long"); basicQuery.clear(); } + List searchTermsExclude = new ArrayList<>(); + List searchTermsInclude = new ArrayList<>(); + List searchTermsAdvice = new ArrayList<>(); + List searchTermsPriority = new ArrayList<>(); + List> searchTermCoherences = new ArrayList<>(); - QueryLimitsAccumulator qualityLimits = new QueryLimitsAccumulator(params); + SpecificationLimit qualityLimit = SpecificationLimit.none(); + SpecificationLimit year = SpecificationLimit.none(); + SpecificationLimit size = SpecificationLimit.none(); + SpecificationLimit rank = SpecificationLimit.none(); + QueryStrategy queryStrategy = QueryStrategy.AUTO; - for (Token t : basicQuery) { - if (t.type == TokenType.QUOT_TERM || t.type == TokenType.LITERAL_TERM) { - if (t.str.startsWith("site:")) { - t.str = normalizeDomainName(t.str); + String domain = null; + + for (QueryToken t : basicQuery) { + switch (t) { + case QueryToken.QuotTerm(String str, String displayStr) -> { + analyzeSearchTerm(problems, str, displayStr); + searchTermsHuman.addAll(Arrays.asList(displayStr.replace("\"", "").split("\\s+"))); + + String[] parts = StringUtils.split(str, '_'); + + // Checking for stop words here is a bit of a stop-gap to fix the issue of stop words being + // required in the query (which is a problem because they are not indexed). How to do this + // in a clean way is a bit of an open problem that may not get resolved until query-parsing is + // improved. + + if (parts.length > 1 && !anyPartIsStopWord(parts)) { + // Prefer that the actual n-gram is present + searchTermsAdvice.add(str); + + // Require that the terms appear in the same sentence + searchTermCoherences.add(Arrays.asList(parts)); + + // Require that each term exists in the document + // (needed for ranking) + searchTermsInclude.addAll(Arrays.asList(parts)); + } + else { + searchTermsInclude.add(str); + } + } + case QueryToken.LiteralTerm(String str, String displayStr) -> { + analyzeSearchTerm(problems, str, displayStr); + searchTermsHuman.addAll(Arrays.asList(displayStr.split("\\s+"))); + + searchTermsInclude.add(str); } - searchTermsHuman.addAll(toHumanSearchTerms(t)); - analyzeSearchTerm(problems, t); - } - t.visit(qualityLimits); + case QueryToken.ExcludeTerm(String str, String displayStr) -> searchTermsExclude.add(str); + case QueryToken.PriorityTerm(String str, String displayStr) -> searchTermsPriority.add(str); + case QueryToken.AdviceTerm(String str, String displayStr) -> { + searchTermsAdvice.add(str); + + if (str.toLowerCase().startsWith("site:")) { + domain = str.substring("site:".length()); + } + } + + case QueryToken.YearTerm(String str) -> year = parseSpecificationLimit(str); + case QueryToken.SizeTerm(String str) -> size = parseSpecificationLimit(str); + case QueryToken.RankTerm(String str) -> rank = parseSpecificationLimit(str); + case QueryToken.QualityTerm(String str) -> qualityLimit = parseSpecificationLimit(str); + case QueryToken.QsTerm(String str) -> queryStrategy = parseQueryStrategy(str); + + default -> {} + } } - var queryPermutations = queryPermutation.permuteQueriesNew(basicQuery); - List subqueries = new ArrayList<>(); - - for (var parts : queryPermutations) { - QuerySearchTermsAccumulator termsAccumulator = new QuerySearchTermsAccumulator(parts); - - SearchSubquery subquery = termsAccumulator.createSubquery(); - - domain = termsAccumulator.domain; - - subqueries.add(subquery); + if (searchTermsInclude.isEmpty() && !searchTermsAdvice.isEmpty()) { + searchTermsInclude.addAll(searchTermsAdvice); + searchTermsAdvice.clear(); } List domainIds = params.domainIds(); @@ -123,55 +137,85 @@ public class QueryFactory { limits = limits.forSingleDomain(); } + var searchQuery = new SearchQuery( + queryExpansion.expandQuery( + searchTermsInclude + ), + searchTermsInclude, + searchTermsExclude, + searchTermsAdvice, + searchTermsPriority, + searchTermCoherences + ); + var specsBuilder = SearchSpecification.builder() - .subqueries(subqueries) + .query(searchQuery) .humanQuery(query) - .quality(qualityLimits.qualityLimit) - .year(qualityLimits.year) - .size(qualityLimits.size) - .rank(qualityLimits.rank) + .quality(qualityLimit) + .year(year) + .size(size) + .rank(rank) .domains(domainIds) .queryLimits(limits) .searchSetIdentifier(params.identifier()) .rankingParams(ResultRankingParameters.sensibleDefaults()) - .queryStrategy(qualityLimits.queryStrategy); + .queryStrategy(queryStrategy); SearchSpecification specs = specsBuilder.build(); - for (var sq : specs.subqueries) { - sq.searchTermsAdvice.addAll(params.tacitAdvice()); - sq.searchTermsPriority.addAll(params.tacitPriority()); - sq.searchTermsInclude.addAll(params.tacitIncludes()); - sq.searchTermsExclude.addAll(params.tacitExcludes()); - } + specs.query.searchTermsAdvice.addAll(params.tacitAdvice()); + specs.query.searchTermsPriority.addAll(params.tacitPriority()); + specs.query.searchTermsExclude.addAll(params.tacitExcludes()); return new ProcessedQuery(specs, searchTermsHuman, domain); } - private String normalizeDomainName(String str) { - return str.toLowerCase(); - } - - private List toHumanSearchTerms(Token t) { - if (t.type == TokenType.LITERAL_TERM) { - return Arrays.asList(t.displayStr.split("\\s+")); - } - else if (t.type == TokenType.QUOT_TERM) { - return Arrays.asList(t.displayStr.replace("\"", "").split("\\s+")); - - } - return Collections.emptyList(); - } - - private void analyzeSearchTerm(List problems, Token term) { - final String word = term.str; + private void analyzeSearchTerm(List problems, String str, String displayStr) { + final String word = str; if (word.length() < WordPatterns.MIN_WORD_LENGTH) { - problems.add("Search term \"" + term.displayStr + "\" too short"); + problems.add("Search term \"" + displayStr + "\" too short"); } if (!word.contains("_") && word.length() >= WordPatterns.MAX_WORD_LENGTH) { - problems.add("Search term \"" + term.displayStr + "\" too long"); + problems.add("Search term \"" + displayStr + "\" too long"); + } + } + private SpecificationLimit parseSpecificationLimit(String str) { + int startChar = str.charAt(0); + + int val = Integer.parseInt(str.substring(1)); + if (startChar == '=') { + return SpecificationLimit.equals(val); + } else if (startChar == '<') { + return SpecificationLimit.lessThan(val); + } else if (startChar == '>') { + return SpecificationLimit.greaterThan(val); + } else { + return SpecificationLimit.none(); } } + private QueryStrategy parseQueryStrategy(String str) { + return switch (str.toUpperCase()) { + case "RF_TITLE" -> QueryStrategy.REQUIRE_FIELD_TITLE; + case "RF_SUBJECT" -> QueryStrategy.REQUIRE_FIELD_SUBJECT; + case "RF_SITE" -> QueryStrategy.REQUIRE_FIELD_SITE; + case "RF_URL" -> QueryStrategy.REQUIRE_FIELD_URL; + case "RF_DOMAIN" -> QueryStrategy.REQUIRE_FIELD_DOMAIN; + case "RF_LINK" -> QueryStrategy.REQUIRE_FIELD_LINK; + case "SENTENCE" -> QueryStrategy.SENTENCE; + case "TOPIC" -> QueryStrategy.TOPIC; + default -> QueryStrategy.AUTO; + }; + } + + + private boolean anyPartIsStopWord(String[] parts) { + for (String part : parts) { + if (WordPatterns.isStopWord(part)) { + return true; + } + } + return false; + } } diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryLimitsAccumulator.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryLimitsAccumulator.java deleted file mode 100644 index 1b49bab3..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QueryLimitsAccumulator.java +++ /dev/null @@ -1,93 +0,0 @@ -package nu.marginalia.functions.searchquery.svc; - -import nu.marginalia.api.searchquery.model.query.QueryParams; -import nu.marginalia.index.query.limit.QueryStrategy; -import nu.marginalia.index.query.limit.SpecificationLimit; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenVisitor; - -public class QueryLimitsAccumulator implements TokenVisitor { - public SpecificationLimit qualityLimit; - public SpecificationLimit year; - public SpecificationLimit size; - public SpecificationLimit rank; - - public QueryStrategy queryStrategy = QueryStrategy.AUTO; - - public QueryLimitsAccumulator(QueryParams params) { - qualityLimit = params.quality(); - year = params.year(); - size = params.size(); - rank = params.rank(); - } - - private SpecificationLimit parseSpecificationLimit(String str) { - int startChar = str.charAt(0); - - int val = Integer.parseInt(str.substring(1)); - if (startChar == '=') { - return SpecificationLimit.equals(val); - } else if (startChar == '<') { - return SpecificationLimit.lessThan(val); - } else if (startChar == '>') { - return SpecificationLimit.greaterThan(val); - } else { - return SpecificationLimit.none(); - } - } - - private QueryStrategy parseQueryStrategy(String str) { - return switch (str.toUpperCase()) { - case "RF_TITLE" -> QueryStrategy.REQUIRE_FIELD_TITLE; - case "RF_SUBJECT" -> QueryStrategy.REQUIRE_FIELD_SUBJECT; - case "RF_SITE" -> QueryStrategy.REQUIRE_FIELD_SITE; - case "RF_URL" -> QueryStrategy.REQUIRE_FIELD_URL; - case "RF_DOMAIN" -> QueryStrategy.REQUIRE_FIELD_DOMAIN; - case "RF_LINK" -> QueryStrategy.REQUIRE_FIELD_LINK; - case "SENTENCE" -> QueryStrategy.SENTENCE; - case "TOPIC" -> QueryStrategy.TOPIC; - default -> QueryStrategy.AUTO; - }; - } - - @Override - public void onYearTerm(Token token) { - year = parseSpecificationLimit(token.str); - } - - @Override - public void onSizeTerm(Token token) { - size = parseSpecificationLimit(token.str); - } - - @Override - public void onRankTerm(Token token) { - rank = parseSpecificationLimit(token.str); - } - - @Override - public void onQualityTerm(Token token) { - qualityLimit = parseSpecificationLimit(token.str); - } - - @Override - public void onQsTerm(Token token) { - queryStrategy = parseQueryStrategy(token.str); - } - - - @Override - public void onLiteralTerm(Token token) {} - - @Override - public void onQuotTerm(Token token) {} - - @Override - public void onExcludeTerm(Token token) {} - - @Override - public void onPriorityTerm(Token token) {} - - @Override - public void onAdviceTerm(Token token) {} -} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QuerySearchTermsAccumulator.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QuerySearchTermsAccumulator.java deleted file mode 100644 index e4def0d0..00000000 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/svc/QuerySearchTermsAccumulator.java +++ /dev/null @@ -1,109 +0,0 @@ -package nu.marginalia.functions.searchquery.svc; - -import nu.marginalia.api.searchquery.model.query.SearchSubquery; -import nu.marginalia.language.WordPatterns; -import nu.marginalia.functions.searchquery.query_parser.token.Token; -import nu.marginalia.functions.searchquery.query_parser.token.TokenVisitor; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** @see SearchSubquery */ -public class QuerySearchTermsAccumulator implements TokenVisitor { - public List searchTermsExclude = new ArrayList<>(); - public List searchTermsInclude = new ArrayList<>(); - public List searchTermsAdvice = new ArrayList<>(); - public List searchTermsPriority = new ArrayList<>(); - public List> searchTermCoherences = new ArrayList<>(); - - public String domain; - - public SearchSubquery createSubquery() { - return new SearchSubquery(searchTermsInclude, searchTermsExclude, searchTermsAdvice, searchTermsPriority, searchTermCoherences); - } - - public QuerySearchTermsAccumulator(List parts) { - for (Token t : parts) { - t.visit(this); - } - - if (searchTermsInclude.isEmpty() && !searchTermsAdvice.isEmpty()) { - searchTermsInclude.addAll(searchTermsAdvice); - searchTermsAdvice.clear(); - } - - } - - @Override - public void onLiteralTerm(Token token) { - searchTermsInclude.add(token.str); - } - - @Override - public void onQuotTerm(Token token) { - String[] parts = token.str.split("_"); - - // HACK (2023-05-02 vlofgren) - // - // Checking for stop words here is a bit of a stop-gap to fix the issue of stop words being - // required in the query (which is a problem because they are not indexed). How to do this - // in a clean way is a bit of an open problem that may not get resolved until query-parsing is - // improved. - - if (parts.length > 1 && !anyPartIsStopWord(parts)) { - // Prefer that the actual n-gram is present - searchTermsAdvice.add(token.str); - - // Require that the terms appear in the same sentence - searchTermCoherences.add(Arrays.asList(parts)); - - // Require that each term exists in the document - // (needed for ranking) - searchTermsInclude.addAll(Arrays.asList(parts)); - } - else { - searchTermsInclude.add(token.str); - - } - } - - private boolean anyPartIsStopWord(String[] parts) { - for (String part : parts) { - if (WordPatterns.isStopWord(part)) { - return true; - } - } - return false; - } - - @Override - public void onExcludeTerm(Token token) { - searchTermsExclude.add(token.str); - } - - @Override - public void onPriorityTerm(Token token) { - searchTermsPriority.add(token.str); - } - - @Override - public void onAdviceTerm(Token token) { - searchTermsAdvice.add(token.str); - - if (token.str.toLowerCase().startsWith("site:")) { - domain = token.str.substring("site:".length()); - } - } - - @Override - public void onYearTerm(Token token) {} - @Override - public void onSizeTerm(Token token) {} - @Override - public void onRankTerm(Token token) {} - @Override - public void onQualityTerm(Token token) {} - @Override - public void onQsTerm(Token token) {} -} diff --git a/code/functions/search-query/java/nu/marginalia/util/ngrams/DenseBitMap.java b/code/functions/search-query/java/nu/marginalia/util/ngrams/DenseBitMap.java deleted file mode 100644 index 008b17b3..00000000 --- a/code/functions/search-query/java/nu/marginalia/util/ngrams/DenseBitMap.java +++ /dev/null @@ -1,69 +0,0 @@ -package nu.marginalia.util.ngrams; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.BitSet; - -// It's unclear why this exists, we should probably use a BitSet instead? -// Chesterton's fence? -public class DenseBitMap { - public static final long MAX_CAPACITY_2GB_16BN_ITEMS=(1L<<34)-8; - - public final long cardinality; - private final ByteBuffer buffer; - - public DenseBitMap(long cardinality) { - this.cardinality = cardinality; - - boolean misaligned = (cardinality & 7) > 0; - this.buffer = ByteBuffer.allocateDirect((int)((cardinality / 8) + (misaligned ? 1 : 0))); - } - - public static DenseBitMap loadFromFile(Path file) throws IOException { - long size = Files.size(file); - var dbm = new DenseBitMap(size/8); - - try (var bc = Files.newByteChannel(file)) { - while (dbm.buffer.position() < dbm.buffer.capacity()) { - bc.read(dbm.buffer); - } - } - dbm.buffer.clear(); - - return dbm; - } - - public void writeToFile(Path file) throws IOException { - - try (var bc = Files.newByteChannel(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) { - while (buffer.position() < buffer.capacity()) { - bc.write(buffer); - } - } - - buffer.clear(); - } - - public boolean get(long pos) { - return (buffer.get((int)(pos >>> 3)) & ((byte)1 << (int)(pos & 7))) != 0; - } - - /** Set the bit indexed by pos, returns - * its previous value. - */ - public boolean set(long pos) { - int offset = (int) (pos >>> 3); - int oldVal = buffer.get(offset); - int mask = (byte) 1 << (int) (pos & 7); - buffer.put(offset, (byte) (oldVal | mask)); - return (oldVal & mask) != 0; - } - - public void clear(long pos) { - int offset = (int)(pos >>> 3); - buffer.put(offset, (byte)(buffer.get(offset) & ~(byte)(1 << (int)(pos & 7)))); - } -} diff --git a/code/functions/search-query/java/nu/marginalia/util/ngrams/NGramBloomFilter.java b/code/functions/search-query/java/nu/marginalia/util/ngrams/NGramBloomFilter.java deleted file mode 100644 index 3326956d..00000000 --- a/code/functions/search-query/java/nu/marginalia/util/ngrams/NGramBloomFilter.java +++ /dev/null @@ -1,64 +0,0 @@ -package nu.marginalia.util.ngrams; - -import ca.rmen.porterstemmer.PorterStemmer; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hashing; -import com.google.inject.Inject; -import nu.marginalia.LanguageModels; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.regex.Pattern; - -public class NGramBloomFilter { - private final DenseBitMap bitMap; - private static final PorterStemmer ps = new PorterStemmer(); - private static final HashFunction hasher = Hashing.murmur3_128(0); - - private static final Logger logger = LoggerFactory.getLogger(NGramBloomFilter.class); - - @Inject - public NGramBloomFilter(LanguageModels lm) throws IOException { - this(loadSafely(lm.ngramBloomFilter)); - } - - private static DenseBitMap loadSafely(Path path) throws IOException { - if (Files.isRegularFile(path)) { - return DenseBitMap.loadFromFile(path); - } - else { - logger.warn("NGrams file missing " + path); - return new DenseBitMap(1); - } - } - - public NGramBloomFilter(DenseBitMap bitMap) { - this.bitMap = bitMap; - } - - public boolean isKnownNGram(String word) { - long bit = bitForWord(word, bitMap.cardinality); - - return bitMap.get(bit); - } - - public static NGramBloomFilter load(Path file) throws IOException { - return new NGramBloomFilter(DenseBitMap.loadFromFile(file)); - } - - private static final Pattern underscore = Pattern.compile("_"); - - private static long bitForWord(String s, long n) { - String[] parts = underscore.split(s); - long hc = 0; - for (String part : parts) { - hc = hc * 31 + hasher.hashString(ps.stemWord(part), StandardCharsets.UTF_8).padToLong(); - } - return (hc & 0x7FFF_FFFF_FFFF_FFFFL) % n; - } - -} diff --git a/code/functions/search-query/java/nu/marginalia/util/transform_list/TransformList.java b/code/functions/search-query/java/nu/marginalia/util/transform_list/TransformList.java index 08bc428e..62dd2e0a 100644 --- a/code/functions/search-query/java/nu/marginalia/util/transform_list/TransformList.java +++ b/code/functions/search-query/java/nu/marginalia/util/transform_list/TransformList.java @@ -80,6 +80,15 @@ public class TransformList { iter.remove(); } } + else if (firstEntity.action == Action.NO_OP) { + if (secondEntry.action == Action.REPLACE) { + backingList.set(iter.nextIndex(), secondEntry.value); + } + else if (secondEntry.action == Action.REMOVE) { + iter.next(); + iter.remove(); + } + } } } diff --git a/code/functions/search-query/test/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphTest.java b/code/functions/search-query/test/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphTest.java new file mode 100644 index 00000000..f985cd13 --- /dev/null +++ b/code/functions/search-query/test/nu/marginalia/functions/searchquery/query_parser/model/QWordGraphTest.java @@ -0,0 +1,115 @@ +package nu.marginalia.functions.searchquery.query_parser.model; + +import org.junit.jupiter.api.Test; + +import java.util.Comparator; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class QWordGraphTest { + + @Test + void forwardReachability() { + // Construct a graph like + + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("b"), "d"); + + var reachability = graph.forwardReachability(); + + System.out.println(reachability.get(graph.node("a"))); + System.out.println(reachability.get(graph.node("b"))); + System.out.println(reachability.get(graph.node("c"))); + System.out.println(reachability.get(graph.node("d"))); + + assertEquals(Set.of(graph.node(" ^ ")), reachability.get(graph.node("a"))); + assertEquals(Set.of(graph.node(" ^ "), graph.node("a")), reachability.get(graph.node("b"))); + assertEquals(Set.of(graph.node(" ^ "), graph.node("a")), reachability.get(graph.node("d"))); + assertEquals(Set.of(graph.node(" ^ "), graph.node("a"), graph.node("b"), graph.node("d")), reachability.get(graph.node("c"))); + assertEquals(Set.of(graph.node(" ^ "), graph.node("a"), graph.node("b"), graph.node("d"), graph.node("c")), reachability.get(graph.node(" $ "))); + } + + + @Test + void reverseReachability() { + // Construct a graph like + + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("b"), "d"); + + var reachability = graph.reverseReachability(); + + System.out.println(reachability.get(graph.node("a"))); + System.out.println(reachability.get(graph.node("b"))); + System.out.println(reachability.get(graph.node("c"))); + System.out.println(reachability.get(graph.node("d"))); + + assertEquals(Set.of(graph.node(" $ ")), reachability.get(graph.node("c"))); + assertEquals(Set.of(graph.node(" $ "), graph.node("c")), reachability.get(graph.node("b"))); + assertEquals(Set.of(graph.node(" $ "), graph.node("c")), reachability.get(graph.node("d"))); + assertEquals(Set.of(graph.node(" $ "), graph.node("c"), graph.node("b"), graph.node("d")), reachability.get(graph.node("a"))); + assertEquals(Set.of(graph.node(" $ "), graph.node("c"), graph.node("b"), graph.node("d"), graph.node("a")), reachability.get(graph.node(" ^ "))); + } + + @Test + void testCompile1() { + // Construct a graph like + + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("b"), "d"); + + assertEquals("a c ( b | d )", graph.compileToQuery()); + } + + @Test + void testCompile2() { + // Construct a graph like + + // ^ - a - b - c - $ + QWordGraph graph = new QWordGraph("a", "b", "c"); + + assertEquals("a b c", graph.compileToQuery()); + } + + @Test + void testCompile3() { + // Construct a graph like + + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("a"), "d"); + assertEquals("b c ( a | d )", graph.compileToQuery()); + } + + @Test + void testCompile4() { + // Construct a graph like + + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("c"), "d"); + assertEquals("a b ( c | d )", graph.compileToQuery()); + } + + @Test + void testCompile5() { + // Construct a graph like + + // /- e -\ + // ^ - a - b - c - $ + // \- d -/ + QWordGraph graph = new QWordGraph("a", "b", "c"); + graph.addVariant(graph.node("c"), "d"); + graph.addVariant(graph.node("b"), "e"); + assertEquals("a ( c ( b | e ) | d ( b | e ) )", graph.compileToQuery()); + } +} \ No newline at end of file diff --git a/code/functions/search-query/test/nu/marginalia/query/svc/QueryFactoryTest.java b/code/functions/search-query/test/nu/marginalia/query/svc/QueryFactoryTest.java index fe93a1f6..622130b7 100644 --- a/code/functions/search-query/test/nu/marginalia/query/svc/QueryFactoryTest.java +++ b/code/functions/search-query/test/nu/marginalia/query/svc/QueryFactoryTest.java @@ -3,19 +3,21 @@ package nu.marginalia.query.svc; import nu.marginalia.WmsaHome; import nu.marginalia.api.searchquery.model.query.SearchSpecification; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; +import nu.marginalia.functions.searchquery.query_parser.QueryExpansion; import nu.marginalia.functions.searchquery.svc.QueryFactory; import nu.marginalia.index.query.limit.QueryLimits; import nu.marginalia.index.query.limit.QueryStrategy; import nu.marginalia.index.query.limit.SpecificationLimit; import nu.marginalia.index.query.limit.SpecificationLimitType; -import nu.marginalia.util.language.EnglishDictionary; -import nu.marginalia.util.ngrams.NGramBloomFilter; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.api.searchquery.model.query.QueryParams; import nu.marginalia.term_frequency_dict.TermFrequencyDict; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -28,12 +30,9 @@ public class QueryFactoryTest { public static void setUpAll() throws IOException { var lm = WmsaHome.getLanguageModels(); - var tfd = new TermFrequencyDict(lm); - queryFactory = new QueryFactory(lm, - tfd, - new EnglishDictionary(tfd), - new NGramBloomFilter(lm) + queryFactory = new QueryFactory( + new QueryExpansion(new TermFrequencyDict(lm), new NgramLexicon(lm)) ); } @@ -55,6 +54,21 @@ public class QueryFactoryTest { ResultRankingParameters.TemporalBias.NONE)).specs; } + + @Test + void qsec10() { + try (var lines = Files.lines(Path.of("/home/vlofgren/Exports/qsec10/webis-qsec-10-training-set/webis-qsec-10-training-set-queries.txt"))) { + lines.limit(1000).forEach(line -> { + String[] parts = line.split("\t"); + if (parts.length == 2) { + System.out.println(parseAndGetSpecs(parts[1]).getQuery().compiledQuery); + } + }); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @Test public void testParseNoSpecials() { var year = parseAndGetSpecs("in the year 2000").year; @@ -114,17 +128,15 @@ public class QueryFactoryTest { { // the is a stopword, so it should generate an ngram search term var specs = parseAndGetSpecs("\"the shining\""); - assertEquals(List.of("the_shining"), specs.subqueries.iterator().next().searchTermsInclude); - assertEquals(List.of(), specs.subqueries.iterator().next().searchTermsAdvice); - assertEquals(List.of(), specs.subqueries.iterator().next().searchTermCoherences); + assertEquals("the_shining", specs.query.compiledQuery); } { // tde isn't a stopword, so we should get the normal behavior var specs = parseAndGetSpecs("\"tde shining\""); - assertEquals(List.of("tde", "shining"), specs.subqueries.iterator().next().searchTermsInclude); - assertEquals(List.of("tde_shining"), specs.subqueries.iterator().next().searchTermsAdvice); - assertEquals(List.of(List.of("tde", "shining")), specs.subqueries.iterator().next().searchTermCoherences); + assertEquals("tde shining", specs.query.compiledQuery); + assertEquals(List.of("tde_shining"), specs.query.searchTermsAdvice); + assertEquals(List.of(List.of("tde", "shining")), specs.query.searchTermCoherences); } } @@ -152,8 +164,18 @@ public class QueryFactoryTest { @Test public void testPriorityTerm() { - var subquery = parseAndGetSpecs("physics ?tld:edu").subqueries.iterator().next(); + var subquery = parseAndGetSpecs("physics ?tld:edu").query; assertEquals(List.of("tld:edu"), subquery.searchTermsPriority); - assertEquals(List.of("physics"), subquery.searchTermsInclude); + assertEquals("physics", subquery.compiledQuery); + } + + @Test + public void testExpansion() { + + long start = System.currentTimeMillis(); + var subquery = parseAndGetSpecs("elden ring mechanical keyboard slackware linux duke nukem 3d").query; + System.out.println("Time: " + (System.currentTimeMillis() - start)); + System.out.println(subquery.compiledQuery); + } } \ No newline at end of file diff --git a/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexEntrySource.java b/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexEntrySource.java index 37c79941..851bf9ab 100644 --- a/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexEntrySource.java +++ b/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexEntrySource.java @@ -7,6 +7,7 @@ import nu.marginalia.index.query.EntrySource; import static java.lang.Math.min; public class ReverseIndexEntrySource implements EntrySource { + private final String name; private final BTreeReader reader; int pos; @@ -15,9 +16,11 @@ public class ReverseIndexEntrySource implements EntrySource { final int entrySize; private final long wordId; - public ReverseIndexEntrySource(BTreeReader reader, + public ReverseIndexEntrySource(String name, + BTreeReader reader, int entrySize, long wordId) { + this.name = name; this.reader = reader; this.entrySize = entrySize; this.wordId = wordId; @@ -46,7 +49,7 @@ public class ReverseIndexEntrySource implements EntrySource { return; for (int ri = entrySize, wi=1; ri < buffer.end ; ri+=entrySize, wi++) { - buffer.data[wi] = buffer.data[ri]; + buffer.data.set(wi, buffer.data.get(ri)); } buffer.end /= entrySize; @@ -60,6 +63,6 @@ public class ReverseIndexEntrySource implements EntrySource { @Override public String indexName() { - return "Full:" + Long.toHexString(wordId); + return name + ":" + Long.toHexString(wordId); } } diff --git a/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexReader.java b/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexReader.java index f37420dd..72feb7fd 100644 --- a/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexReader.java +++ b/code/index/index-reverse/java/nu/marginalia/index/ReverseIndexReader.java @@ -25,8 +25,11 @@ public class ReverseIndexReader { private final long wordsDataOffset; private final Logger logger = LoggerFactory.getLogger(getClass()); private final BTreeReader wordsBTreeReader; + private final String name; + + public ReverseIndexReader(String name, Path words, Path documents) throws IOException { + this.name = name; - public ReverseIndexReader(Path words, Path documents) throws IOException { if (!Files.exists(words) || !Files.exists(documents)) { this.words = null; this.documents = null; @@ -65,8 +68,12 @@ public class ReverseIndexReader { } - long wordOffset(long wordId) { - long idx = wordsBTreeReader.findEntry(wordId); + /** Calculate the offset of the word in the documents. + * If the return-value is negative, the term does not exist + * in the index. + */ + long wordOffset(long termId) { + long idx = wordsBTreeReader.findEntry(termId); if (idx < 0) return -1L; @@ -74,37 +81,43 @@ public class ReverseIndexReader { return words.get(wordsDataOffset + idx + 1); } - public EntrySource documents(long wordId) { + public EntrySource documents(long termId) { if (null == words) { logger.warn("Reverse index is not ready, dropping query"); return new EmptyEntrySource(); } - long offset = wordOffset(wordId); + long offset = wordOffset(termId); - if (offset < 0) return new EmptyEntrySource(); + if (offset < 0) // No documents + return new EmptyEntrySource(); - return new ReverseIndexEntrySource(createReaderNew(offset), 2, wordId); + return new ReverseIndexEntrySource(name, createReaderNew(offset), 2, termId); } - public QueryFilterStepIf also(long wordId) { - long offset = wordOffset(wordId); + /** Create a filter step requiring the specified termId to exist in the documents */ + public QueryFilterStepIf also(long termId) { + long offset = wordOffset(termId); - if (offset < 0) return new QueryFilterNoPass(); + if (offset < 0) // No documents + return new QueryFilterNoPass(); - return new ReverseIndexRetainFilter(createReaderNew(offset), "full", wordId); + return new ReverseIndexRetainFilter(createReaderNew(offset), name, termId); } - public QueryFilterStepIf not(long wordId) { - long offset = wordOffset(wordId); + /** Create a filter step requiring the specified termId to be absent from the documents */ + public QueryFilterStepIf not(long termId) { + long offset = wordOffset(termId); - if (offset < 0) return new QueryFilterLetThrough(); + if (offset < 0) // No documents + return new QueryFilterLetThrough(); return new ReverseIndexRejectFilter(createReaderNew(offset)); } - public int numDocuments(long wordId) { - long offset = wordOffset(wordId); + /** Return the number of documents with the termId in the index */ + public int numDocuments(long termId) { + long offset = wordOffset(termId); if (offset < 0) return 0; @@ -112,15 +125,20 @@ public class ReverseIndexReader { return createReaderNew(offset).numEntries(); } + /** Create a BTreeReader for the document offset associated with a termId */ private BTreeReader createReaderNew(long offset) { - return new BTreeReader(documents, ReverseIndexParameters.docsBTreeContext, offset); + return new BTreeReader( + documents, + ReverseIndexParameters.docsBTreeContext, + offset); } - public long[] getTermMeta(long wordId, long[] docIds) { - long offset = wordOffset(wordId); + public long[] getTermMeta(long termId, long[] docIds) { + long offset = wordOffset(termId); if (offset < 0) { - logger.debug("Missing offset for word {}", wordId); + // This is likely a bug in the code, but we can't throw an exception here + logger.debug("Missing offset for word {}", termId); return new long[docIds.length]; } @@ -133,10 +151,9 @@ public class ReverseIndexReader { private boolean isUniqueAndSorted(long[] ids) { if (ids.length == 0) return true; - long prev = ids[0]; for (int i = 1; i < ids.length; i++) { - if(ids[i] <= prev) + if(ids[i] <= ids[i-1]) return false; } diff --git a/code/index/index-reverse/test/nu/marginalia/index/ReverseIndexReaderTest.java b/code/index/index-reverse/test/nu/marginalia/index/ReverseIndexReaderTest.java index e6b76249..ed8b4193 100644 --- a/code/index/index-reverse/test/nu/marginalia/index/ReverseIndexReaderTest.java +++ b/code/index/index-reverse/test/nu/marginalia/index/ReverseIndexReaderTest.java @@ -102,7 +102,7 @@ class ReverseIndexReaderTest { preindex.finalizeIndex(docsFile, wordsFile); preindex.delete(); - return new ReverseIndexReader(wordsFile, docsFile); + return new ReverseIndexReader("test", wordsFile, docsFile); } } \ No newline at end of file diff --git a/code/index/java/nu/marginalia/index/IndexFactory.java b/code/index/java/nu/marginalia/index/IndexFactory.java index 48911546..a1d2f5a5 100644 --- a/code/index/java/nu/marginalia/index/IndexFactory.java +++ b/code/index/java/nu/marginalia/index/IndexFactory.java @@ -41,14 +41,14 @@ public class IndexFactory { public ReverseIndexReader getReverseIndexReader() throws IOException { - return new ReverseIndexReader( + return new ReverseIndexReader("full", ReverseIndexFullFileNames.resolve(liveStorage, ReverseIndexFullFileNames.FileIdentifier.WORDS, ReverseIndexFullFileNames.FileVersion.CURRENT), ReverseIndexFullFileNames.resolve(liveStorage, ReverseIndexFullFileNames.FileIdentifier.DOCS, ReverseIndexFullFileNames.FileVersion.CURRENT) ); } public ReverseIndexReader getReverseIndexPrioReader() throws IOException { - return new ReverseIndexReader( + return new ReverseIndexReader("prio", ReverseIndexPrioFileNames.resolve(liveStorage, ReverseIndexPrioFileNames.FileIdentifier.WORDS, ReverseIndexPrioFileNames.FileVersion.CURRENT), ReverseIndexPrioFileNames.resolve(liveStorage, ReverseIndexPrioFileNames.FileIdentifier.DOCS, ReverseIndexPrioFileNames.FileVersion.CURRENT) ); diff --git a/code/index/java/nu/marginalia/index/IndexGrpcService.java b/code/index/java/nu/marginalia/index/IndexGrpcService.java index a47c4684..4810d625 100644 --- a/code/index/java/nu/marginalia/index/IndexGrpcService.java +++ b/code/index/java/nu/marginalia/index/IndexGrpcService.java @@ -9,14 +9,15 @@ import io.prometheus.client.Histogram; import it.unimi.dsi.fastutil.longs.LongArrayList; import lombok.SneakyThrows; import nu.marginalia.api.searchquery.*; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqDataInt; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; import nu.marginalia.api.searchquery.model.results.*; import nu.marginalia.array.buffer.LongQueryBuffer; import nu.marginalia.index.index.StatefulIndex; import nu.marginalia.index.model.SearchParameters; import nu.marginalia.index.model.SearchTerms; -import nu.marginalia.index.model.SearchTermsUtil; import nu.marginalia.index.query.IndexQuery; import nu.marginalia.index.query.IndexSearchBudget; import nu.marginalia.index.results.IndexResultValuatorService; @@ -135,15 +136,15 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { var rawItem = RpcRawResultItem.newBuilder(); rawItem.setCombinedId(rawResult.combinedId); rawItem.setResultsFromDomain(rawResult.resultsFromDomain); + rawItem.setHtmlFeatures(rawResult.htmlFeatures); + rawItem.setEncodedDocMetadata(rawResult.encodedDocMetadata); + rawItem.setHasPriorityTerms(rawResult.hasPrioTerm); for (var score : rawResult.keywordScores) { rawItem.addKeywordScores( RpcResultKeywordScore.newBuilder() - .setEncodedDocMetadata(score.encodedDocMetadata()) .setEncodedWordMetadata(score.encodedWordMetadata()) .setKeyword(score.keyword) - .setHtmlFeatures(score.htmlFeatures()) - .setSubquery(score.subquery) ); } @@ -156,6 +157,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { .setTitle(result.title) .setUrl(result.url.toString()) .setWordsTotal(result.wordsTotal) + .setBestPositions(result.bestPositions) .setRawItem(rawItem); if (result.pubYear != null) { @@ -203,7 +205,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { return new SearchResultSet(List.of()); } - ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.subqueries); + ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.compiledQueryIds); var queryExecution = new QueryExecution(rankingContext, params.fetchSize); @@ -255,14 +257,10 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { /** Execute a search query */ public SearchResultSet run(SearchParameters parameters) throws SQLException, InterruptedException { - for (var subquery : parameters.subqueries) { - var terms = new SearchTerms(subquery); - if (terms.isEmpty()) - continue; + var terms = new SearchTerms(parameters.query, parameters.compiledQueryIds); - for (var indexQuery : index.createQueries(terms, parameters.queryParams)) { - workerPool.execute(new IndexLookup(indexQuery, parameters.budget)); - } + for (var indexQuery : index.createQueries(terms, parameters.queryParams)) { + workerPool.execute(new IndexLookup(indexQuery, parameters.budget)); } for (int i = 0; i < indexValuationThreads; i++) { @@ -327,7 +325,9 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { buffer.reset(); query.getMoreResults(buffer); - results.addElements(0, buffer.data, 0, buffer.end); + for (int i = 0; i < buffer.end; i++) { + results.add(buffer.data.get(i)); + } if (results.size() < 512) { enqueueResults(new CombinedDocIdList(results)); @@ -413,18 +413,23 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase { } - private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams, List subqueries) { - final var termToId = SearchTermsUtil.getAllIncludeTerms(subqueries); - final Map termFrequencies = new HashMap<>(termToId.size()); - final Map prioFrequencies = new HashMap<>(termToId.size()); + private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams, + CompiledQueryLong compiledQueryIds) + { - termToId.forEach((key, id) -> termFrequencies.put(key, index.getTermFrequency(id))); - termToId.forEach((key, id) -> prioFrequencies.put(key, index.getTermFrequencyPrio(id))); + int[] full = new int[compiledQueryIds.size()]; + int[] prio = new int[compiledQueryIds.size()]; + + for (int idx = 0; idx < compiledQueryIds.size(); idx++) { + long id = compiledQueryIds.at(idx); + full[idx] = index.getTermFrequency(id); + prio[idx] = index.getTermFrequencyPrio(id); + } return new ResultRankingContext(index.getTotalDocCount(), rankingParams, - termFrequencies, - prioFrequencies); + new CqDataInt(full), + new CqDataInt(prio)); } } diff --git a/code/index/java/nu/marginalia/index/index/CombinedIndexReader.java b/code/index/java/nu/marginalia/index/index/CombinedIndexReader.java index ea78739c..27a631f5 100644 --- a/code/index/java/nu/marginalia/index/index/CombinedIndexReader.java +++ b/code/index/java/nu/marginalia/index/index/CombinedIndexReader.java @@ -38,6 +38,14 @@ public class CombinedIndexReader { return new IndexQueryBuilderImpl(reverseIndexFullReader, reverseIndexPriorityReader, query); } + public QueryFilterStepIf hasWordFull(long termId) { + return reverseIndexFullReader.also(termId); + } + + public QueryFilterStepIf hasWordPrio(long termId) { + return reverseIndexPriorityReader.also(termId); + } + /** Creates a query builder for terms in the priority index */ public IndexQueryBuilder findPriorityWord(long wordId) { diff --git a/code/index/java/nu/marginalia/index/index/IndexQueryBuilderImpl.java b/code/index/java/nu/marginalia/index/index/IndexQueryBuilderImpl.java index 825728ae..0f63fdbc 100644 --- a/code/index/java/nu/marginalia/index/index/IndexQueryBuilderImpl.java +++ b/code/index/java/nu/marginalia/index/index/IndexQueryBuilderImpl.java @@ -1,9 +1,11 @@ package nu.marginalia.index.index; +import java.util.List; import gnu.trove.set.hash.TLongHashSet; import nu.marginalia.index.ReverseIndexReader; import nu.marginalia.index.query.IndexQuery; import nu.marginalia.index.query.IndexQueryBuilder; +import nu.marginalia.index.query.filter.QueryFilterAnyOf; import nu.marginalia.index.query.filter.QueryFilterStepIf; public class IndexQueryBuilderImpl implements IndexQueryBuilder { @@ -34,7 +36,7 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder { return this; } - public IndexQueryBuilder alsoFull(long termId) { + public IndexQueryBuilder also(long termId) { if (alreadyConsideredTerms.add(termId)) { query.addInclusionFilter(reverseIndexFullReader.also(termId)); @@ -43,16 +45,7 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder { return this; } - public IndexQueryBuilder alsoPrio(long termId) { - - if (alreadyConsideredTerms.add(termId)) { - query.addInclusionFilter(reverseIndexPrioReader.also(termId)); - } - - return this; - } - - public IndexQueryBuilder notFull(long termId) { + public IndexQueryBuilder not(long termId) { query.addInclusionFilter(reverseIndexFullReader.not(termId)); @@ -66,6 +59,20 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder { return this; } + public IndexQueryBuilder addInclusionFilterAny(List filterSteps) { + if (filterSteps.isEmpty()) + return this; + + if (filterSteps.size() == 1) { + query.addInclusionFilter(filterSteps.getFirst()); + } + else { + query.addInclusionFilter(new QueryFilterAnyOf(filterSteps)); + } + + return this; + } + public IndexQuery build() { return query; } diff --git a/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java b/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java new file mode 100644 index 00000000..ffaa5176 --- /dev/null +++ b/code/index/java/nu/marginalia/index/index/QueryBranchWalker.java @@ -0,0 +1,108 @@ +package nu.marginalia.index.index; + +import it.unimi.dsi.fastutil.longs.LongArrayList; +import it.unimi.dsi.fastutil.longs.LongArraySet; +import it.unimi.dsi.fastutil.longs.LongSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +/** Helper class for index query construction */ +public class QueryBranchWalker { + private static final Logger logger = LoggerFactory.getLogger(QueryBranchWalker.class); + public final long[] priorityOrder; + public final List paths; + public final long termId; + + private QueryBranchWalker(long[] priorityOrder, List paths, long termId) { + this.priorityOrder = priorityOrder; + this.paths = paths; + this.termId = termId; + } + + public boolean atEnd() { + return priorityOrder.length == 0; + } + + /** Group the provided paths by the lowest termId they contain per the provided priorityOrder, + * into a list of QueryBranchWalkers. This can be performed iteratively on the resultant QBW:s + * to traverse the tree via the next() method. + *

+ * The paths can be extracted through the {@link nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates CompiledQueryAggregates} + * queriesAggregate method. + */ + public static List create(long[] priorityOrder, List paths) { + if (paths.isEmpty()) + return List.of(); + + List ret = new ArrayList<>(); + List remainingPaths = new LinkedList<>(paths); + remainingPaths.removeIf(LongSet::isEmpty); + + List pathsForPrio = new ArrayList<>(); + + for (int i = 0; i < priorityOrder.length; i++) { + long termId = priorityOrder[i]; + + var it = remainingPaths.iterator(); + + while (it.hasNext()) { + var path = it.next(); + + if (path.contains(termId)) { + // Remove the current termId from the path + path.remove(termId); + + // Add it to the set of paths associated with the termId + pathsForPrio.add(path); + + // Remove it from consideration + it.remove(); + } + } + + if (!pathsForPrio.isEmpty()) { + long[] newPrios = keepRelevantPriorities(priorityOrder, pathsForPrio); + ret.add(new QueryBranchWalker(newPrios, new ArrayList<>(pathsForPrio), termId)); + pathsForPrio.clear(); + } + } + + // This happens if the priorityOrder array doesn't contain all items in the paths, + // in practice only when an index doesn't contain all the search terms, so we can just + // skip those paths + if (!remainingPaths.isEmpty()) { + logger.debug("Dropping: {}", remainingPaths); + } + + return ret; + } + + /** From the provided priorityOrder array, keep the elements that are present in any set in paths */ + private static long[] keepRelevantPriorities(long[] priorityOrder, List paths) { + LongArrayList remainingPrios = new LongArrayList(paths.size()); + + // these sets are typically very small so array set is a good choice + LongSet allElements = new LongArraySet(priorityOrder.length); + for (var path : paths) { + allElements.addAll(path); + } + + for (var p : priorityOrder) { + if (allElements.contains(p)) + remainingPrios.add(p); + } + + return remainingPrios.elements(); + } + + /** Convenience method that applies the create() method + * to the priority order and paths associated with this instance */ + public List next() { + return create(priorityOrder, paths); + } + +} diff --git a/code/index/java/nu/marginalia/index/index/StatefulIndex.java b/code/index/java/nu/marginalia/index/index/StatefulIndex.java index a49e740e..ae7b1353 100644 --- a/code/index/java/nu/marginalia/index/index/StatefulIndex.java +++ b/code/index/java/nu/marginalia/index/index/StatefulIndex.java @@ -2,6 +2,12 @@ package nu.marginalia.index.index; import com.google.inject.Inject; import com.google.inject.Singleton; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongSet; +import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; +import nu.marginalia.index.query.filter.QueryFilterAllOf; +import nu.marginalia.index.query.filter.QueryFilterAnyOf; +import nu.marginalia.index.query.filter.QueryFilterStepIf; import nu.marginalia.index.results.model.ids.CombinedDocIdList; import nu.marginalia.index.results.model.ids.DocMetadataList; import nu.marginalia.index.model.QueryParams; @@ -14,12 +20,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; +import java.util.*; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Predicate; /** This class delegates SearchIndexReader and deals with the stateful nature of the index, * i.e. it may be possible to reconstruct the index and load a new set of data. @@ -87,7 +92,6 @@ public class StatefulIndex { logger.error("Uncaught exception", ex); } finally { - lock.unlock(); } @@ -105,7 +109,6 @@ public class StatefulIndex { return combinedIndexReader != null && combinedIndexReader.isLoaded(); } - public List createQueries(SearchTerms terms, QueryParams params) { if (!isLoaded()) { @@ -113,55 +116,106 @@ public class StatefulIndex { return Collections.emptyList(); } - final long[] orderedIncludes = terms.sortedDistinctIncludes(this::compareKeywords); - final long[] orderedIncludesPrio = terms.sortedDistinctIncludes(this::compareKeywordsPrio); - List queryHeads = new ArrayList<>(10); - List queries = new ArrayList<>(10); - // To ensure that good results are discovered, create separate query heads for the priority index that - // filter for terms that contain pairs of two search terms - if (orderedIncludesPrio.length > 1) { - for (int i = 0; i + 1 < orderedIncludesPrio.length; i++) { - for (int j = i + 1; j < orderedIncludesPrio.length; j++) { - var entrySource = combinedIndexReader - .findPriorityWord(orderedIncludesPrio[i]) - .alsoPrio(orderedIncludesPrio[j]); - queryHeads.add(entrySource); + final long[] termPriority = terms.sortedDistinctIncludes(this::compareKeywords); + List paths = CompiledQueryAggregates.queriesAggregate(terms.compiledQuery()); + + // Remove any paths that do not contain all prioritized terms, as this means + // the term is missing from the index and can never be found + paths.removeIf(containsAll(termPriority).negate()); + + List walkers = QueryBranchWalker.create(termPriority, paths); + + for (var walker : walkers) { + for (var builder : List.of( + combinedIndexReader.findPriorityWord(walker.termId), + combinedIndexReader.findFullWord(walker.termId) + )) + { + queryHeads.add(builder); + + if (walker.atEnd()) + continue; // Single term search query + + // Add filter steps for the remaining combinations of terms + List filterSteps = new ArrayList<>(); + for (var step : walker.next()) { + filterSteps.add(createFilter(step, 0)); } + builder.addInclusionFilterAny(filterSteps); } } - // Next consider entries that appear only once in the priority index - for (var wordId : orderedIncludesPrio) { - queryHeads.add(combinedIndexReader.findPriorityWord(wordId)); - } - - // Finally consider terms in the full index - queryHeads.add(combinedIndexReader.findFullWord(orderedIncludes[0])); + // Add additional conditions to the query heads for (var query : queryHeads) { - if (query == null) { - return Collections.emptyList(); - } - // Note that we can add all includes as filters, even though - // they may not be present in the query head, as the query builder - // will ignore redundant include filters: - for (long orderedInclude : orderedIncludes) { - query = query.alsoFull(orderedInclude); + // Advice terms are a special case, mandatory but not ranked, and exempt from re-writing + for (long term : terms.advice()) { + query = query.also(term); } for (long term : terms.excludes()) { - query = query.notFull(term); + query = query.not(term); } // Run these filter steps last, as they'll worst-case cause as many page faults as there are // items in the buffer - queries.add(query.addInclusionFilter(combinedIndexReader.filterForParams(params)).build()); + query.addInclusionFilter(combinedIndexReader.filterForParams(params)); } - return queries; + return queryHeads + .stream() + .map(IndexQueryBuilder::build) + .toList(); + } + + /** Recursively create a filter step based on the QBW and its children */ + private QueryFilterStepIf createFilter(QueryBranchWalker walker, int depth) { + + // Create a filter for the current termId + final QueryFilterStepIf ownFilterCondition = ownFilterCondition(walker, depth); + + var childSteps = walker.next(); + if (childSteps.isEmpty()) // no children, and so we're satisfied with just a single filter condition + return ownFilterCondition; + + // If there are children, we append the filter conditions for each child as an anyOf condition + // to the current filter condition + + List combinedFilters = new ArrayList<>(); + + for (var step : childSteps) { + // Recursion will be limited to a fairly shallow stack depth due to how the queries are constructed. + var childFilter = createFilter(step, depth+1); + combinedFilters.add(new QueryFilterAllOf(ownFilterCondition, childFilter)); + } + + // Flatten the filter conditions if there's only one branch + if (combinedFilters.size() == 1) + return combinedFilters.getFirst(); + else + return new QueryFilterAnyOf(combinedFilters); + } + + /** Create a filter condition based on the termId associated with the QBW */ + private QueryFilterStepIf ownFilterCondition(QueryBranchWalker walker, int depth) { + if (depth < 2) { + // At shallow depths we prioritize terms that appear in the priority index, + // to increase the odds we find "good" results before the execution timer runs out + return new QueryFilterAnyOf( + combinedIndexReader.hasWordPrio(walker.termId), + combinedIndexReader.hasWordFull(walker.termId) + ); + } else { + return combinedIndexReader.hasWordFull(walker.termId); + } + } + + private Predicate containsAll(long[] permitted) { + LongSet permittedTerms = new LongOpenHashSet(permitted); + return permittedTerms::containsAll; } private int compareKeywords(long a, long b) { @@ -171,13 +225,6 @@ public class StatefulIndex { ); } - private int compareKeywordsPrio(long a, long b) { - return Long.compare( - combinedIndexReader.numHitsPrio(a), - combinedIndexReader.numHitsPrio(b) - ); - } - /** Return an array of encoded document metadata longs corresponding to the * document identifiers provided; with metadata for termId. The input array * docs[] *must* be sorted. diff --git a/code/index/java/nu/marginalia/index/model/SearchParameters.java b/code/index/java/nu/marginalia/index/model/SearchParameters.java index 7db25341..f0e851e5 100644 --- a/code/index/java/nu/marginalia/index/model/SearchParameters.java +++ b/code/index/java/nu/marginalia/index/model/SearchParameters.java @@ -2,16 +2,16 @@ package nu.marginalia.index.model; import nu.marginalia.api.searchquery.IndexProtobufCodec; import nu.marginalia.api.searchquery.RpcIndexQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.query.IndexSearchBudget; import nu.marginalia.index.query.limit.QueryStrategy; import nu.marginalia.index.searchset.SearchSet; -import java.util.ArrayList; -import java.util.List; - import static nu.marginalia.api.searchquery.IndexProtobufCodec.convertSpecLimit; public class SearchParameters { @@ -21,13 +21,16 @@ public class SearchParameters { */ public final int fetchSize; public final IndexSearchBudget budget; - public final List subqueries; + public final SearchQuery query; public final QueryParams queryParams; public final ResultRankingParameters rankingParams; public final int limitByDomain; public final int limitTotal; + public final CompiledQuery compiledQuery; + public final CompiledQueryLong compiledQueryIds; + // mutable: /** @@ -40,7 +43,7 @@ public class SearchParameters { this.fetchSize = limits.fetchSize(); this.budget = new IndexSearchBudget(limits.timeoutMs()); - this.subqueries = specsSet.subqueries; + this.query = specsSet.query; this.limitByDomain = limits.resultsByDomain(); this.limitTotal = limits.resultsTotal(); @@ -52,6 +55,9 @@ public class SearchParameters { searchSet, specsSet.queryStrategy); + compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery); + compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId); + rankingParams = specsSet.rankingParams; } @@ -63,11 +69,8 @@ public class SearchParameters { // The time budget is halved because this is the point when we start to // wrap up the search and return the results. this.budget = new IndexSearchBudget(limits.timeoutMs() / 2); + this.query = IndexProtobufCodec.convertRpcQuery(request.getQuery()); - this.subqueries = new ArrayList<>(request.getSubqueriesCount()); - for (int i = 0; i < request.getSubqueriesCount(); i++) { - this.subqueries.add(IndexProtobufCodec.convertSearchSubquery(request.getSubqueries(i))); - } this.limitByDomain = limits.resultsByDomain(); this.limitTotal = limits.resultsTotal(); @@ -79,9 +82,13 @@ public class SearchParameters { searchSet, QueryStrategy.valueOf(request.getQueryStrategy())); + compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery); + compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId); + rankingParams = IndexProtobufCodec.convertRankingParameterss(request.getParameters()); } + public long getDataCost() { return dataCost; } diff --git a/code/index/java/nu/marginalia/index/model/SearchTerms.java b/code/index/java/nu/marginalia/index/model/SearchTerms.java index c32b1aa3..8115c109 100644 --- a/code/index/java/nu/marginalia/index/model/SearchTerms.java +++ b/code/index/java/nu/marginalia/index/model/SearchTerms.java @@ -3,49 +3,36 @@ package nu.marginalia.index.model; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongComparator; import it.unimi.dsi.fastutil.longs.LongList; -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import static nu.marginalia.index.model.SearchTermsUtil.getWordId; public final class SearchTerms { - private final LongList includes; + private final LongList advice; private final LongList excludes; private final LongList priority; private final List coherences; - public SearchTerms( - LongList includes, - LongList excludes, - LongList priority, - List coherences - ) { - this.includes = includes; - this.excludes = excludes; - this.priority = priority; - this.coherences = coherences; - } + private final CompiledQueryLong compiledQueryIds; - public SearchTerms(SearchSubquery subquery) { - this(new LongArrayList(), - new LongArrayList(), - new LongArrayList(), - new ArrayList<>()); + public SearchTerms(SearchQuery query, + CompiledQueryLong compiledQueryIds) + { + this.excludes = new LongArrayList(); + this.priority = new LongArrayList(); + this.coherences = new ArrayList<>(); + this.advice = new LongArrayList(); + this.compiledQueryIds = compiledQueryIds; - for (var word : subquery.searchTermsInclude) { - includes.add(getWordId(word)); - } - for (var word : subquery.searchTermsAdvice) { - // This looks like a bug, but it's not - includes.add(getWordId(word)); + for (var word : query.searchTermsAdvice) { + advice.add(getWordId(word)); } - - for (var coherence : subquery.searchTermCoherences) { + for (var coherence : query.searchTermCoherences) { LongList parts = new LongArrayList(coherence.size()); for (var word : coherence) { @@ -55,39 +42,32 @@ public final class SearchTerms { coherences.add(parts); } - for (var word : subquery.searchTermsExclude) { + for (var word : query.searchTermsExclude) { excludes.add(getWordId(word)); } - for (var word : subquery.searchTermsPriority) { + + for (var word : query.searchTermsPriority) { priority.add(getWordId(word)); } } public boolean isEmpty() { - return includes.isEmpty(); + return compiledQueryIds.isEmpty(); } public long[] sortedDistinctIncludes(LongComparator comparator) { - if (includes.isEmpty()) - return includes.toLongArray(); - - LongList list = new LongArrayList(new LongOpenHashSet(includes)); + LongList list = new LongArrayList(compiledQueryIds.copyData()); list.sort(comparator); return list.toLongArray(); } - public int size() { - return includes.size() + excludes.size() + priority.size(); - } - - public LongList includes() { - return includes; - } public LongList excludes() { return excludes; } - + public LongList advice() { + return advice; + } public LongList priority() { return priority; } @@ -96,29 +76,6 @@ public final class SearchTerms { return coherences; } - @Override - public boolean equals(Object obj) { - if (obj == this) return true; - if (obj == null || obj.getClass() != this.getClass()) return false; - var that = (SearchTerms) obj; - return Objects.equals(this.includes, that.includes) && - Objects.equals(this.excludes, that.excludes) && - Objects.equals(this.priority, that.priority) && - Objects.equals(this.coherences, that.coherences); - } - - @Override - public int hashCode() { - return Objects.hash(includes, excludes, priority, coherences); - } - - @Override - public String toString() { - return "SearchTerms[" + - "includes=" + includes + ", " + - "excludes=" + excludes + ", " + - "priority=" + priority + ", " + - "coherences=" + coherences + ']'; - } + public CompiledQueryLong compiledQuery() { return compiledQueryIds; } } diff --git a/code/index/java/nu/marginalia/index/model/SearchTermsUtil.java b/code/index/java/nu/marginalia/index/model/SearchTermsUtil.java index 9797ca95..fa516565 100644 --- a/code/index/java/nu/marginalia/index/model/SearchTermsUtil.java +++ b/code/index/java/nu/marginalia/index/model/SearchTermsUtil.java @@ -1,29 +1,9 @@ package nu.marginalia.index.model; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; import nu.marginalia.hash.MurmurHash3_128; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - public class SearchTermsUtil { - /** Extract all include-terms from the specified subqueries, - * and a return a map of the terms and their termIds. - */ - public static Map getAllIncludeTerms(List subqueries) { - Map ret = new HashMap<>(); - - for (var subquery : subqueries) { - for (var include : subquery.searchTermsInclude) { - ret.computeIfAbsent(include, i -> getWordId(include)); - } - } - - return ret; - } - private static final MurmurHash3_128 hasher = new MurmurHash3_128(); /** Translate the word to a unique id. */ diff --git a/code/index/java/nu/marginalia/index/results/IndexMetadataService.java b/code/index/java/nu/marginalia/index/results/IndexMetadataService.java index 1932a5a4..ce23c3f2 100644 --- a/code/index/java/nu/marginalia/index/results/IndexMetadataService.java +++ b/code/index/java/nu/marginalia/index/results/IndexMetadataService.java @@ -4,7 +4,8 @@ import com.google.inject.Inject; import gnu.trove.map.hash.TObjectLongHashMap; import it.unimi.dsi.fastutil.longs.Long2ObjectArrayMap; import it.unimi.dsi.fastutil.longs.LongArrayList; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.index.index.StatefulIndex; import nu.marginalia.index.model.SearchTermsUtil; import nu.marginalia.index.results.model.QuerySearchTerms; @@ -13,9 +14,6 @@ import nu.marginalia.index.results.model.TermMetadataForCombinedDocumentIds; import nu.marginalia.index.results.model.ids.CombinedDocIdList; import nu.marginalia.index.results.model.ids.TermIdList; -import java.util.ArrayList; -import java.util.List; - import static nu.marginalia.index.results.model.TermCoherenceGroupList.TermCoherenceGroup; import static nu.marginalia.index.results.model.TermMetadataForCombinedDocumentIds.DocumentsWithMetadata; @@ -42,43 +40,47 @@ public class IndexMetadataService { return new TermMetadataForCombinedDocumentIds(termdocToMeta); } - public QuerySearchTerms getSearchTerms(List searchTermVariants) { + public QuerySearchTerms getSearchTerms(CompiledQuery compiledQuery, SearchQuery searchQuery) { LongArrayList termIdsList = new LongArrayList(); + LongArrayList termIdsPrio = new LongArrayList(); TObjectLongHashMap termToId = new TObjectLongHashMap<>(10, 0.75f, -1); - for (var subquery : searchTermVariants) { - for (var term : subquery.searchTermsInclude) { - if (termToId.containsKey(term)) { - continue; - } + for (String word : compiledQuery) { + long id = SearchTermsUtil.getWordId(word); + termIdsList.add(id); + termToId.put(word, id); + } - long id = SearchTermsUtil.getWordId(term); - termIdsList.add(id); - termToId.put(term, id); + for (var term : searchQuery.searchTermsAdvice) { + if (termToId.containsKey(term)) { + continue; } + + long id = SearchTermsUtil.getWordId(term); + termIdsList.add(id); + termToId.put(term, id); + } + + for (var term : searchQuery.searchTermsPriority) { + if (termToId.containsKey(term)) { + continue; + } + + long id = SearchTermsUtil.getWordId(term); + termIdsList.add(id); + termIdsPrio.add(id); + termToId.put(term, id); } return new QuerySearchTerms(termToId, new TermIdList(termIdsList), - getTermCoherences(searchTermVariants)); - } - - - private TermCoherenceGroupList getTermCoherences(List searchTermVariants) { - List coherences = new ArrayList<>(); - - for (var subquery : searchTermVariants) { - for (var coh : subquery.searchTermCoherences) { - coherences.add(new TermCoherenceGroup(coh)); - } - - // It's assumed each subquery has identical coherences - break; - } - - return new TermCoherenceGroupList(coherences); + new TermIdList(termIdsPrio), + new TermCoherenceGroupList( + searchQuery.searchTermCoherences.stream().map(TermCoherenceGroup::new).toList() + ) + ); } } diff --git a/code/index/java/nu/marginalia/index/results/IndexResultValuationContext.java b/code/index/java/nu/marginalia/index/results/IndexResultValuationContext.java index 967a600f..a9d6b4a6 100644 --- a/code/index/java/nu/marginalia/index/results/IndexResultValuationContext.java +++ b/code/index/java/nu/marginalia/index/results/IndexResultValuationContext.java @@ -1,10 +1,12 @@ package nu.marginalia.index.results; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.compiled.*; +import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; import nu.marginalia.api.searchquery.model.results.ResultRankingContext; import nu.marginalia.api.searchquery.model.results.SearchResultItem; import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; import nu.marginalia.index.index.StatefulIndex; +import nu.marginalia.index.model.SearchParameters; import nu.marginalia.index.results.model.ids.CombinedDocIdList; import nu.marginalia.index.model.QueryParams; import nu.marginalia.index.results.model.QuerySearchTerms; @@ -23,7 +25,6 @@ import java.util.List; * reasons to cache this data, and performs the calculations */ public class IndexResultValuationContext { private final StatefulIndex statefulIndex; - private final List> searchTermVariants; private final QueryParams queryParams; private final TermMetadataForCombinedDocumentIds termMetadataForCombinedDocumentIds; @@ -31,24 +32,28 @@ public class IndexResultValuationContext { private final ResultRankingContext rankingContext; private final ResultValuator searchResultValuator; + private final CompiledQuery compiledQuery; + private final CompiledQueryLong compiledQueryIds; public IndexResultValuationContext(IndexMetadataService metadataService, ResultValuator searchResultValuator, CombinedDocIdList ids, StatefulIndex statefulIndex, ResultRankingContext rankingContext, - List subqueries, - QueryParams queryParams + SearchParameters params ) { this.statefulIndex = statefulIndex; this.rankingContext = rankingContext; this.searchResultValuator = searchResultValuator; - this.searchTermVariants = subqueries.stream().map(sq -> sq.searchTermsInclude).distinct().toList(); - this.queryParams = queryParams; + this.queryParams = params.queryParams; + this.compiledQuery = params.compiledQuery; + this.compiledQueryIds = params.compiledQueryIds; - this.searchTerms = metadataService.getSearchTerms(subqueries); - this.termMetadataForCombinedDocumentIds = metadataService.getTermMetadataForDocuments(ids, searchTerms.termIdsAll); + this.searchTerms = metadataService.getSearchTerms(params.compiledQuery, params.query); + + this.termMetadataForCombinedDocumentIds = metadataService.getTermMetadataForDocuments(ids, + searchTerms.termIdsAll); } private final long flagsFilterMask = @@ -65,110 +70,97 @@ public class IndexResultValuationContext { long docMetadata = statefulIndex.getDocumentMetadata(docId); int htmlFeatures = statefulIndex.getHtmlFeatures(docId); - int maxFlagsCount = 0; - boolean anyAllSynthetic = false; - int maxPositionsSet = 0; - SearchResultItem searchResult = new SearchResultItem(docId, - searchTermVariants.stream().mapToInt(List::size).sum()); + docMetadata, + htmlFeatures, + hasPrioTerm(combinedId)); - for (int querySetId = 0; - querySetId < searchTermVariants.size(); - querySetId++) - { - var termList = searchTermVariants.get(querySetId); + long[] wordMetas = new long[compiledQuery.size()]; + SearchResultKeywordScore[] scores = new SearchResultKeywordScore[compiledQuery.size()]; - SearchResultKeywordScore[] termScoresForSet = new SearchResultKeywordScore[termList.size()]; + for (int i = 0; i < wordMetas.length; i++) { + final long termId = compiledQueryIds.at(i); + final String term = compiledQuery.at(i); - boolean synthetic = true; - - for (int termIdx = 0; termIdx < termList.size(); termIdx++) { - String searchTerm = termList.get(termIdx); - - long termMetadata = termMetadataForCombinedDocumentIds.getTermMetadata( - searchTerms.getIdForTerm(searchTerm), - combinedId - ); - - var score = new SearchResultKeywordScore( - querySetId, - searchTerm, - termMetadata, - docMetadata, - htmlFeatures - ); - - synthetic &= WordFlags.Synthetic.isPresent(termMetadata); - - searchResult.keywordScores.add(score); - - termScoresForSet[termIdx] = score; - } - - if (!meetsQueryStrategyRequirements(termScoresForSet, queryParams.queryStrategy())) { - continue; - } - - int minFlagsCount = 8; - int minPositionsSet = 4; - - for (var termScore : termScoresForSet) { - final int flagCount = Long.bitCount(termScore.encodedWordMetadata() & flagsFilterMask); - minFlagsCount = Math.min(minFlagsCount, flagCount); - minPositionsSet = Math.min(minPositionsSet, termScore.positionCount()); - } - - maxFlagsCount = Math.max(maxFlagsCount, minFlagsCount); - maxPositionsSet = Math.max(maxPositionsSet, minPositionsSet); - anyAllSynthetic |= synthetic; + wordMetas[i] = termMetadataForCombinedDocumentIds.getTermMetadata(termId, combinedId); + scores[i] = new SearchResultKeywordScore(term, termId, wordMetas[i]); } - if (maxFlagsCount == 0 && !anyAllSynthetic && maxPositionsSet == 0) + + // DANGER: IndexResultValuatorService assumes that searchResult.keywordScores has this specific order, as it needs + // to be able to re-construct its own CompiledQuery for re-ranking the results. This is + // a very flimsy assumption. + searchResult.keywordScores.addAll(List.of(scores)); + + CompiledQueryLong wordMetasQuery = new CompiledQueryLong(compiledQuery.root, new CqDataLong(wordMetas)); + + boolean allSynthetic = !CompiledQueryAggregates.booleanAggregate(wordMetasQuery, WordFlags.Synthetic::isAbsent); + int flagsCount = CompiledQueryAggregates.intMaxMinAggregate(wordMetasQuery, wordMeta -> Long.bitCount(wordMeta & flagsFilterMask)); + int positionsCount = CompiledQueryAggregates.intMaxMinAggregate(wordMetasQuery, wordMeta -> Long.bitCount(WordMetadata.decodePositions(wordMeta))); + + if (!meetsQueryStrategyRequirements(wordMetasQuery, queryParams.queryStrategy())) { + return null; + } + + if (flagsCount == 0 && !allSynthetic && positionsCount == 0) return null; - double score = searchResultValuator.calculateSearchResultValue(searchResult.keywordScores, + double score = searchResultValuator.calculateSearchResultValue( + wordMetasQuery, + docMetadata, + htmlFeatures, 5000, // use a dummy value here as it's not present in the index rankingContext); + if (searchResult.hasPrioTerm) { + score = 0.75 * score; + } + searchResult.setScore(score); return searchResult; } - private boolean meetsQueryStrategyRequirements(SearchResultKeywordScore[] termSet, QueryStrategy queryStrategy) { + private boolean hasPrioTerm(long combinedId) { + for (var term : searchTerms.termIdsPrio.array()) { + if (termMetadataForCombinedDocumentIds.hasTermMeta(term, combinedId)) { + return true; + } + } + return false; + } + + private boolean meetsQueryStrategyRequirements(CompiledQueryLong queryGraphScores, + QueryStrategy queryStrategy) + { if (queryStrategy == QueryStrategy.AUTO || queryStrategy == QueryStrategy.SENTENCE || queryStrategy == QueryStrategy.TOPIC) { return true; } - for (var keyword : termSet) { - if (!meetsQueryStrategyRequirements(keyword, queryParams.queryStrategy())) { - return false; - } - } - - return true; + return CompiledQueryAggregates.booleanAggregate(queryGraphScores, + docs -> meetsQueryStrategyRequirements(docs, queryParams.queryStrategy())); } - private boolean meetsQueryStrategyRequirements(SearchResultKeywordScore termScore, QueryStrategy queryStrategy) { + private boolean meetsQueryStrategyRequirements(long wordMeta, QueryStrategy queryStrategy) { if (queryStrategy == QueryStrategy.REQUIRE_FIELD_SITE) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Site.asBit()); + return WordFlags.Site.isPresent(wordMeta); } else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_SUBJECT) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Subjects.asBit()); + return WordFlags.Subjects.isPresent(wordMeta); } else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_TITLE) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Title.asBit()); + return WordFlags.Title.isPresent(wordMeta); } else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_URL) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.UrlPath.asBit()); + return WordFlags.UrlPath.isPresent(wordMeta); } else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_DOMAIN) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.UrlDomain.asBit()); + return WordFlags.UrlDomain.isPresent(wordMeta); } else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_LINK) { - return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.ExternalLink.asBit()); + return WordFlags.ExternalLink.isPresent(wordMeta); } return true; } diff --git a/code/index/java/nu/marginalia/index/results/IndexResultValuatorService.java b/code/index/java/nu/marginalia/index/results/IndexResultValuatorService.java index 51e59c63..2fa44c31 100644 --- a/code/index/java/nu/marginalia/index/results/IndexResultValuatorService.java +++ b/code/index/java/nu/marginalia/index/results/IndexResultValuatorService.java @@ -4,7 +4,12 @@ import com.google.inject.Inject; import com.google.inject.Singleton; import gnu.trove.list.TLongList; import gnu.trove.list.array.TLongArrayList; -import it.unimi.dsi.fastutil.longs.LongArrayList; +import it.unimi.dsi.fastutil.longs.LongSet; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqDataInt; +import nu.marginalia.api.searchquery.model.compiled.CqDataLong; +import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem; import nu.marginalia.api.searchquery.model.results.ResultRankingContext; import nu.marginalia.api.searchquery.model.results.SearchResultItem; @@ -13,14 +18,13 @@ import nu.marginalia.index.model.SearchParameters; import nu.marginalia.index.results.model.ids.CombinedDocIdList; import nu.marginalia.linkdb.docs.DocumentDbReader; import nu.marginalia.linkdb.model.DocdbUrlDetail; +import nu.marginalia.model.idx.WordMetadata; import nu.marginalia.ranking.results.ResultValuator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.SQLException; import java.util.*; -import java.util.function.Consumer; -import java.util.stream.Collectors; @Singleton public class IndexResultValuatorService { @@ -44,8 +48,8 @@ public class IndexResultValuatorService { } public List rankResults(SearchParameters params, - ResultRankingContext rankingContext, - CombinedDocIdList resultIds) + ResultRankingContext rankingContext, + CombinedDocIdList resultIds) { final var evaluator = createValuationContext(params, rankingContext, resultIds); @@ -70,8 +74,7 @@ public class IndexResultValuatorService { resultIds, statefulIndex, rankingContext, - params.subqueries, - params.queryParams); + params); } @@ -96,12 +99,13 @@ public class IndexResultValuatorService { item.resultsFromDomain = domainCountFilter.getCount(item); } - return decorateAndRerank(resultsList, rankingContext); + return decorateAndRerank(resultsList, params.compiledQuery, rankingContext); } /** Decorate the result items with additional information from the link database * and calculate an updated ranking with the additional information */ public List decorateAndRerank(List rawResults, + CompiledQuery compiledQuery, ResultRankingContext rankingContext) throws SQLException { @@ -125,13 +129,31 @@ public class IndexResultValuatorService { continue; } - resultItems.add(createCombinedItem(result, docData, rankingContext)); + // Reconstruct the compiledquery for re-valuation + // + // CAVEAT: This hinges on a very fragile that IndexResultValuationContext puts them in the same + // order as the data for the CompiledQuery. + long[] wordMetas = new long[compiledQuery.size()]; + + for (int i = 0; i < compiledQuery.size(); i++) { + var score = result.keywordScores.get(i); + wordMetas[i] = score.encodedWordMetadata(); + } + + CompiledQueryLong metaQuery = new CompiledQueryLong(compiledQuery.root, new CqDataLong(wordMetas)); + + resultItems.add(createCombinedItem( + result, + docData, + metaQuery, + rankingContext)); } return resultItems; } private DecoratedSearchResultItem createCombinedItem(SearchResultItem result, DocdbUrlDetail docData, + CompiledQueryLong wordMetas, ResultRankingContext rankingContext) { return new DecoratedSearchResultItem( result, @@ -144,8 +166,33 @@ public class IndexResultValuatorService { docData.pubYear(), docData.dataHash(), docData.wordsTotal(), - resultValuator.calculateSearchResultValue(result.keywordScores, docData.wordsTotal(), rankingContext) - ); + bestPositions(wordMetas), + resultValuator.calculateSearchResultValue(wordMetas, + result.encodedDocMetadata, + result.htmlFeatures, + docData.wordsTotal(), + rankingContext) + ); + } + + private long bestPositions(CompiledQueryLong wordMetas) { + LongSet positionsSet = CompiledQueryAggregates.positionsAggregate(wordMetas, WordMetadata::decodePositions); + + int bestPc = 0; + long bestPositions = 0; + + var li = positionsSet.longIterator(); + + while (li.hasNext()) { + long pos = li.nextLong(); + int pc = Long.bitCount(pos); + if (pc > bestPc) { + bestPc = pc; + bestPositions = pos; + } + } + + return bestPositions; } } diff --git a/code/index/java/nu/marginalia/index/results/model/QuerySearchTerms.java b/code/index/java/nu/marginalia/index/results/model/QuerySearchTerms.java index d72e0ea9..bbb7cf30 100644 --- a/code/index/java/nu/marginalia/index/results/model/QuerySearchTerms.java +++ b/code/index/java/nu/marginalia/index/results/model/QuerySearchTerms.java @@ -6,14 +6,17 @@ import nu.marginalia.index.results.model.ids.TermIdList; public class QuerySearchTerms { private final TObjectLongHashMap termToId; public final TermIdList termIdsAll; + public final TermIdList termIdsPrio; public final TermCoherenceGroupList coherences; public QuerySearchTerms(TObjectLongHashMap termToId, TermIdList termIdsAll, + TermIdList termIdsPrio, TermCoherenceGroupList coherences) { this.termToId = termToId; this.termIdsAll = termIdsAll; + this.termIdsPrio = termIdsPrio; this.coherences = coherences; } diff --git a/code/index/java/nu/marginalia/index/results/model/TermMetadataForCombinedDocumentIds.java b/code/index/java/nu/marginalia/index/results/model/TermMetadataForCombinedDocumentIds.java index 9068dd69..3ef2f7ab 100644 --- a/code/index/java/nu/marginalia/index/results/model/TermMetadataForCombinedDocumentIds.java +++ b/code/index/java/nu/marginalia/index/results/model/TermMetadataForCombinedDocumentIds.java @@ -18,12 +18,21 @@ public class TermMetadataForCombinedDocumentIds { public long getTermMetadata(long termId, long combinedId) { var metaByCombinedId = termdocToMeta.get(termId); if (metaByCombinedId == null) { - logger.warn("Missing meta for term {}", termId); return 0; } return metaByCombinedId.get(combinedId); } + public boolean hasTermMeta(long termId, long combinedId) { + var metaByCombinedId = termdocToMeta.get(termId); + + if (metaByCombinedId == null) { + return false; + } + + return metaByCombinedId.get(combinedId) != 0; + } + public record DocumentsWithMetadata(Long2LongOpenHashMap data) { public DocumentsWithMetadata(CombinedDocIdList combinedDocIdsAll, DocMetadataList metadata) { this(new Long2LongOpenHashMap(combinedDocIdsAll.array(), metadata.array())); diff --git a/code/index/java/nu/marginalia/ranking/results/ResultKeywordSet.java b/code/index/java/nu/marginalia/ranking/results/ResultKeywordSet.java deleted file mode 100644 index 19405dcb..00000000 --- a/code/index/java/nu/marginalia/ranking/results/ResultKeywordSet.java +++ /dev/null @@ -1,26 +0,0 @@ -package nu.marginalia.ranking.results; - - -import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; - -import java.util.List; - -public record ResultKeywordSet(List keywords) { - - public int length() { - return keywords.size(); - } - public boolean isEmpty() { return length() == 0; } - public boolean hasNgram() { - for (var word : keywords) { - if (word.keyword.contains("_")) { - return true; - } - } - return false; - } - @Override - public String toString() { - return "%s[%s]".formatted(getClass().getSimpleName(), keywords); - } -} diff --git a/code/index/java/nu/marginalia/ranking/results/ResultValuator.java b/code/index/java/nu/marginalia/ranking/results/ResultValuator.java index 6c67559d..4d257349 100644 --- a/code/index/java/nu/marginalia/ranking/results/ResultValuator.java +++ b/code/index/java/nu/marginalia/ranking/results/ResultValuator.java @@ -1,8 +1,8 @@ package nu.marginalia.ranking.results; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; import nu.marginalia.api.searchquery.model.results.ResultRankingContext; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; -import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; import nu.marginalia.model.crawl.HtmlFeature; import nu.marginalia.model.crawl.PubDate; import nu.marginalia.model.idx.DocumentFlags; @@ -14,33 +14,32 @@ import com.google.inject.Singleton; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.List; - @Singleton public class ResultValuator { final static double scalingFactor = 500.; - private final Bm25Factor bm25Factor; private final TermCoherenceFactor termCoherenceFactor; private static final Logger logger = LoggerFactory.getLogger(ResultValuator.class); @Inject - public ResultValuator(Bm25Factor bm25Factor, - TermCoherenceFactor termCoherenceFactor) { - this.bm25Factor = bm25Factor; + public ResultValuator(TermCoherenceFactor termCoherenceFactor) { this.termCoherenceFactor = termCoherenceFactor; } - public double calculateSearchResultValue(List scores, + public double calculateSearchResultValue(CompiledQueryLong wordMeta, + long documentMetadata, + int features, int length, ResultRankingContext ctx) { - int sets = numberOfSets(scores); + if (wordMeta.isEmpty()) + return Double.MAX_VALUE; + + if (length < 0) { + length = 5000; + } - long documentMetadata = documentMetadata(scores); - int features = htmlFeatures(scores); var rankingParams = ctx.params; int rank = DocumentMetadata.decodeRank(documentMetadata); @@ -75,32 +74,17 @@ public class ResultValuator { + temporalBias + flagsPenalty; - double bestTcf = 0; - double bestBM25F = 0; - double bestBM25P = 0; - double bestBM25PN = 0; - - for (int set = 0; set < sets; set++) { - ResultKeywordSet keywordSet = createKeywordSet(scores, set); - - if (keywordSet.isEmpty()) - continue; - - bestTcf = Math.max(bestTcf, rankingParams.tcfWeight * termCoherenceFactor.calculate(keywordSet)); - bestBM25P = Math.max(bestBM25P, rankingParams.bm25PrioWeight * bm25Factor.calculateBm25Prio(rankingParams.prioParams, keywordSet, ctx)); - bestBM25F = Math.max(bestBM25F, rankingParams.bm25FullWeight * bm25Factor.calculateBm25(rankingParams.fullParams, keywordSet, length, ctx)); - if (keywordSet.hasNgram()) { - bestBM25PN = Math.max(bestBM25PN, rankingParams.bm25PrioWeight * bm25Factor.calculateBm25Prio(rankingParams.prioParams, keywordSet, ctx)); - } - } + double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta); + double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(new Bm25FullGraphVisitor(rankingParams.fullParams, wordMeta.data, length, ctx)); + double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx)); double overallPartPositive = Math.max(0, overallPart); double overallPartNegative = -Math.min(0, overallPart); // Renormalize to 0...15, where 0 is the best possible score; // this is a historical artifact of the original ranking function - return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + 0.25 * bestBM25PN + overallPartPositive, overallPartNegative); + return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + overallPartPositive, overallPartNegative); } private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) { @@ -159,51 +143,6 @@ public class ResultValuator { return (int) -penalty; } - private long documentMetadata(List rawScores) { - for (var score : rawScores) { - return score.encodedDocMetadata(); - } - return 0; - } - - private int htmlFeatures(List rawScores) { - for (var score : rawScores) { - return score.htmlFeatures(); - } - return 0; - } - - private ResultKeywordSet createKeywordSet(List rawScores, - int thisSet) - { - List scoresList = new ArrayList<>(); - - for (var score : rawScores) { - if (score.subquery != thisSet) - continue; - - // Don't consider synthetic keywords for ranking, these are keywords that don't - // have counts. E.g. "tld:edu" - if (score.isKeywordSpecial()) - continue; - - scoresList.add(score); - } - - return new ResultKeywordSet(scoresList); - - } - - private int numberOfSets(List scores) { - int maxSet = 0; - - for (var score : scores) { - maxSet = Math.max(maxSet, score.subquery); - } - - return 1 + maxSet; - } - public static double normalize(double value, double penalty) { if (value < 0) value = 0; diff --git a/code/index/java/nu/marginalia/ranking/results/factors/Bm25Factor.java b/code/index/java/nu/marginalia/ranking/results/factors/Bm25Factor.java deleted file mode 100644 index 335b5fa8..00000000 --- a/code/index/java/nu/marginalia/ranking/results/factors/Bm25Factor.java +++ /dev/null @@ -1,122 +0,0 @@ -package nu.marginalia.ranking.results.factors; - -import nu.marginalia.api.searchquery.model.results.Bm25Parameters; -import nu.marginalia.api.searchquery.model.results.ResultRankingContext; -import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; -import nu.marginalia.model.idx.WordFlags; -import nu.marginalia.ranking.results.ResultKeywordSet; - -public class Bm25Factor { - private static final int AVG_LENGTH = 5000; - - /** This is an estimation of BM-25. - * - * @see Bm25Parameters - */ - public double calculateBm25(Bm25Parameters bm25Parameters, ResultKeywordSet keywordSet, int length, ResultRankingContext ctx) { - final int docCount = ctx.termFreqDocCount(); - - if (length <= 0) - length = AVG_LENGTH; - - double sum = 0.; - - for (var keyword : keywordSet.keywords()) { - double count = keyword.positionCount(); - - int freq = ctx.frequency(keyword.keyword); - - sum += invFreq(docCount, freq) * f(bm25Parameters.k(), bm25Parameters.b(), count, length); - } - - return sum; - } - - /** Bm25 calculation, except instead of counting positions in the document, - * the number of relevance signals for the term is counted instead. - */ - public double calculateBm25Prio(Bm25Parameters bm25Parameters, ResultKeywordSet keywordSet, ResultRankingContext ctx) { - final int docCount = ctx.termFreqDocCount(); - - double sum = 0.; - - for (var keyword : keywordSet.keywords()) { - double count = evaluatePriorityScore(keyword); - - int freq = ctx.priorityFrequency(keyword.keyword); - - // note we override b to zero for priority terms as they are independent of document length - sum += invFreq(docCount, freq) * f(bm25Parameters.k(), 0, count, 0); - } - - return sum; - } - - private static double evaluatePriorityScore(SearchResultKeywordScore keyword) { - int pcount = keyword.positionCount(); - - double qcount = 0.; - - if ((keyword.encodedWordMetadata() & WordFlags.ExternalLink.asBit()) != 0) { - - qcount += 2.5; - - if ((keyword.encodedWordMetadata() & WordFlags.UrlDomain.asBit()) != 0) - qcount += 2.5; - else if ((keyword.encodedWordMetadata() & WordFlags.UrlPath.asBit()) != 0) - qcount += 1.5; - - if ((keyword.encodedWordMetadata() & WordFlags.Site.asBit()) != 0) - qcount += 1.25; - if ((keyword.encodedWordMetadata() & WordFlags.SiteAdjacent.asBit()) != 0) - qcount += 1.25; - } - else { - if ((keyword.encodedWordMetadata() & WordFlags.UrlDomain.asBit()) != 0) - qcount += 3; - else if ((keyword.encodedWordMetadata() & WordFlags.UrlPath.asBit()) != 0) - qcount += 1; - - if ((keyword.encodedWordMetadata() & WordFlags.Site.asBit()) != 0) - qcount += 0.5; - if ((keyword.encodedWordMetadata() & WordFlags.SiteAdjacent.asBit()) != 0) - qcount += 0.5; - } - - if ((keyword.encodedWordMetadata() & WordFlags.Title.asBit()) != 0) - qcount += 1.5; - - if (pcount > 2) { - if ((keyword.encodedWordMetadata() & WordFlags.Subjects.asBit()) != 0) - qcount += 1.25; - if ((keyword.encodedWordMetadata() & WordFlags.NamesWords.asBit()) != 0) - qcount += 0.25; - if ((keyword.encodedWordMetadata() & WordFlags.TfIdfHigh.asBit()) != 0) - qcount += 0.5; - } - - return qcount; - } - - /** - * - * @param docCount Number of documents - * @param freq Number of matching documents - */ - private double invFreq(int docCount, int freq) { - return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5)); - } - - /** - * - * @param k determines the size of the impact of a single term - * @param b determines the magnitude of the length normalization - * @param count number of occurrences in the document - * @param length document length - */ - private double f(double k, double b, double count, int length) { - final double lengthRatio = (double) length / AVG_LENGTH; - - return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio)); - } -} diff --git a/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java b/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java new file mode 100644 index 00000000..9c46261d --- /dev/null +++ b/code/index/java/nu/marginalia/ranking/results/factors/Bm25FullGraphVisitor.java @@ -0,0 +1,81 @@ +package nu.marginalia.ranking.results.factors; + +import nu.marginalia.api.searchquery.model.compiled.CqDataInt; +import nu.marginalia.api.searchquery.model.compiled.CqDataLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; +import nu.marginalia.api.searchquery.model.results.Bm25Parameters; +import nu.marginalia.api.searchquery.model.results.ResultRankingContext; +import nu.marginalia.model.idx.WordMetadata; + +import java.util.List; + +public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor { + private static final long AVG_LENGTH = 5000; + + private final CqDataLong wordMetaData; + private final CqDataInt frequencies; + private final Bm25Parameters bm25Parameters; + + private final int docCount; + private final int length; + + public Bm25FullGraphVisitor(Bm25Parameters bm25Parameters, + CqDataLong wordMetaData, + int length, + ResultRankingContext ctx) { + this.length = length; + this.bm25Parameters = bm25Parameters; + this.docCount = ctx.termFreqDocCount(); + this.wordMetaData = wordMetaData; + this.frequencies = ctx.fullCounts; + } + + @Override + public double onAnd(List parts) { + double value = 0; + for (var part : parts) { + value += part.visit(this); + } + return value; + } + + @Override + public double onOr(List parts) { + double value = 0; + for (var part : parts) { + value = Math.max(value, part.visit(this)); + } + return value; + } + + @Override + public double onLeaf(int idx) { + double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx))); + + int freq = frequencies.get(idx); + + return invFreq(docCount, freq) * f(bm25Parameters.k(), bm25Parameters.b(), count, length); + } + + /** + * + * @param docCount Number of documents + * @param freq Number of matching documents + */ + private double invFreq(int docCount, int freq) { + return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5)); + } + + /** + * + * @param k determines the size of the impact of a single term + * @param b determines the magnitude of the length normalization + * @param count number of occurrences in the document + * @param length document length + */ + private double f(double k, double b, double count, int length) { + final double lengthRatio = (double) length / AVG_LENGTH; + + return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio)); + } +} diff --git a/code/index/java/nu/marginalia/ranking/results/factors/Bm25PrioGraphVisitor.java b/code/index/java/nu/marginalia/ranking/results/factors/Bm25PrioGraphVisitor.java new file mode 100644 index 00000000..1fb26f6b --- /dev/null +++ b/code/index/java/nu/marginalia/ranking/results/factors/Bm25PrioGraphVisitor.java @@ -0,0 +1,127 @@ +package nu.marginalia.ranking.results.factors; + +import nu.marginalia.api.searchquery.model.compiled.CqDataInt; +import nu.marginalia.api.searchquery.model.compiled.CqDataLong; +import nu.marginalia.api.searchquery.model.compiled.CqExpression; +import nu.marginalia.api.searchquery.model.results.Bm25Parameters; +import nu.marginalia.api.searchquery.model.results.ResultRankingContext; +import nu.marginalia.model.idx.WordFlags; +import nu.marginalia.model.idx.WordMetadata; + +import java.util.List; + +public class Bm25PrioGraphVisitor implements CqExpression.DoubleVisitor { + private static final long AVG_LENGTH = 5000; + + private final CqDataLong wordMetaData; + private final CqDataInt frequencies; + private final Bm25Parameters bm25Parameters; + + private final int docCount; + + public Bm25PrioGraphVisitor(Bm25Parameters bm25Parameters, + CqDataLong wordMetaData, + ResultRankingContext ctx) { + this.bm25Parameters = bm25Parameters; + this.docCount = ctx.termFreqDocCount(); + this.wordMetaData = wordMetaData; + this.frequencies = ctx.fullCounts; + } + + @Override + public double onAnd(List parts) { + double value = 0; + for (var part : parts) { + value += part.visit(this); + } + return value; + } + + @Override + public double onOr(List parts) { + double value = 0; + for (var part : parts) { + value = Math.max(value, part.visit(this)); + } + return value; + } + + @Override + public double onLeaf(int idx) { + double count = evaluatePriorityScore(wordMetaData.get(idx)); + + int freq = frequencies.get(idx); + + // note we override b to zero for priority terms as they are independent of document length + return invFreq(docCount, freq) * f(bm25Parameters.k(), 0, count, 0); + } + + private static double evaluatePriorityScore(long wordMeta) { + int pcount = Long.bitCount(WordMetadata.decodePositions(wordMeta)); + + double qcount = 0.; + + if ((wordMeta & WordFlags.ExternalLink.asBit()) != 0) { + + qcount += 2.5; + + if ((wordMeta & WordFlags.UrlDomain.asBit()) != 0) + qcount += 2.5; + else if ((wordMeta & WordFlags.UrlPath.asBit()) != 0) + qcount += 1.5; + + if ((wordMeta & WordFlags.Site.asBit()) != 0) + qcount += 1.25; + if ((wordMeta & WordFlags.SiteAdjacent.asBit()) != 0) + qcount += 1.25; + } + else { + if ((wordMeta & WordFlags.UrlDomain.asBit()) != 0) + qcount += 3; + else if ((wordMeta & WordFlags.UrlPath.asBit()) != 0) + qcount += 1; + + if ((wordMeta & WordFlags.Site.asBit()) != 0) + qcount += 0.5; + if ((wordMeta & WordFlags.SiteAdjacent.asBit()) != 0) + qcount += 0.5; + } + + if ((wordMeta & WordFlags.Title.asBit()) != 0) + qcount += 1.5; + + if (pcount > 2) { + if ((wordMeta & WordFlags.Subjects.asBit()) != 0) + qcount += 1.25; + if ((wordMeta & WordFlags.NamesWords.asBit()) != 0) + qcount += 0.25; + if ((wordMeta & WordFlags.TfIdfHigh.asBit()) != 0) + qcount += 0.5; + } + + return qcount; + } + + + /** + * + * @param docCount Number of documents + * @param freq Number of matching documents + */ + private double invFreq(int docCount, int freq) { + return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5)); + } + + /** + * + * @param k determines the size of the impact of a single term + * @param b determines the magnitude of the length normalization + * @param count number of occurrences in the document + * @param length document length + */ + private double f(double k, double b, double count, int length) { + final double lengthRatio = (double) length / AVG_LENGTH; + + return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio)); + } +} diff --git a/code/index/java/nu/marginalia/ranking/results/factors/TermCoherenceFactor.java b/code/index/java/nu/marginalia/ranking/results/factors/TermCoherenceFactor.java index f956ce88..e617549d 100644 --- a/code/index/java/nu/marginalia/ranking/results/factors/TermCoherenceFactor.java +++ b/code/index/java/nu/marginalia/ranking/results/factors/TermCoherenceFactor.java @@ -1,14 +1,16 @@ package nu.marginalia.ranking.results.factors; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; import nu.marginalia.model.idx.WordMetadata; -import nu.marginalia.ranking.results.ResultKeywordSet; /** Rewards documents where terms appear frequently within the same sentences */ public class TermCoherenceFactor { - public double calculate(ResultKeywordSet keywordSet) { - long mask = combinedMask(keywordSet); + public double calculate(CompiledQueryLong wordMetadataQuery) { + long mask = CompiledQueryAggregates.longBitmaskAggregate(wordMetadataQuery, + score -> score >>> WordMetadata.POSITIONS_SHIFT); return bitsSetFactor(mask); } @@ -19,14 +21,5 @@ public class TermCoherenceFactor { return Math.pow(bitsSetInMask/(float) WordMetadata.POSITIONS_COUNT, 0.25); } - long combinedMask(ResultKeywordSet keywordSet) { - long mask = WordMetadata.POSITIONS_MASK; - - for (var keyword : keywordSet.keywords()) { - mask &= keyword.positions(); - } - - return mask; - } } \ No newline at end of file diff --git a/code/index/query/java/nu/marginalia/index/query/IndexQueryBuilder.java b/code/index/query/java/nu/marginalia/index/query/IndexQueryBuilder.java index 68a88625..855309fa 100644 --- a/code/index/query/java/nu/marginalia/index/query/IndexQueryBuilder.java +++ b/code/index/query/java/nu/marginalia/index/query/IndexQueryBuilder.java @@ -2,6 +2,8 @@ package nu.marginalia.index.query; import nu.marginalia.index.query.filter.QueryFilterStepIf; +import java.util.List; + /** Builds a query. *

* Note: The query builder may omit predicates that are deemed redundant. @@ -9,18 +11,14 @@ import nu.marginalia.index.query.filter.QueryFilterStepIf; public interface IndexQueryBuilder { /** Filters documents that also contain termId, within the full index. */ - IndexQueryBuilder alsoFull(long termId); - - /** - * Filters documents that also contain the termId, within the priority index. - */ - IndexQueryBuilder alsoPrio(long termIds); + IndexQueryBuilder also(long termId); /** Excludes documents that contain termId, within the full index */ - IndexQueryBuilder notFull(long termId); + IndexQueryBuilder not(long termId); IndexQueryBuilder addInclusionFilter(QueryFilterStepIf filterStep); + IndexQueryBuilder addInclusionFilterAny(List filterStep); IndexQuery build(); } diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java new file mode 100644 index 00000000..e9725179 --- /dev/null +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAllOf.java @@ -0,0 +1,71 @@ +package nu.marginalia.index.query.filter; + +import nu.marginalia.array.buffer.LongQueryBuffer; + +import java.util.ArrayList; +import java.util.List; +import java.util.StringJoiner; + +public class QueryFilterAllOf implements QueryFilterStepIf { + private final List steps; + + public QueryFilterAllOf(List steps) { + this.steps = new ArrayList<>(steps.size()); + + for (var step : steps) { + if (step instanceof QueryFilterAllOf allOf) { + this.steps.addAll(allOf.steps); + } + else { + this.steps.add(step); + } + } + } + + public QueryFilterAllOf(QueryFilterStepIf... steps) { + this(List.of(steps)); + } + + public double cost() { + double prod = 1.; + + for (var step : steps) { + double cost = step.cost(); + if (cost > 1.0) { + prod *= Math.log(cost); + } + else { + prod += cost; + } + } + + return prod; + } + + @Override + public boolean test(long value) { + for (var step : steps) { + if (!step.test(value)) + return false; + } + return true; + } + + + public void apply(LongQueryBuffer buffer) { + if (steps.isEmpty()) + return; + + for (var step : steps) { + step.apply(buffer); + } + } + + public String describe() { + StringJoiner sj = new StringJoiner(",", "[All Of: ", "]"); + for (var step : steps) { + sj.add(step.describe()); + } + return sj.toString(); + } +} diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java index c9ee2c6e..bea62194 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterAnyOf.java @@ -2,19 +2,31 @@ package nu.marginalia.index.query.filter; import nu.marginalia.array.buffer.LongQueryBuffer; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.StringJoiner; public class QueryFilterAnyOf implements QueryFilterStepIf { - private final List steps; + private final List steps; public QueryFilterAnyOf(List steps) { - this.steps = steps; + this.steps = new ArrayList<>(steps.size()); + + for (var step : steps) { + if (step instanceof QueryFilterAnyOf anyOf) { + this.steps.addAll(anyOf.steps); + } else { + this.steps.add(step); + } + } + } + + public QueryFilterAnyOf(QueryFilterStepIf... steps) { + this(List.of(steps)); } public double cost() { - return steps.stream().mapToDouble(QueryFilterStepIf::cost).average().orElse(0.); + return steps.stream().mapToDouble(QueryFilterStepIf::cost).sum(); } @Override @@ -31,31 +43,37 @@ public class QueryFilterAnyOf implements QueryFilterStepIf { if (steps.isEmpty()) return; - int start; - int end = buffer.end; - - steps.getFirst().apply(buffer); - - // The filter functions will partition the data in the buffer from 0 to END, - // and update END to the length of the retained items, keeping the retained - // items sorted but making no guarantees about the rejected half - // - // Therefore, we need to re-sort the rejected side, and to satisfy the - // constraint that the data is sorted up to END, finally sort it again. - // - // This sorting may seem like it's slower, but filter.apply(...) is - // typically much faster than iterating over filter.test(...); so this - // is more than made up for - - for (int fi = 1; fi < steps.size(); fi++) - { - start = buffer.end; - Arrays.sort(buffer.data, start, end); - buffer.startFilterForRange(start, end); - steps.get(fi).apply(buffer); + if (steps.size() == 1) { + steps.getFirst().apply(buffer); + return; } - Arrays.sort(buffer.data, 0, buffer.end); + int start = 0; + final int endOfValidData = buffer.end; // End of valid data range + + // The filters act as a partitioning function, where anything before buffer.end + // is "in", and is guaranteed to be sorted; and anything after buffer.end is "out" + // but no sorting guaranteed is provided. + + // To provide a conditional filter, we re-sort the "out" range, slice it and apply filtering to the slice + + for (var step : steps) + { + var slice = buffer.slice(start, endOfValidData); + slice.data.quickSort(0, slice.size()); + + step.apply(slice); + start += slice.end; + } + + // After we're done, read and write pointers should be 0 and "end" should be the length of valid data, + // normally done through buffer.finalizeFiltering(); but that won't work here + buffer.reset(); + buffer.end = start; + + // After all filters have been applied, we must re-sort all the retained data + // to uphold the sortedness contract + buffer.data.quickSort(0, buffer.end); } public String describe() { diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterLetThrough.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterLetThrough.java index ed02dd6d..77f503cf 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterLetThrough.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterLetThrough.java @@ -16,7 +16,7 @@ public class QueryFilterLetThrough implements QueryFilterStepIf { } public double cost() { - return 0.; + return 1.; } public String describe() { diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterNoPass.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterNoPass.java index 1bcd04ae..502e7c4c 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterNoPass.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterNoPass.java @@ -15,7 +15,7 @@ public class QueryFilterNoPass implements QueryFilterStepIf { } public double cost() { - return 0.; + return 1.; } public String describe() { diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepExcludeFromPredicate.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepExcludeFromPredicate.java index 92c8c972..0d715863 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepExcludeFromPredicate.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepExcludeFromPredicate.java @@ -16,7 +16,7 @@ public class QueryFilterStepExcludeFromPredicate implements QueryFilterStepIf { @Override public double cost() { - return 0; + return 1; } @Override diff --git a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepFromPredicate.java b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepFromPredicate.java index 56f08b71..9cd51d7a 100644 --- a/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepFromPredicate.java +++ b/code/index/query/java/nu/marginalia/index/query/filter/QueryFilterStepFromPredicate.java @@ -16,7 +16,7 @@ public class QueryFilterStepFromPredicate implements QueryFilterStepIf { @Override public double cost() { - return 0; + return 1; } @Override diff --git a/code/index/query/test/nu/marginalia/index/query/filter/QueryFilterStepIfTest.java b/code/index/query/test/nu/marginalia/index/query/filter/QueryFilterStepIfTest.java index a7450b11..b2ef1bdb 100644 --- a/code/index/query/test/nu/marginalia/index/query/filter/QueryFilterStepIfTest.java +++ b/code/index/query/test/nu/marginalia/index/query/filter/QueryFilterStepIfTest.java @@ -55,6 +55,32 @@ class QueryFilterStepIfTest { assertArrayEquals(new long[]{8, 10}, buffer.copyData()); } + @Test + public void testSuccessiveApplicationWithAllOf() { + var buffer = createBuffer(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + var filter1 = new QueryFilterStepFromPredicate(value -> value % 2 == 0); + var filter2 = new QueryFilterStepExcludeFromPredicate(value -> value <= 6); + new QueryFilterAllOf(List.of(filter1, filter2)).apply(buffer); + assertArrayEquals(new long[]{8, 10}, buffer.copyData()); + } + @Test + public void testCombinedOrAnd() { + var buffer = createBuffer(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + var filter1 = new QueryFilterStepFromPredicate(value -> value % 2 == 0); + var filter2 = new QueryFilterStepFromPredicate(value -> value <= 5); + var filter1_2 = new QueryFilterAllOf(List.of(filter1, filter2)); + + var filter3 = new QueryFilterStepFromPredicate(value -> value % 2 == 1); + var filter4 = new QueryFilterStepFromPredicate(value -> value > 5); + var filter3_4 = new QueryFilterAllOf(List.of(filter3, filter4)); + + var filter12_34 = new QueryFilterAnyOf(List.of(filter1_2, filter3_4)); + + filter12_34.apply(buffer); + + assertArrayEquals(new long[]{2, 4, 7, 9}, buffer.copyData()); + } @Test public void testCombinedApplication() { var buffer = createBuffer(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); diff --git a/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationSmokeTest.java b/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationSmokeTest.java index 634481f4..7b0a6a24 100644 --- a/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationSmokeTest.java +++ b/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationSmokeTest.java @@ -5,7 +5,7 @@ import com.google.inject.Inject; import lombok.SneakyThrows; import nu.marginalia.IndexLocations; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.index.StatefulIndex; import nu.marginalia.process.control.FakeProcessHeartbeat; @@ -123,9 +123,10 @@ public class IndexQueryServiceIntegrationSmokeTest { .rankingParams(ResultRankingParameters.sensibleDefaults()) .domains(new ArrayList<>()) .searchSetIdentifier("NONE") - .subqueries(List.of(new SearchSubquery( + .query(new SearchQuery( + "2 3 5", List.of("3", "5", "2"), List.of("4"), Collections.emptyList(), Collections.emptyList(), - Collections.emptyList()))).build()); + Collections.emptyList())).build()); int[] idxes = new int[] { 30, 510, 90, 150, 210, 270, 330, 390, 450 }; long[] ids = IntStream.of(idxes).mapToLong(this::fullId).toArray(); @@ -166,9 +167,13 @@ public class IndexQueryServiceIntegrationSmokeTest { .rankingParams(ResultRankingParameters.sensibleDefaults()) .queryStrategy(QueryStrategy.SENTENCE) .domains(List.of(2)) - .subqueries(List.of(new SearchSubquery( - List.of("3", "5", "2"), List.of("4"), Collections.emptyList(), Collections.emptyList(), - Collections.emptyList()))).build()); + .query(new SearchQuery( + "2 3 5", + List.of("3", "5", "2"), + List.of("4"), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList())).build()); int[] idxes = new int[] { 210, 270 }; long[] ids = IntStream.of(idxes).mapToLong(id -> UrlIdCodec.encodeId(id/100, id)).toArray(); long[] actual = rsp.results.stream().mapToLong(i -> i.rawIndexResult.getDocumentId()).toArray(); @@ -202,18 +207,15 @@ public class IndexQueryServiceIntegrationSmokeTest { .queryStrategy(QueryStrategy.SENTENCE) .searchSetIdentifier("NONE") .rankingParams(ResultRankingParameters.sensibleDefaults()) - .subqueries(List.of(new SearchSubquery( - List.of("4"), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), - Collections.emptyList())) + .query( + new SearchQuery("4", List.of("4"), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList()) ).build()); Set years = new HashSet<>(); for (var res : rsp.results) { - for (var score : res.rawIndexResult.getKeywordScores()) { - years.add(DocumentMetadata.decodeYear(score.encodedDocMetadata())); - } + years.add(DocumentMetadata.decodeYear(res.rawIndexResult.encodedDocMetadata)); } assertEquals(Set.of(1998), years); diff --git a/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationTest.java b/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationTest.java index 6def5bbc..e29f8751 100644 --- a/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationTest.java +++ b/code/index/test/nu/marginalia/index/IndexQueryServiceIntegrationTest.java @@ -4,7 +4,7 @@ import com.google.inject.Guice; import com.google.inject.Inject; import nu.marginalia.IndexLocations; import nu.marginalia.api.searchquery.model.query.SearchSpecification; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.index.StatefulIndex; import nu.marginalia.storage.FileStorageService; @@ -35,6 +35,7 @@ import nu.marginalia.process.control.ProcessHeartbeat; import nu.marginalia.index.domainrankings.DomainRankings; import nu.marginalia.service.control.ServiceHeartbeat; import nu.marginalia.service.server.Initialization; +import org.apache.logging.log4j.util.Strings; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -108,7 +109,7 @@ public class IndexQueryServiceIntegrationTest { w("world", WordFlags.Title) ).load(); - var query = basicQuery(builder -> builder.subqueries(justInclude("hello", "world"))); + var query = basicQuery(builder -> builder.query(justInclude("hello", "world"))); executeSearch(query) .expectDocumentsInOrder(d(1,1)); @@ -127,57 +128,51 @@ public class IndexQueryServiceIntegrationTest { ).load(); var queryMissingExclude = basicQuery(builder -> - builder.subqueries(includeAndExclude("hello", "missing"))); + builder.query(includeAndExclude("hello", "missing"))); executeSearch(queryMissingExclude) .expectDocumentsInOrder(d(1,1)); var queryMissingInclude = basicQuery(builder -> - builder.subqueries(justInclude("missing"))); + builder.query(justInclude("missing"))); executeSearch(queryMissingInclude) .expectCount(0); var queryMissingPriority = basicQuery(builder -> - builder.subqueries( - List.of( - new SearchSubquery( - List.of("hello"), - List.of(), - List.of(), - List.of("missing"), - List.of() - ) - ))); + builder.query(new SearchQuery( + "hello", + List.of("hello"), + List.of(), + List.of(), + List.of("missing"), + List.of()) + )); executeSearch(queryMissingPriority) .expectCount(1); var queryMissingAdvice = basicQuery(builder -> - builder.subqueries( - List.of( - new SearchSubquery( - List.of("hello"), - List.of(), - List.of("missing"), - List.of(), - List.of() - ) + builder.query( + new SearchQuery("hello", + List.of("hello"), + List.of(), + List.of("missing"), + List.of(), + List.of() ))); executeSearch(queryMissingAdvice) .expectCount(0); var queryMissingCoherence = basicQuery(builder -> - builder.subqueries( - List.of( - new SearchSubquery( - List.of("hello"), - List.of(), - List.of(), - List.of(), - List.of(List.of("missing", "hello")) - ) + builder.query( + new SearchQuery("hello", + List.of("hello"), + List.of(), + List.of(), + List.of(), + List.of(List.of("missing", "hello")) ))); executeSearch(queryMissingCoherence) @@ -202,7 +197,7 @@ public class IndexQueryServiceIntegrationTest { ).load(); - var query = basicQuery(builder -> builder.subqueries(justInclude("hello", "world"))); + var query = basicQuery(builder -> builder.query(justInclude("hello", "world"))); executeSearch(query) .expectDocumentsInOrder(d(1,1)); @@ -234,15 +229,15 @@ public class IndexQueryServiceIntegrationTest { var beforeY2K = basicQuery(builder -> - builder.subqueries(justInclude("hello", "world")) + builder.query(justInclude("hello", "world")) .year(SpecificationLimit.lessThan(2000)) ); var atY2K = basicQuery(builder -> - builder.subqueries(justInclude("hello", "world")) + builder.query(justInclude("hello", "world")) .year(SpecificationLimit.equals(2000)) ); var afterY2K = basicQuery(builder -> - builder.subqueries(justInclude("hello", "world")) + builder.query(justInclude("hello", "world")) .year(SpecificationLimit.greaterThan(2000)) ); @@ -296,11 +291,11 @@ public class IndexQueryServiceIntegrationTest { var domain1 = basicQuery(builder -> - builder.subqueries(justInclude("hello", "world")) + builder.query(justInclude("hello", "world")) .domains(List.of(1)) ); var domain2 = basicQuery(builder -> - builder.subqueries(justInclude("hello", "world")) + builder.query(justInclude("hello", "world")) .domains(List.of(2)) ); @@ -334,7 +329,7 @@ public class IndexQueryServiceIntegrationTest { ).load(); var query = basicQuery(builder -> - builder.subqueries(includeAndExclude("hello", "my_darling")) + builder.query(includeAndExclude("hello", "my_darling")) ); executeSearch(query) @@ -403,7 +398,7 @@ public class IndexQueryServiceIntegrationTest { .load(); var rsp = queryService.justQuery( - basicQuery(builder -> builder.subqueries( + basicQuery(builder -> builder.query( // note coherence requriement includeAndCohere("hello", "world") ))); @@ -424,50 +419,53 @@ public class IndexQueryServiceIntegrationTest { .rank(SpecificationLimit.none()) .rankingParams(ResultRankingParameters.sensibleDefaults()) .domains(new ArrayList<>()) - .searchSetIdentifier("NONE") - .subqueries(List.of()); + .searchSetIdentifier("NONE"); return mutator.apply(builder).build(); } - List justInclude(String... includes) { - return List.of(new SearchSubquery( + SearchQuery justInclude(String... includes) { + return new SearchQuery( + Strings.join(List.of(includes), ' '), List.of(includes), List.of(), List.of(), List.of(), List.of() - )); + ); } - List includeAndExclude(List includes, List excludes) { - return List.of(new SearchSubquery( + SearchQuery includeAndExclude(List includes, List excludes) { + return new SearchQuery( + Strings.join(List.of(includes), ' '), includes, excludes, List.of(), List.of(), List.of() - )); + ); } - List includeAndExclude(String include, String exclude) { - return List.of(new SearchSubquery( + SearchQuery includeAndExclude(String include, String exclude) { + return new SearchQuery( + include, List.of(include), List.of(exclude), List.of(), List.of(), List.of() - )); + ); } - List includeAndCohere(String... includes) { - return List.of(new SearchSubquery( + SearchQuery includeAndCohere(String... includes) { + return new SearchQuery( + Strings.join(List.of(includes), ' '), List.of(includes), List.of(), List.of(), List.of(), List.of(List.of(includes)) - )); + ); } private MockDataDocument d(int domainId, int ordinal) { return new MockDataDocument(domainId, ordinal); diff --git a/code/index/test/nu/marginalia/index/index/QueryBranchWalkerTest.java b/code/index/test/nu/marginalia/index/index/QueryBranchWalkerTest.java new file mode 100644 index 00000000..8d2f45c8 --- /dev/null +++ b/code/index/test/nu/marginalia/index/index/QueryBranchWalkerTest.java @@ -0,0 +1,59 @@ +package nu.marginalia.index.index; + +import it.unimi.dsi.fastutil.longs.LongArraySet; +import it.unimi.dsi.fastutil.longs.LongSet; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + +class QueryBranchWalkerTest { + @Test + public void testNoOverlap() { + var paths = QueryBranchWalker.create( + new long[] { 1, 2 }, + List.of(set(1), set(2)) + ); + assertEquals(2, paths.size()); + assertEquals(Set.of(1L, 2L), paths.stream().map(path -> path.termId).collect(Collectors.toSet())); + } + + @Test + public void testCond() { + var paths = QueryBranchWalker.create( + new long[] { 1, 2, 3, 4 }, + List.of(set(1,2,3), set(1,4,3)) + ); + assertEquals(1, paths.size()); + assertEquals(Set.of(1L), paths.stream().map(path -> path.termId).collect(Collectors.toSet())); + System.out.println(Arrays.toString(paths.getFirst().priorityOrder)); + assertArrayEquals(new long[] { 2, 3, 4 }, paths.getFirst().priorityOrder); + + var next = paths.getFirst().next(); + assertEquals(2, next.size()); + assertEquals(Set.of(2L, 3L), next.stream().map(path -> path.termId).collect(Collectors.toSet())); + Map byId = next.stream().collect(Collectors.toMap(w -> w.termId, w->w)); + assertArrayEquals(new long[] { 3L }, byId.get(2L).priorityOrder ); + assertArrayEquals(new long[] { 4L }, byId.get(3L).priorityOrder ); + } + + @Test + public void testNoOverlapFirst() { + var paths = QueryBranchWalker.create( + new long[] { 1, 2, 3 }, + List.of(set(1, 2), set(1, 3)) + ); + assertEquals(1, paths.size()); + assertArrayEquals(new long[] { 2, 3 }, paths.getFirst().priorityOrder); + assertEquals(Set.of(1L), paths.stream().map(path -> path.termId).collect(Collectors.toSet())); + } + + LongSet set(long... args) { + return new LongArraySet(args); + } +} \ No newline at end of file diff --git a/code/index/test/nu/marginalia/index/results/IndexResultDomainDeduplicatorTest.java b/code/index/test/nu/marginalia/index/results/IndexResultDomainDeduplicatorTest.java index 4f5a12cd..21f6312e 100644 --- a/code/index/test/nu/marginalia/index/results/IndexResultDomainDeduplicatorTest.java +++ b/code/index/test/nu/marginalia/index/results/IndexResultDomainDeduplicatorTest.java @@ -2,9 +2,10 @@ package nu.marginalia.index.results; import nu.marginalia.api.searchquery.model.results.SearchResultItem; import nu.marginalia.model.id.UrlIdCodec; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import java.util.List; + import static org.junit.jupiter.api.Assertions.*; class IndexResultDomainDeduplicatorTest { @@ -24,7 +25,7 @@ class IndexResultDomainDeduplicatorTest { } SearchResultItem forId(int domain, int ordinal) { - return new SearchResultItem(UrlIdCodec.encodeId(domain, ordinal), 4); + return new SearchResultItem(UrlIdCodec.encodeId(domain, ordinal), 0, 0, List.of(), 4, Double.NaN, false); } } \ No newline at end of file diff --git a/code/index/test/nu/marginalia/ranking/results/ResultValuatorTest.java b/code/index/test/nu/marginalia/ranking/results/ResultValuatorTest.java index 8f8f7eaa..a1b66b04 100644 --- a/code/index/test/nu/marginalia/ranking/results/ResultValuatorTest.java +++ b/code/index/test/nu/marginalia/ranking/results/ResultValuatorTest.java @@ -1,5 +1,8 @@ package nu.marginalia.ranking.results; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong; +import nu.marginalia.api.searchquery.model.compiled.CqDataInt; import nu.marginalia.api.searchquery.model.results.ResultRankingContext; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; @@ -30,30 +33,27 @@ class ResultValuatorTest { when(dict.docCount()).thenReturn(100_000); valuator = new ResultValuator( - new Bm25Factor(), new TermCoherenceFactor() ); } - List titleOnlyLowCountSet = List.of( - new SearchResultKeywordScore(0, "bob", - wordMetadata(Set.of(1), EnumSet.of(WordFlags.Title)), - docMetadata(0, 2010, 5, EnumSet.noneOf(DocumentFlags.class)), - 0) - ); - List highCountNoTitleSet = List.of( - new SearchResultKeywordScore(0, "bob", - wordMetadata(Set.of(1,3,4,6,7,9,10,11,12,14,15,16), EnumSet.of(WordFlags.TfIdfHigh)), - docMetadata(0, 2010, 5, EnumSet.noneOf(DocumentFlags.class)), - 0) - ); - List highCountSubjectSet = List.of( - new SearchResultKeywordScore(0, "bob", - wordMetadata(Set.of(1,3,4,6,7,9,10,11,12,14,15,16), EnumSet.of(WordFlags.TfIdfHigh, WordFlags.Subjects)), - docMetadata(0, 2010, 5, EnumSet.noneOf(DocumentFlags.class)), - 0) - ); + CqDataInt frequencyData = new CqDataInt(new int[] { 10 }); + + CompiledQueryLong titleOnlyLowCountSet = CompiledQuery.just( + new SearchResultKeywordScore("bob", 1, + wordMetadata(Set.of(1), EnumSet.of(WordFlags.Title))) + ).mapToLong(SearchResultKeywordScore::encodedWordMetadata); + + CompiledQueryLong highCountNoTitleSet = CompiledQuery.just( + new SearchResultKeywordScore("bob", 1, + wordMetadata(Set.of(1,3,4,6,7,9,10,11,12,14,15,16), EnumSet.of(WordFlags.TfIdfHigh))) + ).mapToLong(SearchResultKeywordScore::encodedWordMetadata);; + + CompiledQueryLong highCountSubjectSet = CompiledQuery.just( + new SearchResultKeywordScore("bob", 1, + wordMetadata(Set.of(1,3,4,6,7,9,10,11,12,14,15,16), EnumSet.of(WordFlags.TfIdfHigh, WordFlags.Subjects))) + ).mapToLong(SearchResultKeywordScore::encodedWordMetadata);; @Test @@ -62,12 +62,16 @@ class ResultValuatorTest { when(dict.getTermFreq("bob")).thenReturn(10); ResultRankingContext context = new ResultRankingContext(100000, ResultRankingParameters.sensibleDefaults(), - Map.of("bob", 10), Collections.emptyMap()); + frequencyData, + frequencyData); - double titleOnlyLowCount = valuator.calculateSearchResultValue(titleOnlyLowCountSet, 10_000, context); - double titleLongOnlyLowCount = valuator.calculateSearchResultValue(titleOnlyLowCountSet, 10_000, context); - double highCountNoTitle = valuator.calculateSearchResultValue(highCountNoTitleSet, 10_000, context); - double highCountSubject = valuator.calculateSearchResultValue(highCountSubjectSet, 10_000, context); + long docMeta = docMetadata(0, 2010, 5, EnumSet.noneOf(DocumentFlags.class)); + int features = 0; + + double titleOnlyLowCount = valuator.calculateSearchResultValue(titleOnlyLowCountSet, docMeta, features, 10_000, context); + double titleLongOnlyLowCount = valuator.calculateSearchResultValue(titleOnlyLowCountSet, docMeta, features, 10_000, context); + double highCountNoTitle = valuator.calculateSearchResultValue(highCountNoTitleSet, docMeta, features, 10_000, context); + double highCountSubject = valuator.calculateSearchResultValue(highCountSubjectSet, docMeta, features, 10_000, context); System.out.println(titleOnlyLowCount); System.out.println(titleLongOnlyLowCount); @@ -75,7 +79,10 @@ class ResultValuatorTest { System.out.println(highCountSubject); } - private long docMetadata(int topology, int year, int quality, EnumSet flags) { + private long docMetadata(int topology, + int year, + int quality, + EnumSet flags) { return new DocumentMetadata(topology, PubDate.toYearByte(year), quality, flags).encode(); } diff --git a/code/index/test/nu/marginalia/ranking/results/factors/TermCoherenceFactorTest.java b/code/index/test/nu/marginalia/ranking/results/factors/TermCoherenceFactorTest.java index a5bca54e..d0abe443 100644 --- a/code/index/test/nu/marginalia/ranking/results/factors/TermCoherenceFactorTest.java +++ b/code/index/test/nu/marginalia/ranking/results/factors/TermCoherenceFactorTest.java @@ -1,9 +1,10 @@ package nu.marginalia.ranking.results.factors; +import nu.marginalia.api.searchquery.model.compiled.CompiledQuery; +import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates; import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore; import nu.marginalia.bbpc.BrailleBlockPunchCards; import nu.marginalia.model.idx.WordMetadata; -import nu.marginalia.ranking.results.ResultKeywordSet; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -17,14 +18,23 @@ class TermCoherenceFactorTest { @Test public void testAllBitsSet() { var allPositionsSet = createSet( - WordMetadata.POSITIONS_MASK, WordMetadata.POSITIONS_MASK + ~0L, + ~0L ); - long mask = termCoherenceFactor.combinedMask(allPositionsSet); + long mask = CompiledQueryAggregates.longBitmaskAggregate( + allPositionsSet, + SearchResultKeywordScore::positions + ); assertEquals(1.0, termCoherenceFactor.bitsSetFactor(mask), 0.01); - assertEquals(1.0, termCoherenceFactor.calculate(allPositionsSet)); + assertEquals(1.0, + termCoherenceFactor.calculate( + allPositionsSet.mapToLong(SearchResultKeywordScore::encodedWordMetadata) + ) + ); + } @Test @@ -33,11 +43,11 @@ class TermCoherenceFactorTest { 0, 0 ); - long mask = termCoherenceFactor.combinedMask(allPositionsSet); + long mask = CompiledQueryAggregates.longBitmaskAggregate(allPositionsSet, score -> score.positions() & WordMetadata.POSITIONS_MASK); assertEquals(0, termCoherenceFactor.bitsSetFactor(mask), 0.01); - assertEquals(0, termCoherenceFactor.calculate(allPositionsSet)); + assertEquals(0, termCoherenceFactor.calculate(allPositionsSet.mapToLong(SearchResultKeywordScore::encodedWordMetadata))); } @Test @SuppressWarnings("unchecked") @@ -46,7 +56,7 @@ class TermCoherenceFactorTest { List.of(0, 1, 2, 3), List.of(0, 1, 2, 3) ); - long mask = termCoherenceFactor.combinedMask(positions); + long mask = CompiledQueryAggregates.longBitmaskAggregate(positions, score -> score.positions() & WordMetadata.POSITIONS_MASK); printMask(mask); } @@ -57,7 +67,7 @@ class TermCoherenceFactorTest { List.of(55, 54, 53, 52), List.of(55, 54, 53, 52) ); - long mask = termCoherenceFactor.combinedMask(positions); + long mask = CompiledQueryAggregates.longBitmaskAggregate(positions, score -> score.positions() & WordMetadata.POSITIONS_MASK); printMask(mask); } @@ -72,7 +82,7 @@ class TermCoherenceFactorTest { System.out.println(BrailleBlockPunchCards.printBits(mask, 48)); } - ResultKeywordSet createSet(List... maskPositions) { + CompiledQuery createSet(List... maskPositions) { long[] positions = new long[maskPositions.length]; for (int i = 0; i < maskPositions.length; i++) { @@ -84,14 +94,14 @@ class TermCoherenceFactorTest { return createSet(positions); } - ResultKeywordSet createSet(long... positionMasks) { + CompiledQuery createSet(long... positionMasks) { List keywords = new ArrayList<>(); for (int i = 0; i < positionMasks.length; i++) { - keywords.add(new SearchResultKeywordScore(0, "", - new WordMetadata(positionMasks[i], (byte) 0).encode(), 0, 0)); + keywords.add(new SearchResultKeywordScore("", 0, + new WordMetadata(positionMasks[i] & WordMetadata.POSITIONS_MASK, (byte) 0).encode())); } - return new ResultKeywordSet(keywords); + return CompiledQuery.just(keywords.toArray(SearchResultKeywordScore[]::new)); } } \ No newline at end of file diff --git a/code/libraries/array/java/nu/marginalia/array/algo/LongArrayBase.java b/code/libraries/array/java/nu/marginalia/array/algo/LongArrayBase.java index 39d9bff7..ab7f18bd 100644 --- a/code/libraries/array/java/nu/marginalia/array/algo/LongArrayBase.java +++ b/code/libraries/array/java/nu/marginalia/array/algo/LongArrayBase.java @@ -1,5 +1,7 @@ package nu.marginalia.array.algo; +import nu.marginalia.array.LongArray; + import java.io.IOException; import java.nio.LongBuffer; import java.nio.channels.FileChannel; @@ -61,6 +63,12 @@ public interface LongArrayBase extends BulkTransferArray { } } + default void get(long start, long end, LongArray buffer, int bufferStart) { + for (int i = 0; i < (end-start); i++) { + buffer.set(i + bufferStart, get(start + i)); + } + } + default void get(long start, LongBuffer buffer) { get(start, start + buffer.remaining(), buffer, buffer.position()); } diff --git a/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java b/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java index 390325ee..a0312d36 100644 --- a/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java +++ b/code/libraries/array/java/nu/marginalia/array/buffer/LongQueryBuffer.java @@ -1,5 +1,8 @@ package nu.marginalia.array.buffer; +import nu.marginalia.array.LongArray; +import nu.marginalia.array.LongArrayFactory; + import java.util.Arrays; /** A buffer for long values that can be used to filter and manipulate the data. @@ -17,7 +20,7 @@ import java.util.Arrays; public class LongQueryBuffer { /** Direct access to the data in the buffer, * guaranteed to be populated until `end` */ - public final long[] data; + public final LongArray data; /** Number of items in the data buffer */ public int end; @@ -25,18 +28,27 @@ public class LongQueryBuffer { private int read = 0; private int write = 0; + private LongQueryBuffer(LongArray array, int size) { + this.data = array; + this.end = size; + } + public LongQueryBuffer(int size) { - this.data = new long[size]; + this.data = LongArrayFactory.onHeapConfined(size); this.end = size; } public LongQueryBuffer(long[] data, int size) { - this.data = data; + this.data = LongArrayFactory.onHeapConfined(size); + this.data.set(0, data); + this.end = size; } public long[] copyData() { - return Arrays.copyOf(data, end); + long[] copy = new long[end]; + data.forEach(0, end, (pos, val) -> copy[(int)pos]=val ); + return copy; } public boolean isEmpty() { @@ -48,7 +60,7 @@ public class LongQueryBuffer { } public void reset() { - end = data.length; + end = (int) data.size(); read = 0; write = 0; } @@ -59,12 +71,16 @@ public class LongQueryBuffer { write = 0; } + public LongQueryBuffer slice(int start, int end) { + return new LongQueryBuffer(data.range(start, end), end - start); + } + /* == Filtering methods == */ /** Returns the current value at the read pointer. */ public long currentValue() { - return data[read]; + return data.get(read); } /** Advances the read pointer and returns true if there are more values to read. */ @@ -79,9 +95,9 @@ public class LongQueryBuffer { */ public boolean retainAndAdvance() { if (read != write) { - long tmp = data[write]; - data[write] = data[read]; - data[read] = tmp; + long tmp = data.get(write); + data.set(write, data.get(read)); + data.set(read, tmp); } write++; @@ -117,11 +133,6 @@ public class LongQueryBuffer { write = 0; } - public void startFilterForRange(int pos, int end) { - read = write = pos; - this.end = end; - } - /** Retain only unique values in the buffer, and update the end pointer to the new length. *

* The buffer is assumed to be sorted up until the end pointer. @@ -153,7 +164,7 @@ public class LongQueryBuffer { "read = " + read + ",write = " + write + ",end = " + end + - ",data = [" + Arrays.toString(Arrays.copyOf(data, end)) + "]]"; + ",data = [" + Arrays.toString(copyData()) + "]]"; } diff --git a/code/libraries/array/test/nu/marginalia/array/algo/LongArraySearchTest.java b/code/libraries/array/test/nu/marginalia/array/algo/LongArraySearchTest.java index a515917b..fa50045e 100644 --- a/code/libraries/array/test/nu/marginalia/array/algo/LongArraySearchTest.java +++ b/code/libraries/array/test/nu/marginalia/array/algo/LongArraySearchTest.java @@ -143,7 +143,7 @@ class LongArraySearchTest { assertEquals(43, buffer.size()); for (int i = 0; i < 43; i++) { - assertEquals(buffer.data[i], i*3); + assertEquals(buffer.data.get(i), i*3); } } @@ -160,7 +160,7 @@ class LongArraySearchTest { int j = 0; for (int i = 0; i < 43; i++) { if (++j % 3 == 0) j++; - assertEquals(buffer.data[i], j); + assertEquals(buffer.data.get(i), j); } } } \ No newline at end of file diff --git a/code/libraries/btree/java/nu/marginalia/btree/BTreeReader.java b/code/libraries/btree/java/nu/marginalia/btree/BTreeReader.java index 048e0301..bc40bb43 100644 --- a/code/libraries/btree/java/nu/marginalia/btree/BTreeReader.java +++ b/code/libraries/btree/java/nu/marginalia/btree/BTreeReader.java @@ -109,8 +109,8 @@ public class BTreeReader { return ip.findData(key); } - public void readData(long[] buf, int n, long pos) { - data.get(pos, pos + n, buf); + public void readData(LongArray buf, int n, long pos) { + data.get(pos, pos + n, buf, 0); } /** Used for querying interlaced data in the btree. diff --git a/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithIndexTest.java b/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithIndexTest.java index 8b65753d..be24de10 100644 --- a/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithIndexTest.java +++ b/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithIndexTest.java @@ -32,7 +32,8 @@ public class BTreeReaderRejectRetainWithIndexTest { @Test public void testRetain() { LongQueryBuffer odds = new LongQueryBuffer(50); - Arrays.setAll(odds.data, i -> 2L*i + 1); + for (int i = 0; i < 50; i++) + odds.data.set(i, 2L*i + 1); BTreeReader reader = new BTreeReader(array, ctx, 0); reader.retainEntries(odds); @@ -46,7 +47,8 @@ public class BTreeReaderRejectRetainWithIndexTest { @Test public void testReject() { LongQueryBuffer odds = new LongQueryBuffer(50); - Arrays.setAll(odds.data, i -> 2L*i + 1); + for (int i = 0; i < 50; i++) + odds.data.set(i, 2L*i + 1); BTreeReader reader = new BTreeReader(array, ctx, 0); reader.rejectEntries(odds); diff --git a/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithoutIndexTest.java b/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithoutIndexTest.java index e5d4dc79..fc3b71df 100644 --- a/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithoutIndexTest.java +++ b/code/libraries/btree/test/nu/marginalia/btree/BTreeReaderRejectRetainWithoutIndexTest.java @@ -32,7 +32,8 @@ public class BTreeReaderRejectRetainWithoutIndexTest { @Test public void testRetain() { LongQueryBuffer odds = new LongQueryBuffer(50); - Arrays.setAll(odds.data, i -> 2L*i + 1); + for (int i = 0; i < 50; i++) + odds.data.set(i, 2L*i + 1); BTreeReader reader = new BTreeReader(array, ctx, 0); reader.retainEntries(odds); @@ -46,7 +47,9 @@ public class BTreeReaderRejectRetainWithoutIndexTest { @Test public void testReject() { LongQueryBuffer odds = new LongQueryBuffer(50); - Arrays.setAll(odds.data, i -> 2L*i + 1); + for (int i = 0; i < 50; i++) + odds.data.set(i, 2L*i + 1); + BTreeReader reader = new BTreeReader(array, ctx, 0); reader.rejectEntries(odds); diff --git a/code/libraries/language-processing/java/nu/marginalia/language/model/DocumentSentence.java b/code/libraries/language-processing/java/nu/marginalia/language/model/DocumentSentence.java index ef5bc0a9..b9b4abce 100644 --- a/code/libraries/language-processing/java/nu/marginalia/language/model/DocumentSentence.java +++ b/code/libraries/language-processing/java/nu/marginalia/language/model/DocumentSentence.java @@ -16,12 +16,24 @@ public class DocumentSentence implements Iterable{ public final String[] wordsLowerCase; public final String[] posTags; public final String[] stemmedWords; + public final String[] ngrams; + public final String[] ngramStemmed; private final BitSet isStopWord; + public SoftReference keywords; - public DocumentSentence(String originalSentence, String[] words, int[] separators, String[] wordsLowerCase, String[] posTags, String[] stemmedWords) { + public DocumentSentence(String originalSentence, + String[] words, + int[] separators, + String[] wordsLowerCase, + String[] posTags, + String[] stemmedWords, + String[] ngrams, + String[] ngramsStemmed + ) + { this.originalSentence = originalSentence; this.words = words; this.separators = separators; @@ -31,6 +43,9 @@ public class DocumentSentence implements Iterable{ isStopWord = new BitSet(words.length); + this.ngrams = ngrams; + this.ngramStemmed = ngramsStemmed; + for (int i = 0; i < words.length; i++) { if (WordPatterns.isStopWord(words[i])) isStopWord.set(i); diff --git a/code/libraries/language-processing/java/nu/marginalia/language/sentence/SentenceExtractor.java b/code/libraries/language-processing/java/nu/marginalia/language/sentence/SentenceExtractor.java index 13ba2e76..bb1e3771 100644 --- a/code/libraries/language-processing/java/nu/marginalia/language/sentence/SentenceExtractor.java +++ b/code/libraries/language-processing/java/nu/marginalia/language/sentence/SentenceExtractor.java @@ -4,6 +4,7 @@ import com.github.datquocnguyen.RDRPOSTagger; import gnu.trove.map.hash.TObjectIntHashMap; import lombok.SneakyThrows; import nu.marginalia.LanguageModels; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.language.model.DocumentLanguageData; import nu.marginalia.language.model.DocumentSentence; import opennlp.tools.sentdetect.SentenceDetectorME; @@ -32,6 +33,8 @@ public class SentenceExtractor { private SentenceDetectorME sentenceDetector; private static RDRPOSTagger rdrposTagger; + private static NgramLexicon ngramLexicon = null; + private final PorterStemmer porterStemmer = new PorterStemmer(); private static final Logger logger = LoggerFactory.getLogger(SentenceExtractor.class); @@ -45,7 +48,8 @@ public class SentenceExtractor { private static final int MAX_TEXT_LENGTH = 65536; @SneakyThrows @Inject - public SentenceExtractor(LanguageModels models) { + public SentenceExtractor(LanguageModels models) + { try (InputStream modelIn = new FileInputStream(models.openNLPSentenceDetectionData.toFile())) { var sentenceModel = new SentenceModel(modelIn); sentenceDetector = new SentenceDetectorME(sentenceModel); @@ -55,12 +59,17 @@ public class SentenceExtractor { logger.error("Could not initialize sentence detector", ex); } - synchronized (RDRPOSTagger.class) { - try { - rdrposTagger = new RDRPOSTagger(models.posDict, models.posRules); + synchronized (this) { + if (ngramLexicon == null) { + ngramLexicon = new NgramLexicon(models); } - catch (Exception ex) { - throw new IllegalStateException(ex); + + if (rdrposTagger == null) { + try { + rdrposTagger = new RDRPOSTagger(models.posDict, models.posRules); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } } } @@ -128,8 +137,34 @@ public class SentenceExtractor { var seps = wordsAndSeps.separators; var lc = SentenceExtractorStringUtils.toLowerCaseStripPossessive(wordsAndSeps.words); + List ngrams = ngramLexicon.findSegmentsStrings(2, 12, words); + + String[] ngramsWords = new String[ngrams.size()]; + String[] ngramsStemmedWords = new String[ngrams.size()]; + for (int i = 0; i < ngrams.size(); i++) { + String[] ngram = ngrams.get(i); + + StringJoiner ngramJoiner = new StringJoiner("_"); + StringJoiner stemmedJoiner = new StringJoiner("_"); + for (String s : ngram) { + ngramJoiner.add(s); + stemmedJoiner.add(porterStemmer.stem(s)); + } + + ngramsWords[i] = ngramJoiner.toString(); + ngramsStemmedWords[i] = stemmedJoiner.toString(); + } + + return new DocumentSentence( - SentenceExtractorStringUtils.sanitizeString(text), words, seps, lc, rdrposTagger.tagsForEnSentence(words), stemSentence(lc) + SentenceExtractorStringUtils.sanitizeString(text), + words, + seps, + lc, + rdrposTagger.tagsForEnSentence(words), + stemSentence(lc), + ngramsWords, + ngramsStemmedWords ); } @@ -195,7 +230,35 @@ public class SentenceExtractor { fullString = ""; } - ret[i] = new DocumentSentence(fullString, tokens[i], separators[i], tokensLc[i], posTags[i], stemmedWords[i]); + List ngrams = ngramLexicon.findSegmentsStrings(2, 12, tokens[i]); + + String[] ngramsWords = new String[ngrams.size()]; + String[] ngramsStemmedWords = new String[ngrams.size()]; + + for (int j = 0; j < ngrams.size(); j++) { + String[] ngram = ngrams.get(j); + + StringJoiner ngramJoiner = new StringJoiner("_"); + StringJoiner stemmedJoiner = new StringJoiner("_"); + for (String s : ngram) { + ngramJoiner.add(s); + stemmedJoiner.add(porterStemmer.stem(s)); + } + + ngramsWords[j] = ngramJoiner.toString(); + ngramsStemmedWords[j] = stemmedJoiner.toString(); + } + + + ret[i] = new DocumentSentence(fullString, + tokens[i], + separators[i], + tokensLc[i], + posTags[i], + stemmedWords[i], + ngramsWords, + ngramsStemmedWords + ); } return ret; } diff --git a/code/libraries/language-processing/test/nu/marginalia/language/filter/TestLanguageModels.java b/code/libraries/language-processing/test/nu/marginalia/language/filter/TestLanguageModels.java index 2b7bf0e2..cb31942a 100644 --- a/code/libraries/language-processing/test/nu/marginalia/language/filter/TestLanguageModels.java +++ b/code/libraries/language-processing/test/nu/marginalia/language/filter/TestLanguageModels.java @@ -26,13 +26,13 @@ public class TestLanguageModels { var languageModelsHome = getLanguageModelsPath(); return new LanguageModels( - languageModelsHome.resolve("ngrams.bin"), languageModelsHome.resolve("tfreq-new-algo3.bin"), languageModelsHome.resolve("opennlp-sentence.bin"), languageModelsHome.resolve("English.RDR"), languageModelsHome.resolve("English.DICT"), languageModelsHome.resolve("opennlp-tokens.bin"), - languageModelsHome.resolve("lid.176.ftz") + languageModelsHome.resolve("lid.176.ftz"), + languageModelsHome.resolve("segments.bin") ); } } diff --git a/code/libraries/term-frequency-dict/build.gradle b/code/libraries/term-frequency-dict/build.gradle index 901fd2e0..3a9a4d8d 100644 --- a/code/libraries/term-frequency-dict/build.gradle +++ b/code/libraries/term-frequency-dict/build.gradle @@ -16,11 +16,14 @@ apply from: "$rootProject.projectDir/srcsets.gradle" dependencies { implementation project(':third-party:rdrpostagger') implementation project(':third-party:porterstemmer') + implementation project(':third-party:commons-codec') + implementation project(':third-party:openzim') implementation project(':third-party:monkey-patch-opennlp') implementation project(':code:common:model') implementation project(':code:common:config') implementation project(':code:libraries:easy-lsh') implementation project(':code:libraries:array') + implementation project(':code:libraries:blocking-thread-pool') implementation libs.bundles.slf4j implementation libs.notnull diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/BasicSentenceExtractor.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/BasicSentenceExtractor.java new file mode 100644 index 00000000..cee48910 --- /dev/null +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/BasicSentenceExtractor.java @@ -0,0 +1,16 @@ +package nu.marginalia.segmentation; + +import ca.rmen.porterstemmer.PorterStemmer; +import org.apache.commons.lang3.StringUtils; + +public class BasicSentenceExtractor { + + private static PorterStemmer porterStemmer = new PorterStemmer(); + public static String[] getStemmedParts(String sentence) { + String[] parts = StringUtils.split(sentence, ' '); + for (int i = 0; i < parts.length; i++) { + parts[i] = porterStemmer.stemWord(parts[i]); + } + return parts; + } +} diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/HasherGroup.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/HasherGroup.java new file mode 100644 index 00000000..2a452f75 --- /dev/null +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/HasherGroup.java @@ -0,0 +1,61 @@ +package nu.marginalia.segmentation; + +import nu.marginalia.hash.MurmurHash3_128; + +/** A group of hash functions that can be used to hash a sequence of strings, + * that also has an inverse operation that can be used to remove a previously applied + * string from the sequence. */ +public sealed interface HasherGroup { + /** Apply a hash to the accumulator */ + long apply(long acc, long add); + + /** Remove a hash that was added n operations ago from the accumulator, add a new one */ + long replace(long acc, long add, long rem, int n); + + /** Create a new hasher group that preserves the order of appleid hash functions */ + static HasherGroup ordered() { + return new OrderedHasher(); + } + + /** Create a new hasher group that does not preserve the order of applied hash functions */ + static HasherGroup unordered() { + return new UnorderedHasher(); + } + + /** Bake the words in the sentence into a hash successively using the group's apply function */ + default long rollingHash(String[] parts) { + long code = 0; + for (String part : parts) { + code = apply(code, hash(part)); + } + return code; + } + + MurmurHash3_128 hash = new MurmurHash3_128(); + /** Calculate the hash of a string */ + static long hash(String term) { + return hash.hashNearlyASCII(term); + } + + final class UnorderedHasher implements HasherGroup { + + public long apply(long acc, long add) { + return acc ^ add; + } + + public long replace(long acc, long add, long rem, int n) { + return acc ^ rem ^ add; + } + } + + final class OrderedHasher implements HasherGroup { + + public long apply(long acc, long add) { + return Long.rotateLeft(acc, 1) ^ add; + } + + public long replace(long acc, long add, long rem, int n) { + return Long.rotateLeft(acc, 1) ^ add ^ Long.rotateLeft(rem, n); + } + } +} diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java new file mode 100644 index 00000000..b0eb6916 --- /dev/null +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java @@ -0,0 +1,163 @@ +package nu.marginalia.segmentation; + +import it.unimi.dsi.fastutil.longs.*; +import nu.marginalia.util.SimpleBlockingThreadPool; +import org.jsoup.Jsoup; +import org.jsoup.nodes.Document; +import org.openzim.ZIMTypes.ZIMFile; +import org.openzim.ZIMTypes.ZIMReader; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class NgramExtractorMain { + public static void main(String... args) throws IOException, InterruptedException { + } + + private static List getNgramTitleTerms(String title) { + List terms = new ArrayList<>(); + + // Add the title + if (title.contains(" ")) { // Only add multi-word titles since we're chasing ngrams + terms.add(title.toLowerCase()); + } + + return cleanTerms(terms); + } + + private static List getNgramBodyTerms(Document document) { + List terms = new ArrayList<>(); + + // Grab all internal links + document.select("a[href]").forEach(e -> { + var href = e.attr("href"); + if (href.contains(":")) + return; + if (href.contains("/")) + return; + + var text = e.text().toLowerCase(); + if (!text.contains(" ")) + return; + + terms.add(text); + }); + + // Grab all italicized text + document.getElementsByTag("i").forEach(e -> { + var text = e.text().toLowerCase(); + if (!text.contains(" ")) + return; + + terms.add(text); + }); + + return cleanTerms(terms); + } + + private static List cleanTerms(List terms) { + // Trim the discovered terms + terms.replaceAll(s -> { + // Remove trailing parentheses and their contents + if (s.endsWith(")")) { + int idx = s.lastIndexOf('('); + if (idx > 0) { + return s.substring(0, idx).trim(); + } + } + + return s; + }); + + terms.replaceAll(s -> { + // Remove leading "list of " + if (s.startsWith("list of ")) { + return s.substring("list of ".length()); + } + + return s; + }); + + terms.replaceAll(s -> { + // Remove trailing punctuation + if (s.endsWith(".") || s.endsWith(",") || s.endsWith(":") || s.endsWith(";")) { + return s.substring(0, s.length() - 1); + } + + return s; + }); + + // Remove terms that are too short or too long + terms.removeIf(s -> { + if (!s.contains(" ")) + return true; + if (s.length() > 64) + return true; + return false; + }); + + return terms; + } + + public static void dumpCounts(Path zimInputFile, + Path countsOutputFile + ) throws IOException, InterruptedException + { + ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString())); + + NgramLexicon lexicon = new NgramLexicon(); + + var orderedHasher = HasherGroup.ordered(); + + var pool = new SimpleBlockingThreadPool("ngram-extractor", + Math.clamp(2, Runtime.getRuntime().availableProcessors(), 32), + 32 + ); + + reader.forEachTitles((title) -> { + pool.submitQuietly(() -> { + LongArrayList orderedHashesTitle = new LongArrayList(); + + String normalizedTitle = title.replace('_', ' '); + + for (var sent : getNgramTitleTerms(normalizedTitle)) { + String[] terms = BasicSentenceExtractor.getStemmedParts(sent); + orderedHashesTitle.add(orderedHasher.rollingHash(terms)); + } + synchronized (lexicon) { + for (var hash : orderedHashesTitle) { + lexicon.incOrderedTitle(hash); + } + } + }); + + }); + + reader.forEachArticles((title, body) -> { + pool.submitQuietly(() -> { + LongArrayList orderedHashesBody = new LongArrayList(); + + for (var sent : getNgramBodyTerms(Jsoup.parse(body))) { + String[] terms = BasicSentenceExtractor.getStemmedParts(sent); + orderedHashesBody.add(orderedHasher.rollingHash(terms)); + } + + synchronized (lexicon) { + for (var hash : orderedHashesBody) { + lexicon.incOrderedBody(hash); + } + } + }); + + }, p -> true); + + pool.shutDown(); + pool.awaitTermination(10, TimeUnit.DAYS); + + lexicon.saveCounts(countsOutputFile); + } + +} diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java new file mode 100644 index 00000000..9b59a84f --- /dev/null +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java @@ -0,0 +1,214 @@ +package nu.marginalia.segmentation; + +import com.google.inject.Inject; +import com.google.inject.Singleton; +import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap; +import it.unimi.dsi.fastutil.longs.LongHash; +import nu.marginalia.LanguageModels; + +import java.io.BufferedInputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +@Singleton +public class NgramLexicon { + private final Long2IntOpenCustomHashMap counts; + + private int size; + private static final HasherGroup orderedHasher = HasherGroup.ordered(); + + @Inject + public NgramLexicon(LanguageModels models) { + try (var dis = new DataInputStream(new BufferedInputStream(Files.newInputStream(models.segments)))) { + long size = dis.readInt(); + counts = new Long2IntOpenCustomHashMap( + (int) size, + new KeyIsAlreadyHashStrategy() + ); + counts.defaultReturnValue(0); + + try { + for (int i = 0; i < size; i++) { + counts.put(dis.readLong(), dis.readInt()); + } + } + catch (IOException ex) { + ex.printStackTrace(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public NgramLexicon() { + counts = new Long2IntOpenCustomHashMap(100_000_000, new KeyIsAlreadyHashStrategy()); + counts.defaultReturnValue(0); + } + + public List findSegmentsStrings(int minLength, + int maxLength, + String... parts) + { + List segments = new ArrayList<>(); + + // Hash the parts + long[] hashes = new long[parts.length]; + for (int i = 0; i < hashes.length; i++) { + hashes[i] = HasherGroup.hash(parts[i]); + } + + for (int i = minLength; i <= maxLength; i++) { + findSegments(segments, i, parts, hashes); + } + + return segments; + } + + public void findSegments(List positions, + int length, + String[] parts, + long[] hashes) + { + // Don't look for ngrams longer than the sentence + if (parts.length < length) return; + + long hash = 0; + int i = 0; + + // Prepare by combining up to length hashes + for (; i < length; i++) { + hash = orderedHasher.apply(hash, hashes[i]); + } + + // Slide the window and look for matches + for (;;) { + if (counts.get(hash) > 0) { + positions.add(Arrays.copyOfRange(parts, i - length, i)); + } + + if (i < hashes.length) { + hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length); + i++; + } else { + break; + } + } + } + + public List findSegmentOffsets(int length, String... parts) { + // Don't look for ngrams longer than the sentence + if (parts.length < length) return List.of(); + + List positions = new ArrayList<>(); + + // Hash the parts + long[] hashes = new long[parts.length]; + for (int i = 0; i < hashes.length; i++) { + hashes[i] = HasherGroup.hash(parts[i]); + } + + long hash = 0; + int i = 0; + + // Prepare by combining up to length hashes + for (; i < length; i++) { + hash = orderedHasher.apply(hash, hashes[i]); + } + + // Slide the window and look for matches + for (;;) { + int ct = counts.get(hash); + + if (ct > 0) { + positions.add(new SentenceSegment(i - length, length, ct)); + } + + if (i < hashes.length) { + hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length); + i++; + } else { + break; + } + } + + return positions; + } + + public void incOrderedTitle(long hashOrdered) { + int value = counts.get(hashOrdered); + + if (value <= 0) { + size ++; + value = -value; + } + + value ++; + + counts.put(hashOrdered, value); + } + + public void incOrderedBody(long hashOrdered) { + int value = counts.get(hashOrdered); + + if (value <= 0) value --; + else value++; + + counts.put(hashOrdered, value); + } + + public void saveCounts(Path file) throws IOException { + try (var dos = new DataOutputStream(Files.newOutputStream(file, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE))) { + + dos.writeInt(size); + + counts.forEach((k, v) -> { + try { + if (v > 0) { + dos.writeLong(k); + dos.writeInt(v); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + } + + public void clear() { + counts.clear(); + } + + public record SentenceSegment(int start, int length, int count) { + public String[] project(String... parts) { + return Arrays.copyOfRange(parts, start, start + length); + } + + public boolean overlaps(SentenceSegment other) { + return start < other.start + other.length && start + length > other.start; + } + } + + private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { + @Override + public int hashCode(long l) { + return (int) l; + } + + @Override + public boolean equals(long l, long l1) { + return l == l1; + } + } + +} + diff --git a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/HasherGroupTest.java b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/HasherGroupTest.java new file mode 100644 index 00000000..110b1b9b --- /dev/null +++ b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/HasherGroupTest.java @@ -0,0 +1,34 @@ +package nu.marginalia.segmentation; + +import nu.marginalia.segmentation.HasherGroup; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class HasherGroupTest { + + @Test + void ordered() { + long a = 5; + long b = 3; + long c = 2; + + var group = HasherGroup.ordered(); + assertNotEquals(group.apply(a, b), group.apply(b, a)); + assertEquals(group.apply(b,c), group.replace(group.apply(a, b), c, a, 2)); + } + + @Test + void unordered() { + long a = 5; + long b = 3; + long c = 2; + + var group = HasherGroup.unordered(); + + assertEquals(group.apply(a, b), group.apply(b, a)); + assertEquals(group.apply(b, c), group.replace(group.apply(a, b), c, a, 2)); + } + + +} diff --git a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java new file mode 100644 index 00000000..df24ec10 --- /dev/null +++ b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java @@ -0,0 +1,44 @@ +package nu.marginalia.segmentation; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class NgramLexiconTest { + NgramLexicon lexicon = new NgramLexicon(); + @BeforeEach + public void setUp() { + lexicon.clear(); + } + + void addNgram(String... ngram) { + lexicon.incOrderedTitle(HasherGroup.ordered().rollingHash(ngram)); + } + + @Test + void findSegments() { + addNgram("hello", "world"); + addNgram("rye", "bread"); + addNgram("rye", "world"); + + List segments = lexicon.findSegmentsStrings(2, 2, "hello", "world", "rye", "bread"); + + assertEquals(2, segments.size()); + + for (int i = 0; i < 2; i++) { + var segment = segments.get(i); + switch (i) { + case 0 -> { + assertArrayEquals(new String[]{"hello", "world"}, segment); + } + case 1 -> { + assertArrayEquals(new String[]{"rye", "bread"}, segment); + } + } + } + + } +} \ No newline at end of file diff --git a/code/processes/converting-process/java/nu/marginalia/converting/sideload/encyclopedia/EncyclopediaMarginaliaNuSideloader.java b/code/processes/converting-process/java/nu/marginalia/converting/sideload/encyclopedia/EncyclopediaMarginaliaNuSideloader.java index ca85455e..17c83250 100644 --- a/code/processes/converting-process/java/nu/marginalia/converting/sideload/encyclopedia/EncyclopediaMarginaliaNuSideloader.java +++ b/code/processes/converting-process/java/nu/marginalia/converting/sideload/encyclopedia/EncyclopediaMarginaliaNuSideloader.java @@ -125,7 +125,6 @@ public class EncyclopediaMarginaliaNuSideloader implements SideloadSource, AutoC fullHtml.append("

"); fullHtml.append(part); fullHtml.append("

"); - break; // Only take the first part, this improves accuracy a lot } fullHtml.append(""); diff --git a/code/processes/converting-process/test/nu/marginalia/converting/util/TestLanguageModels.java b/code/processes/converting-process/test/nu/marginalia/converting/util/TestLanguageModels.java index 4ad1e430..f28e1348 100644 --- a/code/processes/converting-process/test/nu/marginalia/converting/util/TestLanguageModels.java +++ b/code/processes/converting-process/test/nu/marginalia/converting/util/TestLanguageModels.java @@ -26,13 +26,13 @@ public class TestLanguageModels { var languageModelsHome = getLanguageModelsPath(); return new LanguageModels( - languageModelsHome.resolve("ngrams.bin"), languageModelsHome.resolve("tfreq-new-algo3.bin"), languageModelsHome.resolve("opennlp-sentence.bin"), languageModelsHome.resolve("English.RDR"), languageModelsHome.resolve("English.DICT"), languageModelsHome.resolve("opennlp-tokens.bin"), - languageModelsHome.resolve("lid.176.ftz") + languageModelsHome.resolve("lid.176.ftz"), + languageModelsHome.resolve("segments.bin") ); } } diff --git a/code/services-application/api-service/java/nu/marginalia/api/ApiSearchOperator.java b/code/services-application/api-service/java/nu/marginalia/api/ApiSearchOperator.java index 25ba4945..e979b86f 100644 --- a/code/services-application/api-service/java/nu/marginalia/api/ApiSearchOperator.java +++ b/code/services-application/api-service/java/nu/marginalia/api/ApiSearchOperator.java @@ -70,22 +70,24 @@ public class ApiSearchOperator { ApiSearchResult convert(DecoratedSearchResultItem url) { List> details = new ArrayList<>(); + + // This list-of-list construction is to avoid breaking the API, + // we'll always have just a single outer list from now on... + if (url.rawIndexResult != null) { - var bySet = url.rawIndexResult.keywordScores.stream().collect(Collectors.groupingBy(SearchResultKeywordScore::subquery)); + List lst = new ArrayList<>(); + for (var entry : url.rawIndexResult.keywordScores) { + var metadata = new WordMetadata(entry.encodedWordMetadata()); - outer: - for (var entries : bySet.values()) { - List lst = new ArrayList<>(); - for (var entry : entries) { - var metadata = new WordMetadata(entry.encodedWordMetadata()); - if (metadata.isEmpty()) - continue outer; + // Skip terms that don't appear anywhere + if (metadata.isEmpty()) + continue; - Set flags = metadata.flagSet().stream().map(Object::toString).collect(Collectors.toSet()); - lst.add(new ApiSearchResultQueryDetails(entry.keyword, Long.bitCount(metadata.positions()), flags)); - } - details.add(lst); + Set flags = metadata.flagSet().stream().map(Object::toString).collect(Collectors.toSet()); + lst.add(new ApiSearchResultQueryDetails(entry.keyword, Long.bitCount(metadata.positions()), flags)); } + + details.add(lst); } return new ApiSearchResult( diff --git a/code/services-application/search-service/java/nu/marginalia/search/SearchQueryParamFactory.java b/code/services-application/search-service/java/nu/marginalia/search/SearchQueryParamFactory.java index 15c8567e..cc28b209 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/SearchQueryParamFactory.java +++ b/code/services-application/search-service/java/nu/marginalia/search/SearchQueryParamFactory.java @@ -1,7 +1,7 @@ package nu.marginalia.search; import nu.marginalia.api.searchquery.model.query.SearchSetIdentifier; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.results.ResultRankingParameters; import nu.marginalia.index.query.limit.QueryLimits; import nu.marginalia.index.query.limit.QueryStrategy; @@ -14,7 +14,7 @@ import java.util.List; public class SearchQueryParamFactory { public QueryParams forRegularSearch(SearchParameters userParams) { - SearchSubquery prototype = new SearchSubquery(); + SearchQuery prototype = new SearchQuery(); var profile = userParams.profile(); profile.addTacitTerms(prototype); diff --git a/code/services-application/search-service/java/nu/marginalia/search/command/SearchAdtechParameter.java b/code/services-application/search-service/java/nu/marginalia/search/command/SearchAdtechParameter.java index 9e8383f3..ce3bf099 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/command/SearchAdtechParameter.java +++ b/code/services-application/search-service/java/nu/marginalia/search/command/SearchAdtechParameter.java @@ -1,6 +1,6 @@ package nu.marginalia.search.command; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import javax.annotation.Nullable; import java.util.Arrays; @@ -23,7 +23,7 @@ public enum SearchAdtechParameter { return DEFAULT; } - public void addTacitTerms(SearchSubquery subquery) { + public void addTacitTerms(SearchQuery subquery) { subquery.searchTermsExclude.addAll(Arrays.asList(implictExcludeSearchTerms)); } } diff --git a/code/services-application/search-service/java/nu/marginalia/search/command/SearchJsParameter.java b/code/services-application/search-service/java/nu/marginalia/search/command/SearchJsParameter.java index 6c8634ac..8cf6aada 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/command/SearchJsParameter.java +++ b/code/services-application/search-service/java/nu/marginalia/search/command/SearchJsParameter.java @@ -1,6 +1,6 @@ package nu.marginalia.search.command; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import javax.annotation.Nullable; import java.util.Arrays; @@ -25,7 +25,7 @@ public enum SearchJsParameter { return DEFAULT; } - public void addTacitTerms(SearchSubquery subquery) { + public void addTacitTerms(SearchQuery subquery) { subquery.searchTermsExclude.addAll(Arrays.asList(implictExcludeSearchTerms)); } } diff --git a/code/services-application/search-service/java/nu/marginalia/search/model/ClusteredUrlDetails.java b/code/services-application/search-service/java/nu/marginalia/search/model/ClusteredUrlDetails.java index 6dd7390d..faba9eb7 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/model/ClusteredUrlDetails.java +++ b/code/services-application/search-service/java/nu/marginalia/search/model/ClusteredUrlDetails.java @@ -6,7 +6,6 @@ import nu.marginalia.model.idx.WordFlags; import org.jetbrains.annotations.NotNull; import java.util.*; -import java.util.stream.Collectors; /** A class to hold a list of UrlDetails, grouped by domain, where the first one is the main result * and the rest are additional results, for summary display. */ @@ -19,44 +18,46 @@ public class ClusteredUrlDetails implements Comparable { * @param details A collection of UrlDetails, which must not be empty. */ public ClusteredUrlDetails(Collection details) { - var queue = new PriorityQueue<>(details); + var items = new ArrayList<>(details); - if (queue.isEmpty()) + items.sort(Comparator.naturalOrder()); + + if (items.isEmpty()) throw new IllegalArgumentException("Empty list of details"); - this.first = queue.poll(); + this.first = items.removeFirst(); + this.rest = items; - if (queue.isEmpty()) { - this.rest = Collections.emptyList(); - } - else { - double bestScore = first.termScore; - double scoreLimit = Math.min(4.0, bestScore * 1.25); + double bestScore = first.termScore; + double scoreLimit = Math.min(4.0, bestScore * 1.25); - this.rest = queue - .stream() - .filter(this::isEligbleForInclusion) - .takeWhile(next -> next.termScore <= scoreLimit) - .toList(); - } + this.rest.removeIf(urlDetail -> { + if (urlDetail.termScore > scoreLimit) + return false; + + for (var keywordScore : urlDetail.resultItem.keywordScores) { + if (keywordScore.isKeywordSpecial()) + continue; + if (keywordScore.positions() == 0) + continue; + + if (keywordScore.hasTermFlag(WordFlags.Title)) + return false; + if (keywordScore.hasTermFlag(WordFlags.ExternalLink)) + return false; + if (keywordScore.hasTermFlag(WordFlags.UrlDomain)) + return false; + if (keywordScore.hasTermFlag(WordFlags.UrlPath)) + return false; + if (keywordScore.hasTermFlag(WordFlags.Subjects)) + return false; + } + + return true; + }); } - private boolean isEligbleForInclusion(UrlDetails urlDetails) { - return urlDetails.resultItem.keywordScores.stream() - .filter(score -> !score.keyword.contains(":")) - .collect(Collectors.toMap( - score -> score.subquery, - score -> score.hasTermFlag(WordFlags.Title) - | score.hasTermFlag(WordFlags.ExternalLink) - | score.hasTermFlag(WordFlags.UrlDomain) - | score.hasTermFlag(WordFlags.UrlPath) - | score.hasTermFlag(WordFlags.Subjects) - , - (a, b) -> a && b - )) - .containsValue(Boolean.TRUE); - } public ClusteredUrlDetails(@NotNull UrlDetails onlyFirst) { this.first = onlyFirst; diff --git a/code/services-application/search-service/java/nu/marginalia/search/model/SearchProfile.java b/code/services-application/search-service/java/nu/marginalia/search/model/SearchProfile.java index 27d9f4aa..955c3fcb 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/model/SearchProfile.java +++ b/code/services-application/search-service/java/nu/marginalia/search/model/SearchProfile.java @@ -2,7 +2,7 @@ package nu.marginalia.search.model; import nu.marginalia.index.query.limit.SpecificationLimit; import nu.marginalia.model.crawl.HtmlFeature; -import nu.marginalia.api.searchquery.model.query.SearchSubquery; +import nu.marginalia.api.searchquery.model.query.SearchQuery; import nu.marginalia.api.searchquery.model.query.SearchSetIdentifier; import java.util.Objects; @@ -47,7 +47,7 @@ public enum SearchProfile { return NO_FILTER; } - public void addTacitTerms(SearchSubquery subquery) { + public void addTacitTerms(SearchQuery subquery) { if (this == ACADEMIA) { subquery.searchTermsAdvice.add("special:academia"); } diff --git a/code/services-application/search-service/java/nu/marginalia/search/svc/SearchQueryIndexService.java b/code/services-application/search-service/java/nu/marginalia/search/svc/SearchQueryIndexService.java index 785c8952..6dc7b83b 100644 --- a/code/services-application/search-service/java/nu/marginalia/search/svc/SearchQueryIndexService.java +++ b/code/services-application/search-service/java/nu/marginalia/search/svc/SearchQueryIndexService.java @@ -88,7 +88,7 @@ public class SearchQueryIndexService { DomainIndexingState.ACTIVE, detail.rankingScore, // termScore detail.resultsFromDomain(), - getPositionsString(detail.rawIndexResult), + getPositionsString(detail), detail.rawIndexResult, detail.rawIndexResult.keywordScores )); @@ -97,27 +97,8 @@ public class SearchQueryIndexService { return ret; } - private String getPositionsString(SearchResultItem resultItem) { - Int2LongArrayMap positionsPerSet = new Int2LongArrayMap(8); - - for (var score : resultItem.keywordScores) { - if (!score.isKeywordRegular()) { - continue; - } - positionsPerSet.merge(score.subquery(), score.positions(), this::and); - } - - long bits = positionsPerSet.values().longStream().reduce(this::or).orElse(0); - - return BrailleBlockPunchCards.printBits(bits, 56); + private String getPositionsString(DecoratedSearchResultItem resultItem) { + return BrailleBlockPunchCards.printBits(resultItem.bestPositions, 56); } - - private long and(long a, long b) { - return a & b; - } - private long or(long a, long b) { - return a | b; - } - } diff --git a/code/services-application/search-service/test/nu/marginalia/util/TestLanguageModels.java b/code/services-application/search-service/test/nu/marginalia/util/TestLanguageModels.java index 5efd2025..a4cc012b 100644 --- a/code/services-application/search-service/test/nu/marginalia/util/TestLanguageModels.java +++ b/code/services-application/search-service/test/nu/marginalia/util/TestLanguageModels.java @@ -26,13 +26,13 @@ public class TestLanguageModels { var languageModelsHome = getLanguageModelsPath(); return new LanguageModels( - languageModelsHome.resolve("ngrams.bin"), languageModelsHome.resolve("tfreq-new-algo3.bin"), languageModelsHome.resolve("opennlp-sentence.bin"), languageModelsHome.resolve("English.RDR"), languageModelsHome.resolve("English.DICT"), languageModelsHome.resolve("opennlp-tokens.bin"), - languageModelsHome.resolve("lid.176.ftz") + languageModelsHome.resolve("lid.176.ftz"), + languageModelsHome.resolve("segments.bin") ); } } diff --git a/code/services-core/control-service/java/nu/marginalia/control/node/svc/ControlNodeActionsService.java b/code/services-core/control-service/java/nu/marginalia/control/node/svc/ControlNodeActionsService.java index 2ae09234..b711be14 100644 --- a/code/services-core/control-service/java/nu/marginalia/control/node/svc/ControlNodeActionsService.java +++ b/code/services-core/control-service/java/nu/marginalia/control/node/svc/ControlNodeActionsService.java @@ -76,6 +76,9 @@ public class ControlNodeActionsService { Spark.post("/public/nodes/:node/actions/sideload-stackexchange", this::sideloadStackexchange, redirectControl.renderRedirectAcknowledgement("Sideloading", "..") ); + Spark.post("/public/nodes/:node/actions/export-segmentation", this::exportSegmentationModel, + redirectControl.renderRedirectAcknowledgement("Exporting", "..") + ); Spark.post("/public/nodes/:node/actions/download-sample-data", this::downloadSampleData, redirectControl.renderRedirectAcknowledgement("Downloading", "..") ); @@ -307,6 +310,14 @@ public class ControlNodeActionsService { return ""; } + private Object exportSegmentationModel(Request req, Response rsp) { + exportClient.exportSegmentationModel( + Integer.parseInt(req.params("node")), + req.queryParams("source")); + + return ""; + } + private Object exportFromCrawlData(Request req, Response rsp) { String exportType = req.queryParams("exportType"); FileStorageId source = parseSourceFileStorageId(req.queryParams("source")); diff --git a/code/services-core/control-service/resources/templates/control/node/actions/partial-export-segmentation.hdb b/code/services-core/control-service/resources/templates/control/node/actions/partial-export-segmentation.hdb new file mode 100644 index 00000000..2ef9b180 --- /dev/null +++ b/code/services-core/control-service/resources/templates/control/node/actions/partial-export-segmentation.hdb @@ -0,0 +1,45 @@ +

Export segmentation model

+ +
+

This will generate a query segmentation model from a wikipedia ZIM file. A query segmentation model +is used to break a search query into segments corresponding to different concepts. For example, the query +"slackware linux package manager" would be segmented into "slackware linux", and "package manager"; and the +search would be performed putting higher emphasis on "package" and "manager" appearing in the same part of the document +than "linux" and "manager". +

+
+
+
+ + + {{#each uploadDirContents.items}} + + + + + + + {{/each}} + {{#unless uploadDirContents.items}} + + + + {{/unless}} +
FilenameSizeLast Modified
+ + {{#unless directory}}{{size}}{{/unless}}{{shortTimestamp lastModifiedTime}}
Nothing found in upload directory
+ +

+ + The upload directory is typically mounted to /uploads on the server. The external + directory is typically something like index-{{node.id}}/uploads. + +

+ +
+
+ +
+
+
+
\ No newline at end of file diff --git a/code/services-core/control-service/resources/templates/control/node/node-actions.hdb b/code/services-core/control-service/resources/templates/control/node/node-actions.hdb index df8ed77f..7de90949 100644 --- a/code/services-core/control-service/resources/templates/control/node/node-actions.hdb +++ b/code/services-core/control-service/resources/templates/control/node/node-actions.hdb @@ -20,6 +20,7 @@ {{#if view.sideload-warc}} {{> control/node/actions/partial-sideload-warc }} {{/if}} {{#if view.sideload-dirtree}} {{> control/node/actions/partial-sideload-dirtree }} {{/if}} {{#if view.sideload-reddit}} {{> control/node/actions/partial-sideload-reddit }} {{/if}} + {{#if view.export-segmentation}} {{> control/node/actions/partial-export-segmentation }} {{/if}} {{#if view.export-db-data}} {{> control/node/actions/partial-export-db-data }} {{/if}} {{#if view.export-from-crawl-data}} {{> control/node/actions/partial-export-from-crawl-data }} {{/if}} {{#if view.export-sample-data}} {{> control/node/actions/partial-export-sample-data }} {{/if}} diff --git a/code/services-core/control-service/resources/templates/control/node/partial-node-nav.hdb b/code/services-core/control-service/resources/templates/control/node/partial-node-nav.hdb index 23627155..ff16507d 100644 --- a/code/services-core/control-service/resources/templates/control/node/partial-node-nav.hdb +++ b/code/services-core/control-service/resources/templates/control/node/partial-node-nav.hdb @@ -30,6 +30,7 @@
  • Export Database Data
  • Export Sample Crawl Data
  • Export From Crawl Data
  • +
  • Export Segmentation Model
  • Restore Index Backup
  • diff --git a/code/tools/experiment-runner/java/nu/marginalia/tools/experiments/SentenceStatisticsExperiment.java b/code/tools/experiment-runner/java/nu/marginalia/tools/experiments/SentenceStatisticsExperiment.java index 8614d1e6..dde7a106 100644 --- a/code/tools/experiment-runner/java/nu/marginalia/tools/experiments/SentenceStatisticsExperiment.java +++ b/code/tools/experiment-runner/java/nu/marginalia/tools/experiments/SentenceStatisticsExperiment.java @@ -8,6 +8,7 @@ import nu.marginalia.crawling.model.CrawledDomain; import nu.marginalia.keyword.DocumentKeywordExtractor; import nu.marginalia.language.sentence.SentenceExtractor; import nu.marginalia.model.EdgeUrl; +import nu.marginalia.segmentation.NgramLexicon; import nu.marginalia.term_frequency_dict.TermFrequencyDict; import nu.marginalia.tools.LegacyExperiment; import org.jsoup.Jsoup; @@ -21,8 +22,10 @@ import java.nio.file.Path; public class SentenceStatisticsExperiment extends LegacyExperiment { + NgramLexicon lexicon = new NgramLexicon(WmsaHome.getLanguageModels()); SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels()); - DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels())); + DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor( + new TermFrequencyDict(WmsaHome.getLanguageModels()), lexicon); Path filename; PrintWriter writer; diff --git a/run/setup.sh b/run/setup.sh index 3d9c5f54..3cacca75 100755 --- a/run/setup.sh +++ b/run/setup.sh @@ -26,7 +26,7 @@ download_model model/English.DICT https://raw.githubusercontent.com/datquocnguye download_model model/English.RDR https://raw.githubusercontent.com/datquocnguyen/RDRPOSTagger/master/Models/POS/English.RDR download_model model/opennlp-sentence.bin https://mirrors.estointernet.in/apache/opennlp/models/ud-models-1.0/opennlp-en-ud-ewt-sentence-1.0-1.9.3.bin download_model model/opennlp-tokens.bin https://mirrors.estointernet.in/apache/opennlp/models/ud-models-1.0/opennlp-en-ud-ewt-tokens-1.0-1.9.3.bin -download_model model/ngrams.bin https://downloads.marginalia.nu/model/ngrams.bin +download_model model/segments.bin https://downloads.marginalia.nu/model/segments.bin download_model model/tfreq-new-algo3.bin https://downloads.marginalia.nu/model/tfreq-new-algo3.bin download_model model/lid.176.ftz https://downloads.marginalia.nu/model/lid.176.ftz diff --git a/third-party/openzim/src/main/java/org/openzim/ZIMTypes/ZIMReader.java b/third-party/openzim/src/main/java/org/openzim/ZIMTypes/ZIMReader.java index e2fcaf6e..e9b5cf47 100644 --- a/third-party/openzim/src/main/java/org/openzim/ZIMTypes/ZIMReader.java +++ b/third-party/openzim/src/main/java/org/openzim/ZIMTypes/ZIMReader.java @@ -275,9 +275,7 @@ public class ZIMReader { } - - // Gives the minimum required information needed for the given articleName - public DirectoryEntry forEachTitles(Consumer aeConsumer, Consumer reConsumer) + public DirectoryEntry forEachTitles(Consumer titleConsumer) throws IOException { int numberOfArticles = mFile.getArticleCount(); @@ -287,26 +285,9 @@ public class ZIMReader { System.err.println(numberOfArticles); long start = System.currentTimeMillis(); - Map> data = new TreeMap<>(); - - System.err.println("Indexing"); - for (long i = beg; i < end; i+=4) { var entry = getDirectoryInfoAtTitlePosition(i); - - if (((i-beg)%100_000) == 0) { - System.err.printf("%f%%\n", ((i-beg) * 100.) / (end-beg)); - } - - if (entry.mimeType == targetMime && entry instanceof ArticleEntry) { - aeConsumer.accept((ArticleEntry) entry); - } - else if (entry.mimeType == 65535 && entry instanceof RedirectEntry) { - - reConsumer.accept((RedirectEntry) entry); - - } - + titleConsumer.accept(entry.title); } return null;