(query-segmentation) Merge pull request #89 from MarginaliaSearch/query-segmentation

The changeset cleans up the query parsing logic in the query service. It gets rid of a lot of old and largely unmaintainable query-rewriting logic that was based on POS-tagging rules, and adds a new cleaner approach. Query parsing is also refactored, and the internal APIs are updated to remove unnecessary duplication of document-level data across each search term.

A new query segmentation model is introduced based on a dictionary of known n-grams, with tools for extracting this dictionary from Wikipedia data. The changeset introduces a new segmentation model file, which is downloaded with the usual run/setup.sh, as well as an updated term frequency model.

A new intermediate representation of the query is introduced, based on a DAG with predefined vertices initiating and terminating the graph. This is for the benefit of easily writing rules for generating alternative queries, e.g. using the new segmentation data.

The graph is converted to a basic LL(1) syntax loosely reminiscent of a regular expression, where e.g. "( wiby | marginalia | kagi ) ( search engine | searchengine )" expands to "wiby search engine", "wiby searchengine", "marginalia search engine", "marginalia searchengine", "kagi search engine" and "kagi searchengine".

This compiled query is passed to the index, which parses the expression, where it is used for execution of the search and ranking of the results.
This commit is contained in:
Viktor 2024-04-16 15:31:05 +02:00 committed by GitHub
commit cfd9a7187f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
143 changed files with 4547 additions and 2229 deletions

View File

@ -1,9 +1,11 @@
package nu.marginalia;
import lombok.Builder;
import java.nio.file.Path;
@Builder
public class LanguageModels {
public final Path ngramBloomFilter;
public final Path termFrequencies;
public final Path openNLPSentenceDetectionData;
@ -11,20 +13,21 @@ public class LanguageModels {
public final Path posDict;
public final Path openNLPTokenData;
public final Path fasttextLanguageModel;
public final Path segments;
public LanguageModels(Path ngramBloomFilter,
Path termFrequencies,
public LanguageModels(Path termFrequencies,
Path openNLPSentenceDetectionData,
Path posRules,
Path posDict,
Path openNLPTokenData,
Path fasttextLanguageModel) {
this.ngramBloomFilter = ngramBloomFilter;
Path fasttextLanguageModel,
Path segments) {
this.termFrequencies = termFrequencies;
this.openNLPSentenceDetectionData = openNLPSentenceDetectionData;
this.posRules = posRules;
this.posDict = posDict;
this.openNLPTokenData = openNLPTokenData;
this.fasttextLanguageModel = fasttextLanguageModel;
this.segments = segments;
}
}

View File

@ -85,13 +85,14 @@ public class WmsaHome {
final Path home = getHomePath();
return new LanguageModels(
home.resolve("model/ngrams.bin"),
home.resolve("model/tfreq-new-algo3.bin"),
home.resolve("model/opennlp-sentence.bin"),
home.resolve("model/English.RDR"),
home.resolve("model/English.DICT"),
home.resolve("model/opennlp-tok.bin"),
home.resolve("model/lid.176.ftz"));
home.resolve("model/lid.176.ftz"),
home.resolve("model/segments.bin")
);
}
public static Path getAtagsPath() {

View File

@ -50,6 +50,10 @@ public enum WordFlags {
return (asBit() & value) > 0;
}
public boolean isAbsent(long value) {
return (asBit() & value) == 0;
}
public static EnumSet<WordFlags> decode(long encodedValue) {
EnumSet<WordFlags> ret = EnumSet.noneOf(WordFlags.class);
@ -61,4 +65,5 @@ public enum WordFlags {
return ret;
}
}

View File

@ -2,10 +2,7 @@ package nu.marginalia.executor.client;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.functions.execution.api.Empty;
import nu.marginalia.functions.execution.api.ExecutorExportApiGrpc;
import nu.marginalia.functions.execution.api.RpcExportSampleData;
import nu.marginalia.functions.execution.api.RpcFileStorageId;
import nu.marginalia.functions.execution.api.*;
import nu.marginalia.service.client.GrpcChannelPoolFactory;
import nu.marginalia.service.client.GrpcMultiNodeChannelPool;
import nu.marginalia.service.discovery.property.ServiceKey;
@ -55,6 +52,7 @@ public class ExecutorExportClient {
.setFileStorageId(fid.id())
.build());
}
public void exportTermFrequencies(int node, FileStorageId fid) {
channelPool.call(ExecutorExportApiBlockingStub::exportTermFrequencies)
.forNode(node)
@ -69,6 +67,14 @@ public class ExecutorExportClient {
.run(Empty.getDefaultInstance());
}
public void exportSegmentationModel(int node, String path) {
channelPool.call(ExecutorExportApiBlockingStub::exportSegmentationModel)
.forNode(node)
.run(RpcExportSegmentationModel
.newBuilder()
.setSourcePath(path)
.build());
}
}

View File

@ -38,6 +38,7 @@ service ExecutorSideloadApi {
service ExecutorExportApi {
rpc exportAtags(RpcFileStorageId) returns (Empty) {}
rpc exportSegmentationModel(RpcExportSegmentationModel) returns (Empty) {}
rpc exportSampleData(RpcExportSampleData) returns (Empty) {}
rpc exportRssFeeds(RpcFileStorageId) returns (Empty) {}
rpc exportTermFrequencies(RpcFileStorageId) returns (Empty) {}
@ -61,6 +62,9 @@ message RpcSideloadEncyclopedia {
string sourcePath = 1;
string baseUrl = 2;
}
message RpcExportSegmentationModel {
string sourcePath = 1;
}
message RpcSideloadDirtree {
string sourcePath = 1;
}

View File

@ -32,8 +32,10 @@ dependencies {
implementation project(':third-party:commons-codec')
implementation project(':code:libraries:message-queue')
implementation project(':code:libraries:term-frequency-dict')
implementation project(':code:functions:link-graph:api')
implementation project(':code:functions:search-query')
implementation project(':code:execution:api')
implementation project(':code:process-models:crawl-spec')

View File

@ -12,6 +12,7 @@ public enum ExecutorActor {
ADJACENCY_CALCULATION,
CRAWL_JOB_EXTRACTOR,
EXPORT_DATA,
EXPORT_SEGMENTATION_MODEL,
EXPORT_ATAGS,
EXPORT_TERM_FREQUENCIES,
EXPORT_FEEDS,

View File

@ -47,6 +47,7 @@ public class ExecutorActorControlService {
ExportFeedsActor exportFeedsActor,
ExportSampleDataActor exportSampleDataActor,
ExportTermFreqActor exportTermFrequenciesActor,
ExportSegmentationModelActor exportSegmentationModelActor,
DownloadSampleActor downloadSampleActor,
ExecutorActorStateMachines stateMachines) {
this.messageQueueFactory = messageQueueFactory;
@ -76,6 +77,7 @@ public class ExecutorActorControlService {
register(ExecutorActor.EXPORT_FEEDS, exportFeedsActor);
register(ExecutorActor.EXPORT_SAMPLE_DATA, exportSampleDataActor);
register(ExecutorActor.EXPORT_TERM_FREQUENCIES, exportTermFrequenciesActor);
register(ExecutorActor.EXPORT_SEGMENTATION_MODEL, exportSegmentationModelActor);
register(ExecutorActor.DOWNLOAD_SAMPLE, downloadSampleActor);
}

View File

@ -0,0 +1,55 @@
package nu.marginalia.actor.task;
import com.google.gson.Gson;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.actor.prototype.RecordActorPrototype;
import nu.marginalia.actor.state.ActorStep;
import nu.marginalia.segmentation.NgramExtractorMain;
import nu.marginalia.storage.FileStorageService;
import nu.marginalia.storage.model.FileStorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.file.Path;
import java.time.LocalDateTime;
@Singleton
public class ExportSegmentationModelActor extends RecordActorPrototype {
private final FileStorageService storageService;
private final Logger logger = LoggerFactory.getLogger(getClass());
public record Export(String zimFile) implements ActorStep {}
@Override
public ActorStep transition(ActorStep self) throws Exception {
return switch(self) {
case Export(String zimFile) -> {
var storage = storageService.allocateStorage(FileStorageType.EXPORT, "segmentation-model", "Segmentation Model Export " + LocalDateTime.now());
Path countsFile = storage.asPath().resolve("ngram-counts.bin");
NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile);
yield new End();
}
default -> new Error();
};
}
@Override
public String describe() {
return "Generate a query segmentation model from a ZIM file.";
}
@Inject
public ExportSegmentationModelActor(Gson gson,
FileStorageService storageService)
{
super(gson);
this.storageService = storageService;
}
}

View File

@ -6,10 +6,7 @@ import io.grpc.stub.StreamObserver;
import nu.marginalia.actor.ExecutorActor;
import nu.marginalia.actor.ExecutorActorControlService;
import nu.marginalia.actor.task.*;
import nu.marginalia.functions.execution.api.Empty;
import nu.marginalia.functions.execution.api.ExecutorExportApiGrpc;
import nu.marginalia.functions.execution.api.RpcExportSampleData;
import nu.marginalia.functions.execution.api.RpcFileStorageId;
import nu.marginalia.functions.execution.api.*;
import nu.marginalia.storage.model.FileStorageId;
@Singleton
@ -92,4 +89,20 @@ public class ExecutorExportGrpcService extends ExecutorExportApiGrpc.ExecutorExp
responseObserver.onError(e);
}
}
@Override
public void exportSegmentationModel(RpcExportSegmentationModel request, StreamObserver<Empty> responseObserver) {
try {
actorControlService.startFrom(ExecutorActor.EXPORT_SEGMENTATION_MODEL,
new ExportSegmentationModelActor.Export(request.getSourcePath())
);
responseObserver.onNext(Empty.getDefaultInstance());
responseObserver.onCompleted();
}
catch (Exception e) {
responseObserver.onError(e);
}
}
}

View File

@ -19,6 +19,7 @@ dependencies {
implementation project(':code:common:process')
implementation project(':code:features-convert:keyword-extraction')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:term-frequency-dict')
implementation libs.bundles.slf4j

View File

@ -5,6 +5,7 @@ import nu.marginalia.keyword.KeywordExtractor;
import nu.marginalia.language.sentence.SentenceExtractor;
import nu.marginalia.model.EdgeDomain;
import nu.marginalia.model.EdgeUrl;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.util.TestLanguageModels;
import org.junit.jupiter.api.Test;

View File

@ -26,13 +26,13 @@ public class TestLanguageModels {
var languageModelsHome = getLanguageModelsPath();
return new LanguageModels(
languageModelsHome.resolve("ngrams.bin"),
languageModelsHome.resolve("tfreq-new-algo3.bin"),
languageModelsHome.resolve("opennlp-sentence.bin"),
languageModelsHome.resolve("English.RDR"),
languageModelsHome.resolve("English.DICT"),
languageModelsHome.resolve("opennlp-tokens.bin"),
languageModelsHome.resolve("lid.176.ftz")
languageModelsHome.resolve("lid.176.ftz"),
languageModelsHome.resolve("segments.bin")
);
}
}

View File

@ -21,6 +21,7 @@ dependencies {
implementation project(':code:common:model')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:term-frequency-dict')
implementation project(':code:libraries:blocking-thread-pool')
implementation project(':code:features-crawl:link-parser')
implementation project(':code:features-convert:anchor-keywords')
implementation project(':code:process-models:crawling-model')

View File

@ -14,6 +14,7 @@ import nu.marginalia.process.log.WorkLog;
import nu.marginalia.storage.FileStorageService;
import nu.marginalia.storage.model.FileStorage;
import nu.marginalia.storage.model.FileStorageId;
import nu.marginalia.util.SimpleBlockingThreadPool;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.slf4j.Logger;
@ -53,26 +54,22 @@ public class TermFrequencyExporter implements ExporterIf {
TLongIntHashMap counts = new TLongIntHashMap(100_000_000, 0.7f, -1, -1);
AtomicInteger docCount = new AtomicInteger();
try (ForkJoinPool fjp = new ForkJoinPool(Math.max(2, Runtime.getRuntime().availableProcessors() / 2))) {
SimpleBlockingThreadPool sjp = new SimpleBlockingThreadPool("exporter", Math.clamp(2, 16, Runtime.getRuntime().availableProcessors() / 2), 4);
Path crawlerLogFile = inputDir.resolve("crawler.log");
for (var item : WorkLog.iterable(crawlerLogFile)) {
if (Thread.interrupted()) {
fjp.shutdownNow();
sjp.shutDownNow();
throw new InterruptedException();
}
Path crawlDataPath = inputDir.resolve(item.relPath());
fjp.execute(() -> processFile(crawlDataPath, counts, docCount, se.get()));
sjp.submitQuietly(() -> processFile(crawlDataPath, counts, docCount, se.get()));
}
while (!fjp.isQuiescent()) {
if (fjp.awaitQuiescence(10, TimeUnit.SECONDS))
break;
}
}
sjp.shutDown();
sjp.awaitTermination(10, TimeUnit.DAYS);
var tmpFile = Files.createTempFile(destStorage.asPath(), "freqs", ".dat.tmp",
PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rw-r--r--")));
@ -127,6 +124,10 @@ public class TermFrequencyExporter implements ExporterIf {
for (var word : sent) {
words.add(longHash(word.stemmed().getBytes(StandardCharsets.UTF_8)));
}
for (var ngram : sent.ngramStemmed) {
words.add(longHash(ngram.getBytes()));
}
}
synchronized (counts) {

View File

@ -1,5 +1,6 @@
package nu.marginalia.keyword;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.keyword.extractors.*;
import nu.marginalia.keyword.model.DocumentKeywordsBuilder;
import nu.marginalia.language.model.DocumentLanguageData;
@ -15,11 +16,13 @@ public class DocumentKeywordExtractor {
private final KeywordExtractor keywordExtractor;
private final TermFrequencyDict dict;
private final NgramLexicon ngramLexicon;
@Inject
public DocumentKeywordExtractor(TermFrequencyDict dict) {
public DocumentKeywordExtractor(TermFrequencyDict dict, NgramLexicon ngramLexicon) {
this.dict = dict;
this.ngramLexicon = ngramLexicon;
this.keywordExtractor = new KeywordExtractor();
}
@ -131,6 +134,17 @@ public class DocumentKeywordExtractor {
wordsBuilder.add(rep.word, meta);
}
for (int i = 0; i < sent.ngrams.length; i++) {
var ngram = sent.ngrams[i];
var ngramStemmed = sent.ngramStemmed[i];
long meta = metadata.getMetadataForWord(ngramStemmed);
assert meta != 0L : "Missing meta for " + ngram;
wordsBuilder.add(ngram, meta);
}
}
}

View File

@ -14,7 +14,9 @@ public class KeywordPositionBitmask {
private static final int unmodulatedPortion = 16;
@Inject
public KeywordPositionBitmask(KeywordExtractor keywordExtractor, DocumentLanguageData dld) {
public KeywordPositionBitmask(KeywordExtractor keywordExtractor,
DocumentLanguageData dld)
{
// Mark the title words as position 0
for (var sent : dld.titleSentences) {
@ -24,6 +26,10 @@ public class KeywordPositionBitmask {
positionMask.merge(word.stemmed(), posBit, this::bitwiseOr);
}
for (var ngram : sent.ngramStemmed) {
positionMask.merge(ngram, posBit, this::bitwiseOr);
}
for (var span : keywordExtractor.getKeywordsFromSentence(sent)) {
positionMask.merge(sent.constructStemmedWordFromSpan(span), posBit, this::bitwiseOr);
}
@ -43,6 +49,10 @@ public class KeywordPositionBitmask {
positionMask.merge(word.stemmed(), posBit, this::bitwiseOr);
}
for (var ngram : sent.ngramStemmed) {
positionMask.merge(ngram, posBit, this::bitwiseOr);
}
for (var span : keywordExtractor.getKeywordsFromSentence(sent)) {
positionMask.merge(sent.constructStemmedWordFromSpan(span), posBit, this::bitwiseOr);
}

View File

@ -5,6 +5,7 @@ import nu.marginalia.converting.processor.logic.dom.DomPruningFilter;
import nu.marginalia.language.sentence.SentenceExtractor;
import nu.marginalia.model.EdgeUrl;
import nu.marginalia.model.idx.WordMetadata;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.jsoup.Jsoup;
import org.junit.jupiter.api.Assertions;
@ -20,7 +21,9 @@ import java.util.Set;
class DocumentKeywordExtractorTest {
DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels()));
DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(
new TermFrequencyDict(WmsaHome.getLanguageModels()),
new NgramLexicon(WmsaHome.getLanguageModels()));
SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels());
@Test
@ -56,6 +59,22 @@ class DocumentKeywordExtractorTest {
}
@Test
public void testKeyboards2() throws IOException, URISyntaxException {
var resource = Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("test-data/keyboards.html"),
"Could not load word frequency table");
String html = new String(resource.readAllBytes(), Charset.defaultCharset());
var doc = Jsoup.parse(html);
doc.filter(new DomPruningFilter(0.5));
var keywords = extractor.extractKeywords(se.extractSentences(doc), new EdgeUrl("https://pmortensen.eu/world2/2021/12/24/rapoo-mechanical-keyboards-gotchas-and-setup/"));
keywords.getWords().forEach((k, v) -> {
if (k.contains("_")) {
System.out.println(k + " " + new WordMetadata(v));
}
});
}
@Test
public void testKeyboards() throws IOException, URISyntaxException {
var resource = Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("test-data/keyboards.html"),
@ -119,7 +138,9 @@ class DocumentKeywordExtractorTest {
var doc = Jsoup.parse(html);
doc.filter(new DomPruningFilter(0.5));
DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels()));
DocumentKeywordExtractor extractor = new DocumentKeywordExtractor(
new TermFrequencyDict(WmsaHome.getLanguageModels()),
new NgramLexicon(WmsaHome.getLanguageModels()));
SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels());
var keywords = extractor.extractKeywords(se.extractSentences(doc), new EdgeUrl("https://math.byu.edu/wiki/index.php/All_You_Need_To_Know_About_Earning_Money_Online"));

View File

@ -3,6 +3,7 @@ package nu.marginalia.keyword;
import lombok.SneakyThrows;
import nu.marginalia.LanguageModels;
import nu.marginalia.language.sentence.SentenceExtractor;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import nu.marginalia.WmsaHome;
import nu.marginalia.model.EdgeUrl;
@ -20,9 +21,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag("slow")
class SentenceExtractorTest {
final LanguageModels lm = TestLanguageModels.getLanguageModels();
static final LanguageModels lm = TestLanguageModels.getLanguageModels();
SentenceExtractor se = new SentenceExtractor(lm);
static NgramLexicon ngramLexicon = new NgramLexicon(lm);
static SentenceExtractor se = new SentenceExtractor(lm);
@SneakyThrows
public static void main(String... args) throws IOException {
@ -32,11 +34,9 @@ class SentenceExtractorTest {
System.out.println("Running");
SentenceExtractor se = new SentenceExtractor(lm);
var dict = new TermFrequencyDict(lm);
var url = new EdgeUrl("https://memex.marginalia.nu/");
DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor(dict);
DocumentKeywordExtractor documentKeywordExtractor = new DocumentKeywordExtractor(dict, ngramLexicon);
for (;;) {
long total = 0;

View File

@ -26,13 +26,13 @@ public class TestLanguageModels {
var languageModelsHome = getLanguageModelsPath();
return new LanguageModels(
languageModelsHome.resolve("ngrams.bin"),
languageModelsHome.resolve("tfreq-new-algo3.bin"),
languageModelsHome.resolve("opennlp-sentence.bin"),
languageModelsHome.resolve("English.RDR"),
languageModelsHome.resolve("English.DICT"),
languageModelsHome.resolve("opennlp-tokens.bin"),
languageModelsHome.resolve("lid.176.ftz")
languageModelsHome.resolve("lid.176.ftz"),
languageModelsHome.resolve("segments.bin")
);
}
}

View File

@ -5,6 +5,7 @@ import nu.marginalia.WmsaHome;
import nu.marginalia.keyword.DocumentKeywordExtractor;
import nu.marginalia.language.sentence.SentenceExtractor;
import nu.marginalia.model.EdgeUrl;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.summary.heuristic.*;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.jsoup.Jsoup;
@ -25,7 +26,9 @@ class SummaryExtractorTest {
@BeforeEach
public void setUp() {
keywordExtractor = new DocumentKeywordExtractor(new TermFrequencyDict(WmsaHome.getLanguageModels()));
keywordExtractor = new DocumentKeywordExtractor(
new TermFrequencyDict(WmsaHome.getLanguageModels()),
new NgramLexicon(WmsaHome.getLanguageModels()));
setenceExtractor = new SentenceExtractor(WmsaHome.getLanguageModels());
summaryExtractor = new SummaryExtractor(255,

View File

@ -30,6 +30,7 @@ dependencies {
implementation libs.notnull
implementation libs.guice
implementation libs.gson
implementation libs.commons.lang3
implementation libs.bundles.protobuf
implementation libs.bundles.grpc
implementation libs.fastutil

View File

@ -1,7 +1,6 @@
package nu.marginalia.api.searchquery;
import nu.marginalia.api.searchquery.*;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.index.query.limit.QueryLimits;
@ -45,33 +44,37 @@ public class IndexProtobufCodec {
.build();
}
public static SearchSubquery convertSearchSubquery(RpcSubquery subquery) {
public static SearchQuery convertRpcQuery(RpcQuery query) {
List<List<String>> coherences = new ArrayList<>();
for (int j = 0; j < subquery.getCoherencesCount(); j++) {
var coh = subquery.getCoherences(j);
for (int j = 0; j < query.getCoherencesCount(); j++) {
var coh = query.getCoherences(j);
coherences.add(new ArrayList<>(coh.getCoherencesList()));
}
return new SearchSubquery(
subquery.getIncludeList(),
subquery.getExcludeList(),
subquery.getAdviceList(),
subquery.getPriorityList(),
return new SearchQuery(
query.getCompiledQuery(),
query.getIncludeList(),
query.getExcludeList(),
query.getAdviceList(),
query.getPriorityList(),
coherences
);
}
public static RpcSubquery convertSearchSubquery(SearchSubquery searchSubquery) {
public static RpcQuery convertRpcQuery(SearchQuery searchQuery) {
var subqueryBuilder =
RpcSubquery.newBuilder()
.addAllAdvice(searchSubquery.getSearchTermsAdvice())
.addAllExclude(searchSubquery.getSearchTermsExclude())
.addAllInclude(searchSubquery.getSearchTermsInclude())
.addAllPriority(searchSubquery.getSearchTermsPriority());
for (var coherences : searchSubquery.searchTermCoherences) {
RpcQuery.newBuilder()
.setCompiledQuery(searchQuery.compiledQuery)
.addAllInclude(searchQuery.getSearchTermsInclude())
.addAllAdvice(searchQuery.getSearchTermsAdvice())
.addAllExclude(searchQuery.getSearchTermsExclude())
.addAllPriority(searchQuery.getSearchTermsPriority());
for (var coherences : searchQuery.searchTermCoherences) {
subqueryBuilder.addCoherencesBuilder().addAllCoherences(coherences);
}
return subqueryBuilder.build();
}

View File

@ -2,7 +2,6 @@ package nu.marginalia.api.searchquery;
import lombok.SneakyThrows;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.api.searchquery.model.results.SearchResultItem;
@ -14,7 +13,6 @@ import nu.marginalia.api.searchquery.model.query.QueryParams;
import nu.marginalia.api.searchquery.model.query.QueryResponse;
import java.util.ArrayList;
import java.util.List;
public class QueryProtobufCodec {
@ -23,9 +21,7 @@ public class QueryProtobufCodec {
builder.addAllDomains(request.getDomainIdsList());
for (var subquery : query.specs.subqueries) {
builder.addSubqueries(IndexProtobufCodec.convertSearchSubquery(subquery));
}
builder.setQuery(IndexProtobufCodec.convertRpcQuery(query.specs.query));
builder.setSearchSetIdentifier(query.specs.searchSetIdentifier);
builder.setHumanQuery(request.getHumanQuery());
@ -51,9 +47,7 @@ public class QueryProtobufCodec {
public static RpcIndexQuery convertQuery(String humanQuery, ProcessedQuery query) {
var builder = RpcIndexQuery.newBuilder();
for (var subquery : query.specs.subqueries) {
builder.addSubqueries(IndexProtobufCodec.convertSearchSubquery(subquery));
}
builder.setQuery(IndexProtobufCodec.convertRpcQuery(query.specs.query));
builder.setSearchSetIdentifier(query.specs.searchSetIdentifier);
builder.setHumanQuery(humanQuery);
@ -127,6 +121,7 @@ public class QueryProtobufCodec {
results.getPubYear(), // ??,
results.getDataHash(),
results.getWordsTotal(),
results.getBestPositions(),
results.getRankingScore()
);
}
@ -139,31 +134,26 @@ public class QueryProtobufCodec {
return new SearchResultItem(
rawItem.getCombinedId(),
rawItem.getEncodedDocMetadata(),
rawItem.getHtmlFeatures(),
keywordScores,
rawItem.getResultsFromDomain(),
rawItem.getHasPriorityTerms(),
Double.NaN // Not set
);
}
private static SearchResultKeywordScore convertKeywordScore(RpcResultKeywordScore keywordScores) {
return new SearchResultKeywordScore(
keywordScores.getSubquery(),
keywordScores.getKeyword(),
keywordScores.getEncodedWordMetadata(),
keywordScores.getEncodedDocMetadata(),
keywordScores.getHtmlFeatures()
-1, // termId is internal to index service
keywordScores.getEncodedWordMetadata()
);
}
private static SearchSpecification convertSearchSpecification(RpcIndexQuery specs) {
List<SearchSubquery> subqueries = new ArrayList<>(specs.getSubqueriesCount());
for (int i = 0; i < specs.getSubqueriesCount(); i++) {
subqueries.add(IndexProtobufCodec.convertSearchSubquery(specs.getSubqueries(i)));
}
return new SearchSpecification(
subqueries,
IndexProtobufCodec.convertRpcQuery(specs.getQuery()),
specs.getDomainsList(),
specs.getSearchSetIdentifier(),
specs.getHumanQuery(),
@ -182,7 +172,6 @@ public class QueryProtobufCodec {
.addAllDomainIds(params.domainIds())
.addAllTacitAdvice(params.tacitAdvice())
.addAllTacitExcludes(params.tacitExcludes())
.addAllTacitIncludes(params.tacitIncludes())
.addAllTacitPriority(params.tacitPriority())
.setHumanQuery(params.humanQuery())
.setQueryLimits(IndexProtobufCodec.convertQueryLimits(params.limits()))
@ -215,6 +204,7 @@ public class QueryProtobufCodec {
rpcDecoratedResultItem.getPubYear(),
rpcDecoratedResultItem.getDataHash(),
rpcDecoratedResultItem.getWordsTotal(),
rpcDecoratedResultItem.getBestPositions(),
rpcDecoratedResultItem.getRankingScore()
);
}

View File

@ -0,0 +1,80 @@
package nu.marginalia.api.searchquery.model.compiled;
import org.jetbrains.annotations.NotNull;
import java.util.Iterator;
import java.util.function.*;
import java.util.stream.IntStream;
import java.util.stream.Stream;
/** A compiled index service query. The class separates the topology of the query from the data,
* and it's possible to create new queries supplanting the data */
public class CompiledQuery<T> implements Iterable<T> {
/** The root expression, conveys the topology of the query */
public final CqExpression root;
private final CqData<T> data;
public CompiledQuery(CqExpression root, CqData<T> data) {
this.root = root;
this.data = data;
}
public CompiledQuery(CqExpression root, T[] data) {
this.root = root;
this.data = new CqData<>(data);
}
/** Exists for testing, creates a simple query that ANDs all the provided items */
public static <T> CompiledQuery<T> just(T... item) {
return new CompiledQuery<>(new CqExpression.And(
IntStream.range(0, item.length).mapToObj(CqExpression.Word::new).toList()
), item);
}
/** Create a new CompiledQuery mapping the leaf nodes using the provided mapper */
public <T2> CompiledQuery<T2> map(Class<T2> clazz, Function<T, T2> mapper) {
return new CompiledQuery<>(
root,
data.map(clazz, mapper)
);
}
public CompiledQueryLong mapToLong(ToLongFunction<T> mapper) {
return new CompiledQueryLong(root, data.mapToLong(mapper));
}
public CompiledQueryLong mapToInt(ToIntFunction<T> mapper) {
return new CompiledQueryLong(root, data.mapToInt(mapper));
}
public CqExpression root() {
return root;
}
public Stream<T> stream() {
return data.stream();
}
public IntStream indices() {
return IntStream.range(0, data.size());
}
public T at(int index) {
return data.get(index);
}
@NotNull
@Override
public Iterator<T> iterator() {
return stream().iterator();
}
public int size() {
return data.size();
}
}

View File

@ -0,0 +1,44 @@
package nu.marginalia.api.searchquery.model.compiled;
import java.util.stream.IntStream;
/** A compiled index service query */
public class CompiledQueryInt {
private final CqExpression root;
private final CqDataInt data;
public CompiledQueryInt(CqExpression root, CqDataInt data) {
this.root = root;
this.data = data;
}
public CqExpression root() {
return root;
}
public IntStream stream() {
return data.stream();
}
public IntStream indices() {
return IntStream.range(0, data.size());
}
public long at(int index) {
return data.get(index);
}
public int[] copyData() {
return data.copyData();
}
public boolean isEmpty() {
return data.size() == 0;
}
public int size() {
return data.size();
}
}

View File

@ -0,0 +1,54 @@
package nu.marginalia.api.searchquery.model.compiled;
import org.jetbrains.annotations.NotNull;
import java.util.Iterator;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
/** A compiled index service query */
public class CompiledQueryLong implements Iterable<Long> {
public final CqExpression root;
public final CqDataLong data;
public CompiledQueryLong(CqExpression root, CqDataLong data) {
this.root = root;
this.data = data;
}
public CqExpression root() {
return root;
}
public LongStream stream() {
return data.stream();
}
public IntStream indices() {
return IntStream.range(0, data.size());
}
public long at(int index) {
return data.get(index);
}
@NotNull
@Override
public Iterator<Long> iterator() {
return stream().iterator();
}
public long[] copyData() {
return data.copyData();
}
public boolean isEmpty() {
return data.size() == 0;
}
public int size() {
return data.size();
}
}

View File

@ -0,0 +1,113 @@
package nu.marginalia.api.searchquery.model.compiled;
import org.apache.commons.lang3.StringUtils;
import java.util.*;
/** Parser for a compiled index query */
public class CompiledQueryParser {
public static CompiledQuery<String> parse(String query) {
List<String> parts = tokenize(query);
if (parts.isEmpty()) {
return new CompiledQuery<>(
CqExpression.empty(),
new CqData<>(new String[0])
);
}
// We aren't interested in a binary tree representation, but an n-ary tree one,
// so a somewhat unusual parsing technique is used to avoid having an additional
// flattening step at the end.
// This is only possible due to the trivial and unambiguous grammar of the compiled queries
List<AndOrState> parenState = new ArrayList<>();
parenState.add(new AndOrState());
Map<String, Integer> wordIds = new HashMap<>();
for (var part : parts) {
var head = parenState.getLast();
if (part.equals("|")) {
head.or();
}
else if (part.equals("(")) {
parenState.addLast(new AndOrState());
}
else if (part.equals(")")) {
if (parenState.size() < 2) {
throw new IllegalStateException("Mismatched parentheses in expression: " + query);
}
parenState.removeLast();
parenState.getLast().and(head.closeOr());
}
else {
head.and(
new CqExpression.Word(
wordIds.computeIfAbsent(part, p -> wordIds.size())
)
);
}
}
if (parenState.size() != 1)
throw new IllegalStateException("Mismatched parentheses in expression: " + query);
// Construct the CompiledQuery object with String:s as leaves
var root = parenState.getLast().closeOr();
String[] cqData = new String[wordIds.size()];
wordIds.forEach((w, i) -> cqData[i] = w);
return new CompiledQuery<>(root, new CqData<>(cqData));
}
private static class AndOrState {
private List<CqExpression> andState = new ArrayList<>();
private List<CqExpression> orState = new ArrayList<>();
/** Add a new item to the and-list */
public void and(CqExpression e) {
andState.add(e);
}
/** Turn the and-list into an expression on the or-list, and then start a new and-list */
public void or() {
closeAnd();
andState = new ArrayList<>();
}
/** Turn the and-list into an And-expression in the or-list */
private void closeAnd() {
if (andState.size() == 1)
orState.add(andState.getFirst());
else if (!andState.isEmpty())
orState.add(new CqExpression.And(andState));
}
/** Finalize the current and-list, then turn the or-list into an Or-expression */
public CqExpression closeOr() {
closeAnd();
if (orState.isEmpty())
return CqExpression.empty();
if (orState.size() == 1)
return orState.getFirst();
return new CqExpression.Or(orState);
}
}
private static List<String> tokenize(String query) {
// Each token is guaranteed to be separated by one or more space characters
return Arrays.stream(StringUtils.split(query, ' '))
.filter(StringUtils::isNotBlank)
.toList();
}
}

View File

@ -0,0 +1,60 @@
package nu.marginalia.api.searchquery.model.compiled;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.function.Function;
import java.util.function.ToIntFunction;
import java.util.function.ToLongFunction;
import java.util.stream.Stream;
public class CqData<T> {
private final T[] data;
public CqData(T[] data) {
this.data = data;
}
@SuppressWarnings("unchecked")
public <T2> CqData<T2> map(Class<T2> clazz, Function<T, T2> mapper) {
T2[] newData = (T2[]) Array.newInstance(clazz, data.length);
for (int i = 0; i < data.length; i++) {
newData[i] = mapper.apply((T) data[i]);
}
return new CqData<>(newData);
}
public CqDataLong mapToLong(ToLongFunction<T> mapper) {
long[] newData = new long[data.length];
for (int i = 0; i < data.length; i++) {
newData[i] = mapper.applyAsLong((T) data[i]);
}
return new CqDataLong(newData);
}
public CqDataLong mapToInt(ToIntFunction<T> mapper) {
long[] newData = new long[data.length];
for (int i = 0; i < data.length; i++) {
newData[i] = mapper.applyAsInt((T) data[i]);
}
return new CqDataLong(newData);
}
public T get(int i) {
return data[i];
}
public T get(CqExpression.Word w) {
return data[w.idx()];
}
public Stream<T> stream() {
return Arrays.stream(data);
}
public int size() {
return data.length;
}
}

View File

@ -0,0 +1,31 @@
package nu.marginalia.api.searchquery.model.compiled;
import java.util.Arrays;
import java.util.stream.IntStream;
public class CqDataInt {
private final int[] data;
public CqDataInt(int[] data) {
this.data = data;
}
public int get(int i) {
return data[i];
}
public int get(CqExpression.Word w) {
return data[w.idx()];
}
public IntStream stream() {
return Arrays.stream(data);
}
public int size() {
return data.length;
}
public int[] copyData() {
return Arrays.copyOf(data, data.length);
}
}

View File

@ -0,0 +1,31 @@
package nu.marginalia.api.searchquery.model.compiled;
import java.util.Arrays;
import java.util.stream.LongStream;
public class CqDataLong {
private final long[] data;
public CqDataLong(long[] data) {
this.data = data;
}
public long get(int i) {
return data[i];
}
public long get(CqExpression.Word w) {
return data[w.idx()];
}
public LongStream stream() {
return Arrays.stream(data);
}
public int size() {
return data.length;
}
public long[] copyData() {
return Arrays.copyOf(data, data.length);
}
}

View File

@ -0,0 +1,170 @@
package nu.marginalia.api.searchquery.model.compiled;
import java.util.List;
import java.util.StringJoiner;
import java.util.stream.Stream;
/** Expression in a parsed index service query
*
*/
public sealed interface CqExpression {
Stream<Word> stream();
/** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */
long visit(LongVisitor visitor);
/** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */
double visit(DoubleVisitor visitor);
/** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */
int visit(IntVisitor visitor);
/** @see nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates */
boolean visit(BoolVisitor visitor);
<T> T visit(ObjectVisitor<T> visitor);
static CqExpression empty() {
return new Or(List.of());
}
record And(List<? extends CqExpression> parts) implements CqExpression {
@Override
public Stream<Word> stream() {
return parts.stream().flatMap(CqExpression::stream);
}
@Override
public long visit(LongVisitor visitor) {
return visitor.onAnd(parts);
}
@Override
public double visit(DoubleVisitor visitor) {
return visitor.onAnd(parts);
}
@Override
public int visit(IntVisitor visitor) {
return visitor.onAnd(parts);
}
@Override
public boolean visit(BoolVisitor visitor) {
return visitor.onAnd(parts);
}
@Override
public <T> T visit(ObjectVisitor<T> visitor) { return visitor.onAnd(parts); }
public String toString() {
StringJoiner sj = new StringJoiner(", ", "And[ ", "]");
parts.forEach(part -> sj.add(part.toString()));
return sj.toString();
}
}
record Or(List<? extends CqExpression> parts) implements CqExpression {
@Override
public Stream<Word> stream() {
return parts.stream().flatMap(CqExpression::stream);
}
@Override
public long visit(LongVisitor visitor) {
return visitor.onOr(parts);
}
@Override
public double visit(DoubleVisitor visitor) {
return visitor.onOr(parts);
}
@Override
public int visit(IntVisitor visitor) {
return visitor.onOr(parts);
}
@Override
public boolean visit(BoolVisitor visitor) {
return visitor.onOr(parts);
}
@Override
public <T> T visit(ObjectVisitor<T> visitor) { return visitor.onOr(parts); }
public String toString() {
StringJoiner sj = new StringJoiner(", ", "Or[ ", "]");
parts.forEach(part -> sj.add(part.toString()));
return sj.toString();
}
}
record Word(int idx) implements CqExpression {
@Override
public Stream<Word> stream() {
return Stream.of(this);
}
@Override
public long visit(LongVisitor visitor) {
return visitor.onLeaf(idx);
}
@Override
public double visit(DoubleVisitor visitor) {
return visitor.onLeaf(idx);
}
@Override
public int visit(IntVisitor visitor) {
return visitor.onLeaf(idx);
}
@Override
public boolean visit(BoolVisitor visitor) {
return visitor.onLeaf(idx);
}
@Override
public <T> T visit(ObjectVisitor<T> visitor) { return visitor.onLeaf(idx); }
@Override
public String toString() {
return Integer.toString(idx);
}
}
interface LongVisitor {
long onAnd(List<? extends CqExpression> parts);
long onOr(List<? extends CqExpression> parts);
long onLeaf(int idx);
}
interface IntVisitor {
int onAnd(List<? extends CqExpression> parts);
int onOr(List<? extends CqExpression> parts);
int onLeaf(int idx);
}
interface BoolVisitor {
boolean onAnd(List<? extends CqExpression> parts);
boolean onOr(List<? extends CqExpression> parts);
boolean onLeaf(int idx);
}
interface DoubleVisitor {
double onAnd(List<? extends CqExpression> parts);
double onOr(List<? extends CqExpression> parts);
double onLeaf(int idx);
}
interface ObjectVisitor<T> {
T onAnd(List<? extends CqExpression> parts);
T onOr(List<? extends CqExpression> parts);
T onLeaf(int idx);
}
}

View File

@ -0,0 +1,67 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import it.unimi.dsi.fastutil.longs.LongSet;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import java.util.ArrayList;
import java.util.List;
import java.util.function.*;
/** Contains methods for aggregating across a CompiledQuery or CompiledQueryLong */
public class CompiledQueryAggregates {
/** Compiled query aggregate that for a single boolean that treats or-branches as logical OR,
* and and-branches as logical AND operations. Will return true if there exists a path through
* the query where the provided predicate returns true for each item.
*/
static public <T> boolean booleanAggregate(CompiledQuery<T> query, Predicate<T> predicate) {
return query.root.visit(new CqBooleanAggregate(query, predicate));
}
static public boolean booleanAggregate(CompiledQueryLong query, LongPredicate predicate) {
return query.root.visit(new CqBooleanAggregate(query, predicate));
}
/** Compiled query aggregate that for a 64b bitmask that treats or-branches as logical OR,
* and and-branches as logical AND operations.
*/
public static <T> long longBitmaskAggregate(CompiledQuery<T> query, ToLongFunction<T> operator) {
return query.root.visit(new CqLongBitmaskOperator(query, operator));
}
public static long longBitmaskAggregate(CompiledQueryLong query, LongUnaryOperator operator) {
return query.root.visit(new CqLongBitmaskOperator(query, operator));
}
/** Apply the operator to each leaf node, then return the highest minimum value found along any path */
public static <T> int intMaxMinAggregate(CompiledQuery<T> query, ToIntFunction<T> operator) {
return query.root.visit(new CqIntMaxMinOperator(query, operator));
}
/** Apply the operator to each leaf node, then return the highest minimum value found along any path */
public static int intMaxMinAggregate(CompiledQueryLong query, LongToIntFunction operator) {
return query.root.visit(new CqIntMaxMinOperator(query, operator));
}
/** Apply the operator to each leaf node, and then return the highest sum of values possible
* through each branch in the compiled query.
*
*/
public static <T> double doubleSumAggregate(CompiledQuery<T> query, ToDoubleFunction<T> operator) {
return query.root.visit(new CqDoubleSumOperator(query, operator));
}
/** Enumerate all possible paths through the compiled query */
public static List<LongSet> queriesAggregate(CompiledQueryLong query) {
return new ArrayList<>(query.root().visit(new CqQueryPathsOperator(query)));
}
/** Using the bitwise AND operator, aggregate all possible combined values of the long generated by the provided operator */
public static <T> LongSet positionsAggregate(CompiledQuery<T> query, ToLongFunction<T> operator) {
return query.root().visit(new CqPositionsOperator(query, operator));
}
/** Using the bitwise AND operator, aggregate all possible combined values of the long generated by the provided operator */
public static <T> LongSet positionsAggregate(CompiledQueryLong query, LongUnaryOperator operator) {
return query.root().visit(new CqPositionsOperator(query, operator));
}
}

View File

@ -0,0 +1,46 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.List;
import java.util.function.IntPredicate;
import java.util.function.LongPredicate;
import java.util.function.Predicate;
public class CqBooleanAggregate implements CqExpression.BoolVisitor {
private final IntPredicate predicate;
public <T> CqBooleanAggregate(CompiledQuery<T> query, Predicate<T> objPred) {
this.predicate = idx -> objPred.test(query.at(idx));
}
public CqBooleanAggregate(CompiledQueryLong query, LongPredicate longPredicate) {
this.predicate = idx -> longPredicate.test(query.at(idx));
}
@Override
public boolean onAnd(List<? extends CqExpression> parts) {
for (var part : parts) {
if (!part.visit(this)) // short-circuit
return false;
}
return true;
}
@Override
public boolean onOr(List<? extends CqExpression> parts) {
for (var part : parts) {
if (part.visit(this)) // short-circuit
return true;
}
return false;
}
@Override
public boolean onLeaf(int idx) {
return predicate.test(idx);
}
}

View File

@ -0,0 +1,46 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.List;
import java.util.function.IntToDoubleFunction;
import java.util.function.LongToDoubleFunction;
import java.util.function.ToDoubleFunction;
public class CqDoubleSumOperator implements CqExpression.DoubleVisitor {
private final IntToDoubleFunction operator;
public <T> CqDoubleSumOperator(CompiledQuery<T> query, ToDoubleFunction<T> operator) {
this.operator = idx -> operator.applyAsDouble(query.at(idx));
}
public CqDoubleSumOperator(IntToDoubleFunction operator) {
this.operator = operator;
}
@Override
public double onAnd(List<? extends CqExpression> parts) {
double value = 0;
for (var part : parts) {
value += part.visit(this);
}
return value;
}
@Override
public double onOr(List<? extends CqExpression> parts) {
double value = parts.getFirst().visit(this);
for (int i = 1; i < parts.size(); i++) {
value = Math.max(value, parts.get(i).visit(this));
}
return value;
}
@Override
public double onLeaf(int idx) {
return operator.applyAsDouble(idx);
}
}

View File

@ -0,0 +1,47 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.List;
import java.util.function.IntUnaryOperator;
import java.util.function.LongToIntFunction;
import java.util.function.ToIntFunction;
public class CqIntMaxMinOperator implements CqExpression.IntVisitor {
private final IntUnaryOperator operator;
public <T> CqIntMaxMinOperator(CompiledQuery<T> query, ToIntFunction<T> operator) {
this.operator = idx -> operator.applyAsInt(query.at(idx));
}
public CqIntMaxMinOperator(CompiledQueryLong query, LongToIntFunction operator) {
this.operator = idx -> operator.applyAsInt(query.at(idx));
}
@Override
public int onAnd(List<? extends CqExpression> parts) {
int value = parts.getFirst().visit(this);
for (int i = 1; i < parts.size(); i++) {
value = Math.min(value, parts.get(i).visit(this));
}
return value;
}
@Override
public int onOr(List<? extends CqExpression> parts) {
int value = parts.getFirst().visit(this);
for (int i = 1; i < parts.size(); i++) {
value = Math.max(value, parts.get(i).visit(this));
}
return value;
}
@Override
public int onLeaf(int idx) {
return operator.applyAsInt(idx);
}
}

View File

@ -0,0 +1,45 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.List;
import java.util.function.IntToLongFunction;
import java.util.function.LongUnaryOperator;
import java.util.function.ToLongFunction;
public class CqLongBitmaskOperator implements CqExpression.LongVisitor {
private final IntToLongFunction operator;
public <T> CqLongBitmaskOperator(CompiledQuery<T> query, ToLongFunction<T> operator) {
this.operator = idx-> operator.applyAsLong(query.at(idx));
}
public CqLongBitmaskOperator(CompiledQueryLong query, LongUnaryOperator operator) {
this.operator = idx-> operator.applyAsLong(query.at(idx));
}
@Override
public long onAnd(List<? extends CqExpression> parts) {
long value = ~0L;
for (var part : parts) {
value &= part.visit(this);
}
return value;
}
@Override
public long onOr(List<? extends CqExpression> parts) {
long value = 0L;
for (var part : parts) {
value |= part.visit(this);
}
return value;
}
@Override
public long onLeaf(int idx) {
return operator.applyAsLong(idx);
}
}

View File

@ -0,0 +1,85 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.List;
import java.util.function.IntToLongFunction;
import java.util.function.LongUnaryOperator;
import java.util.function.ToLongFunction;
public class CqPositionsOperator implements CqExpression.ObjectVisitor<LongSet> {
private final IntToLongFunction operator;
public <T> CqPositionsOperator(CompiledQuery<T> query, ToLongFunction<T> operator) {
this.operator = idx -> operator.applyAsLong(query.at(idx));
}
public CqPositionsOperator(CompiledQueryLong query, LongUnaryOperator operator) {
this.operator = idx -> operator.applyAsLong(query.at(idx));
}
@Override
public LongSet onAnd(List<? extends CqExpression> parts) {
LongSet ret = new LongArraySet();
for (var part : parts) {
ret = comineSets(ret, part.visit(this));
}
return ret;
}
private LongSet comineSets(LongSet a, LongSet b) {
if (a.isEmpty())
return b;
if (b.isEmpty())
return a;
LongSet ret = newSet(a.size() * b.size());
var ai = a.longIterator();
while (ai.hasNext()) {
long aval = ai.nextLong();
var bi = b.longIterator();
while (bi.hasNext()) {
ret.add(aval & bi.nextLong());
}
}
return ret;
}
@Override
public LongSet onOr(List<? extends CqExpression> parts) {
LongSet ret = newSet(parts.size());
for (var part : parts) {
ret.addAll(part.visit(this));
}
return ret;
}
@Override
public LongSet onLeaf(int idx) {
var set = newSet(1);
set.add(operator.applyAsLong(idx));
return set;
}
/** Allocate a new set suitable for a collection with the provided cardinality */
private LongSet newSet(int cardinality) {
if (cardinality < 8)
return new LongArraySet(cardinality);
else
return new LongOpenHashSet(cardinality);
}
}

View File

@ -0,0 +1,75 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import java.util.ArrayList;
import java.util.List;
public class CqQueryPathsOperator implements CqExpression.ObjectVisitor<List<LongSet>> {
private final CompiledQueryLong query;
public CqQueryPathsOperator(CompiledQueryLong query) {
this.query = query;
}
@Override
public List<LongSet> onAnd(List<? extends CqExpression> parts) {
return parts.stream()
.map(expr -> expr.visit(this))
.reduce(List.of(), this::combineAnd);
}
private List<LongSet> combineAnd(List<LongSet> a, List<LongSet> b) {
// No-op cases
if (a.isEmpty())
return b;
if (b.isEmpty())
return a;
// Simple cases
if (a.size() == 1) {
b.forEach(set -> set.addAll(a.getFirst()));
return b;
}
else if (b.size() == 1) {
a.forEach(set -> set.addAll(b.getFirst()));
return a;
}
// Case where we AND two ORs
List<LongSet> ret = new ArrayList<>();
for (var aPart : a) {
for (var bPart : b) {
LongSet set = new LongOpenHashSet(aPart.size() + bPart.size());
set.addAll(aPart);
set.addAll(bPart);
ret.add(set);
}
}
return ret;
}
@Override
public List<LongSet> onOr(List<? extends CqExpression> parts) {
List<LongSet> ret = new ArrayList<>();
for (var part : parts) {
ret.addAll(part.visit(this));
}
return ret;
}
@Override
public List<LongSet> onLeaf(int idx) {
var set = new LongArraySet(1);
set.add(query.at(idx));
return List.of(set);
}
}

View File

@ -13,10 +13,6 @@ public record QueryResponse(SearchSpecification specs,
String domain)
{
public Set<String> getAllKeywords() {
Set<String> keywords = new HashSet<>(100);
for (var sq : specs.subqueries) {
keywords.addAll(sq.searchTermsInclude);
}
return keywords;
return new HashSet<>(specs.query.searchTermsInclude);
}
}

View File

@ -13,9 +13,12 @@ import java.util.stream.Collectors;
@AllArgsConstructor
@With
@EqualsAndHashCode
public class SearchSubquery {
public class SearchQuery {
/** These terms must be present in the document and are used in ranking*/
/** An infix style expression that encodes the required terms in the query */
public final String compiledQuery;
/** All terms that appear in {@see compiledQuery} */
public final List<String> searchTermsInclude;
/** These terms must be absent from the document */
@ -33,7 +36,8 @@ public class SearchSubquery {
@Deprecated // why does this exist?
private double value = 0;
public SearchSubquery() {
public SearchQuery() {
this.compiledQuery = "";
this.searchTermsInclude = new ArrayList<>();
this.searchTermsExclude = new ArrayList<>();
this.searchTermsAdvice = new ArrayList<>();
@ -41,11 +45,13 @@ public class SearchSubquery {
this.searchTermCoherences = new ArrayList<>();
}
public SearchSubquery(List<String> searchTermsInclude,
public SearchQuery(String compiledQuery,
List<String> searchTermsInclude,
List<String> searchTermsExclude,
List<String> searchTermsAdvice,
List<String> searchTermsPriority,
List<List<String>> searchTermCoherences) {
this.compiledQuery = compiledQuery;
this.searchTermsInclude = searchTermsInclude;
this.searchTermsExclude = searchTermsExclude;
this.searchTermsAdvice = searchTermsAdvice;
@ -54,7 +60,7 @@ public class SearchSubquery {
}
@Deprecated // why does this exist?
public SearchSubquery setValue(double value) {
public SearchQuery setValue(double value) {
if (Double.isInfinite(value) || Double.isNaN(value)) {
this.value = Double.MAX_VALUE;
} else {
@ -66,7 +72,7 @@ public class SearchSubquery {
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
if (!searchTermsInclude.isEmpty()) sb.append("include=").append(searchTermsInclude.stream().collect(Collectors.joining(",", "[", "] ")));
if (!compiledQuery.isEmpty()) sb.append("compiledQuery=").append(compiledQuery).append(", ");
if (!searchTermsExclude.isEmpty()) sb.append("exclude=").append(searchTermsExclude.stream().collect(Collectors.joining(",", "[", "] ")));
if (!searchTermsAdvice.isEmpty()) sb.append("advice=").append(searchTermsAdvice.stream().collect(Collectors.joining(",", "[", "] ")));
if (!searchTermsPriority.isEmpty()) sb.append("priority=").append(searchTermsPriority.stream().collect(Collectors.joining(",", "[", "] ")));

View File

@ -10,7 +10,7 @@ import java.util.List;
@ToString @Getter @Builder @With @AllArgsConstructor
public class SearchSpecification {
public List<SearchSubquery> subqueries;
public SearchQuery query;
/** If present and not empty, limit the search to these domain IDs */
public List<Integer> domains;

View File

@ -30,6 +30,7 @@ public class DecoratedSearchResultItem implements Comparable<DecoratedSearchResu
public final Integer pubYear;
public final long dataHash;
public final int wordsTotal;
public final long bestPositions;
public final double rankingScore;
public long documentId() {
@ -65,6 +66,7 @@ public class DecoratedSearchResultItem implements Comparable<DecoratedSearchResu
Integer pubYear,
long dataHash,
int wordsTotal,
long bestPositions,
double rankingScore)
{
this.rawIndexResult = rawIndexResult;
@ -77,6 +79,7 @@ public class DecoratedSearchResultItem implements Comparable<DecoratedSearchResu
this.pubYear = pubYear;
this.dataHash = dataHash;
this.wordsTotal = wordsTotal;
this.bestPositions = bestPositions;
this.rankingScore = rankingScore;
}

View File

@ -1,38 +1,34 @@
package nu.marginalia.api.searchquery.model.results;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import lombok.ToString;
import java.util.Map;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
@ToString
public class ResultRankingContext {
private final int docCount;
public final ResultRankingParameters params;
private final Object2IntOpenHashMap<String> fullCounts = new Object2IntOpenHashMap<>(10, 0.5f);
private final Object2IntOpenHashMap<String> priorityCounts = new Object2IntOpenHashMap<>(10, 0.5f);
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt fullCounts;
/** CqDataInt associated with frequency information of the terms in the query
* in the full index. The dataset is indexed by the compiled query. */
public final CqDataInt priorityCounts;
public ResultRankingContext(int docCount,
ResultRankingParameters params,
Map<String, Integer> fullCounts,
Map<String, Integer> prioCounts
) {
CqDataInt fullCounts,
CqDataInt prioCounts)
{
this.docCount = docCount;
this.params = params;
this.fullCounts.putAll(fullCounts);
this.priorityCounts.putAll(prioCounts);
this.fullCounts = fullCounts;
this.priorityCounts = prioCounts;
}
public int termFreqDocCount() {
return docCount;
}
public int frequency(String keyword) {
return fullCounts.getOrDefault(keyword, 1);
}
public int priorityFrequency(String keyword) {
return priorityCounts.getOrDefault(keyword, 1);
}
}

View File

@ -15,15 +15,30 @@ public class SearchResultItem implements Comparable<SearchResultItem> {
* probably not what you want, use getDocumentId() instead */
public final long combinedId;
/** Encoded document metadata */
public final long encodedDocMetadata;
/** Encoded html features of document */
public final int htmlFeatures;
/** How did the subqueries match against the document ? */
public final List<SearchResultKeywordScore> keywordScores;
/** How many other potential results existed in the same domain */
public int resultsFromDomain;
public SearchResultItem(long combinedId, int scoresCount) {
public boolean hasPrioTerm;
public SearchResultItem(long combinedId,
long encodedDocMetadata,
int htmlFeatures,
boolean hasPrioTerm) {
this.combinedId = combinedId;
this.keywordScores = new ArrayList<>(scoresCount);
this.encodedDocMetadata = encodedDocMetadata;
this.keywordScores = new ArrayList<>();
this.htmlFeatures = htmlFeatures;
this.hasPrioTerm = hasPrioTerm;
}
@ -76,4 +91,6 @@ public class SearchResultItem implements Comparable<SearchResultItem> {
return Long.compare(this.combinedId, o.combinedId);
}
}

View File

@ -2,41 +2,27 @@ package nu.marginalia.api.searchquery.model.results;
import nu.marginalia.model.idx.WordFlags;
import nu.marginalia.model.idx.WordMetadata;
import nu.marginalia.model.idx.DocumentMetadata;
import java.util.Objects;
public final class SearchResultKeywordScore {
public final int subquery;
public final long termId;
public final String keyword;
private final long encodedWordMetadata;
private final long encodedDocMetadata;
private final int htmlFeatures;
public SearchResultKeywordScore(int subquery,
String keyword,
long encodedWordMetadata,
long encodedDocMetadata,
int htmlFeatures) {
this.subquery = subquery;
public SearchResultKeywordScore(String keyword,
long termId,
long encodedWordMetadata) {
this.termId = termId;
this.keyword = keyword;
this.encodedWordMetadata = encodedWordMetadata;
this.encodedDocMetadata = encodedDocMetadata;
this.htmlFeatures = htmlFeatures;
}
public boolean hasTermFlag(WordFlags flag) {
return WordMetadata.hasFlags(encodedWordMetadata, flag.asBit());
}
public int positionCount() {
return Long.bitCount(positions());
}
public int subquery() {
return subquery;
}
public long positions() {
return WordMetadata.decodePositions(encodedWordMetadata);
}
@ -45,46 +31,28 @@ public final class SearchResultKeywordScore {
return keyword.contains(":") || hasTermFlag(WordFlags.Synthetic);
}
public boolean isKeywordRegular() {
return !keyword.contains(":")
&& !hasTermFlag(WordFlags.Synthetic);
}
public long encodedWordMetadata() {
return encodedWordMetadata;
}
public long encodedDocMetadata() {
return encodedDocMetadata;
}
public int htmlFeatures() {
return htmlFeatures;
}
@Override
public boolean equals(Object obj) {
if (obj == this) return true;
if (obj == null || obj.getClass() != this.getClass()) return false;
var that = (SearchResultKeywordScore) obj;
return this.subquery == that.subquery &&
Objects.equals(this.keyword, that.keyword) &&
this.encodedWordMetadata == that.encodedWordMetadata &&
this.encodedDocMetadata == that.encodedDocMetadata;
return Objects.equals(this.termId, that.termId);
}
@Override
public int hashCode() {
return Objects.hash(subquery, keyword, encodedWordMetadata, encodedDocMetadata);
return Objects.hash(termId);
}
@Override
public String toString() {
return "SearchResultKeywordScore[" +
"set=" + subquery + ", " +
"keyword=" + keyword + ", " +
"encodedWordMetadata=" + new WordMetadata(encodedWordMetadata) + ", " +
"encodedDocMetadata=" + new DocumentMetadata(encodedDocMetadata) + ']';
"encodedWordMetadata=" + new WordMetadata(encodedWordMetadata) + ']';
}
}

View File

@ -52,7 +52,7 @@ message RpcTemporalBias {
/* Index service query request */
message RpcIndexQuery {
repeated RpcSubquery subqueries = 1;
RpcQuery query = 1;
repeated int32 domains = 2; // (optional) A list of domain IDs to consider
string searchSetIdentifier = 3; // (optional) A named set of domains to consider
string humanQuery = 4; // The search query as the user entered it
@ -91,23 +91,23 @@ message RpcDecoratedResultItem {
int64 dataHash = 9;
int32 wordsTotal = 10;
double rankingScore = 11; // The ranking score of this search result item, lower is better
int64 bestPositions = 12;
}
/** A raw index-service view of a search result */
message RpcRawResultItem {
int64 combinedId = 1; // raw ID with bit-encoded ranking information still present
int32 resultsFromDomain = 2; // number of other results from the same domain
repeated RpcResultKeywordScore keywordScores = 3;
int64 encodedDocMetadata = 3; // bit encoded document metadata
int32 htmlFeatures = 4; // bitmask encoding features of the document
repeated RpcResultKeywordScore keywordScores = 5;
bool hasPriorityTerms = 6; // true if this word is important to the document
}
/* Information about how well a keyword matches a query */
message RpcResultKeywordScore {
int32 subquery = 1; // index of the subquery this keyword relates to
string keyword = 2; // the keyword
int64 encodedWordMetadata = 3; // bit encoded word metadata
int64 encodedDocMetadata = 4; // bit encoded document metadata
bool hasPriorityTerms = 5; // true if this word is important to the document
int32 htmlFeatures = 6; // bit encoded document features
string keyword = 1; // the keyword
int64 encodedWordMetadata = 2; // bit encoded word metadata
}
/* Query execution parameters */
@ -137,12 +137,13 @@ message RpcResultRankingParameters {
}
/* Defines a single subquery */
message RpcSubquery {
message RpcQuery {
repeated string include = 1; // These terms must be present
repeated string exclude = 2; // These terms must be absent
repeated string advice = 3; // These terms must be present, but do not affect ranking
repeated string priority = 4; // These terms are not mandatory, but affect ranking positively if they are present
repeated RpcCoherences coherences = 5; // Groups of terms that must exist in proximity of each other
string compiledQuery = 6; // Compiled query in infix notation
}
/* Defines a group of search terms that must exist in close proximity within the document */

View File

@ -0,0 +1,79 @@
package nu.marginalia.api.searchquery.model.compiled;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
class CompiledQueryParserTest {
@Test
public void testEmpty() {
assertEquals(CqExpression.empty(), CompiledQueryParser.parse("").root);
assertEquals(CqExpression.empty(), CompiledQueryParser.parse("( )").root);
assertEquals(CqExpression.empty(), CompiledQueryParser.parse("( | )").root);
assertEquals(CqExpression.empty(), CompiledQueryParser.parse("| ( | ) |").root);
}
@Test
public void testSingleWord() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo");
assertEquals(w(q, "foo"), q.root);
}
@Test
public void testAndTwoWords() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo bar");
assertEquals(and(w(q, "foo"), w(q,"bar")), q.root);
}
@Test
public void testOrTwoWords() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo | bar");
assertEquals(or(w(q, "foo"), w(q,"bar")), q.root);
}
@Test
public void testOrAndWords() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo | bar baz");
assertEquals(or(w(q,"foo"), and(w(q,"bar"), w(q,"baz"))), q.root);
}
@Test
public void testAndAndOrAndAndWords() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo foobar | bar baz");
assertEquals(or(
and(w(q, "foo"), w(q, "foobar")),
and(w(q, "bar"), w(q, "baz")))
, q.root);
}
@Test
public void testComplex1() {
CompiledQuery<String> q = CompiledQueryParser.parse("foo ( bar | baz ) quux");
assertEquals(and(w(q,"foo"), or(w(q, "bar"), w(q, "baz")), w(q, "quux")), q.root);
}
@Test
public void testComplex2() {
CompiledQuery<String> q = CompiledQueryParser.parse("( ( ( a ) b ) c ) d");
assertEquals(and(and(and(w(q, "a"), w(q, "b")), w(q, "c")), w(q, "d")), q.root);
}
@Test
public void testNested() {
CompiledQuery<String> q = CompiledQueryParser.parse("( ( ( a ) ) )");
assertEquals(w(q,"a"), q.root);
}
private CqExpression.Word w(CompiledQuery<String> query, String word) {
return new CqExpression.Word(query.indices().filter(idx -> word.equals(query.at(idx))).findAny().orElseThrow());
}
private CqExpression and(CqExpression... parts) {
return new CqExpression.And(List.of(parts));
}
private CqExpression or(CqExpression... parts) {
return new CqExpression.Or(List.of(parts));
}
}

View File

@ -0,0 +1,35 @@
package nu.marginalia.api.searchquery.model.compiled.aggregate;
import static nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser.parse;
import static nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates.*;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class CompiledQueryAggregatesTest {
@Test
void booleanAggregates() {
assertFalse(booleanAggregate(parse("false"), Boolean::parseBoolean));
assertTrue(booleanAggregate(parse("true"), Boolean::parseBoolean));
assertFalse(booleanAggregate(parse("false true"), Boolean::parseBoolean));
assertTrue(booleanAggregate(parse("( true ) | ( true false )"), Boolean::parseBoolean));
assertTrue(booleanAggregate(parse("( false ) | ( true )"), Boolean::parseBoolean));
assertTrue(booleanAggregate(parse("( true false ) | ( true true )"), Boolean::parseBoolean));
assertFalse(booleanAggregate(parse("( true false ) | ( true false )"), Boolean::parseBoolean));
}
@Test
void intMaxMinAggregates() {
assertEquals(5, intMaxMinAggregate(parse("5"), Integer::parseInt));
assertEquals(3, intMaxMinAggregate(parse("5 3"), Integer::parseInt));
assertEquals(6, intMaxMinAggregate(parse("5 3 | 6 7"), Integer::parseInt));
}
@Test
void doubleSumAggregates() {
assertEquals(5, (int) doubleSumAggregate(parse("5"), Double::parseDouble));
assertEquals(8, (int) doubleSumAggregate(parse("5 3"), Double::parseDouble));
assertEquals(13, (int) doubleSumAggregate(parse("1 ( 5 3 | 2 10 )"), Double::parseDouble));
}
}

View File

@ -1,7 +1,7 @@
package nu.marginalia.index.client;
import nu.marginalia.api.searchquery.IndexProtobufCodec;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.index.query.limit.QueryLimits;
import nu.marginalia.index.query.limit.SpecificationLimit;
@ -35,14 +35,15 @@ class IndexProtobufCodecTest {
}
@Test
public void testSubqery() {
verifyIsIdentityTransformation(new SearchSubquery(
verifyIsIdentityTransformation(new SearchQuery(
"qs",
List.of("a", "b"),
List.of("c", "d"),
List.of("e", "f"),
List.of("g", "h"),
List.of(List.of("i", "j"), List.of("k"))
),
s -> IndexProtobufCodec.convertSearchSubquery(IndexProtobufCodec.convertSearchSubquery(s))
s -> IndexProtobufCodec.convertRpcQuery(IndexProtobufCodec.convertRpcQuery(s))
);
}
private <T> void verifyIsIdentityTransformation(T val, Function<T,T> transformation) {

View File

@ -26,6 +26,9 @@ dependencies {
implementation project(':code:libraries:term-frequency-dict')
implementation project(':third-party:porterstemmer')
implementation project(':third-party:openzim')
implementation project(':third-party:commons-codec')
implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:term-frequency-dict')
implementation project(':code:features-convert:keyword-extraction')
@ -36,6 +39,8 @@ dependencies {
implementation libs.bundles.grpc
implementation libs.notnull
implementation libs.guice
implementation libs.jsoup
implementation libs.commons.lang3
implementation libs.trove
implementation libs.fastutil
implementation libs.bundles.gson

View File

@ -0,0 +1,181 @@
package nu.marginalia.functions.searchquery.query_parser;
import ca.rmen.porterstemmer.PorterStemmer;
import com.google.inject.Inject;
import nu.marginalia.functions.searchquery.query_parser.model.QWord;
import nu.marginalia.functions.searchquery.query_parser.model.QWordGraph;
import nu.marginalia.functions.searchquery.query_parser.model.QWordPathsRenderer;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.apache.commons.lang3.StringUtils;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/** Responsible for expanding a query, that is creating alternative branches of query execution
* to increase the number of results
*/
public class QueryExpansion {
private static final PorterStemmer ps = new PorterStemmer();
private final TermFrequencyDict dict;
private final NgramLexicon lexicon;
private final List<ExpansionStrategy> expansionStrategies = List.of(
this::joinDashes,
this::splitWordNum,
this::joinTerms,
this::createSegments
);
@Inject
public QueryExpansion(TermFrequencyDict dict,
NgramLexicon lexicon
) {
this.dict = dict;
this.lexicon = lexicon;
}
public String expandQuery(List<String> words) {
QWordGraph graph = new QWordGraph(words);
for (var strategy : expansionStrategies) {
strategy.expand(graph);
}
return QWordPathsRenderer.render(graph);
}
private static final Pattern dashPattern = Pattern.compile("-");
private static final Pattern numWordBoundary = Pattern.compile("[0-9][a-zA-Z]|[a-zA-Z][0-9]");
// Turn 'lawn-chair' into 'lawnchair'
public void joinDashes(QWordGraph graph) {
for (var qw : graph) {
if (qw.word().contains("-")) {
var joined = StringUtils.join(dashPattern.split(qw.word()));
graph.addVariant(qw, joined);
}
}
}
// Turn 'MP3' into 'MP-3'
public void splitWordNum(QWordGraph graph) {
for (var qw : graph) {
var matcher = numWordBoundary.matcher(qw.word());
if (matcher.matches()) {
var joined = StringUtils.join(dashPattern.split(qw.word()), '-');
graph.addVariant(qw, joined);
}
}
}
// Turn 'lawn chair' into 'lawnchair'
public void joinTerms(QWordGraph graph) {
QWord prev = null;
for (var qw : graph) {
if (prev != null) {
var joinedWord = prev.word() + qw.word();
var joinedStemmed = ps.stemWord(joinedWord);
var scoreA = dict.getTermFreqStemmed(prev.stemmed());
var scoreB = dict.getTermFreqStemmed(qw.stemmed());
var scoreCombo = dict.getTermFreqStemmed(joinedStemmed);
if (scoreCombo > scoreA + scoreB || scoreCombo > 1000) {
graph.addVariantForSpan(prev, qw, joinedWord);
}
}
prev = qw;
}
}
/** Create an alternative interpretation of the query that replaces a sequence of words
* with a word n-gram. This makes it so that when possible, the order of words in the document
* matches the order of the words in the query.
*/
public void createSegments(QWordGraph graph) {
List<QWord> nodes = new ArrayList<>();
for (var qw : graph) {
nodes.add(qw);
}
String[] words = nodes.stream().map(QWord::stemmed).toArray(String[]::new);
// Grab all segments
List<NgramLexicon.SentenceSegment> allSegments = new ArrayList<>();
for (int length = 2; length < Math.min(10, words.length); length++) {
allSegments.addAll(lexicon.findSegmentOffsets(length, words));
}
allSegments.sort(Comparator.comparing(NgramLexicon.SentenceSegment::start));
if (allSegments.isEmpty()) {
return;
}
Set<NgramLexicon.SentenceSegment> bestSegmentation =
findBestSegmentation(allSegments);
for (var segment : bestSegmentation) {
int start = segment.start();
int end = segment.start() + segment.length();
var word = IntStream.range(start, end)
.mapToObj(nodes::get)
.map(QWord::word)
.collect(Collectors.joining("_"));
graph.addVariantForSpan(nodes.get(start), nodes.get(end - 1), word);
}
}
private Set<NgramLexicon.SentenceSegment> findBestSegmentation(List<NgramLexicon.SentenceSegment> allSegments) {
Set<NgramLexicon.SentenceSegment> bestSet = Set.of();
double bestScore = Double.MIN_VALUE;
for (int i = 0; i < allSegments.size(); i++) {
Set<NgramLexicon.SentenceSegment> parts = new HashSet<>();
parts.add(allSegments.get(i));
outer:
for (int j = i+1; j < allSegments.size(); j++) {
var candidate = allSegments.get(j);
for (var part : parts) {
if (part.overlaps(candidate)) {
continue outer;
}
}
parts.add(candidate);
}
double score = 0.;
for (var part : parts) {
// |s|^|s|-normalization per M Hagen et al
double normFactor = Math.pow(part.length(), part.length());
score += normFactor * part.count();
}
if (bestScore < score) {
bestScore = score;
bestSet = parts;
}
}
return bestSet;
}
public interface ExpansionStrategy {
void expand(QWordGraph graph);
}
}

View File

@ -1,8 +1,7 @@
package nu.marginalia.functions.searchquery.query_parser;
import nu.marginalia.functions.searchquery.query_parser.token.QueryToken;
import nu.marginalia.language.WordPatterns;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenType;
import nu.marginalia.util.transform_list.TransformList;
import java.util.List;
@ -11,95 +10,126 @@ public class QueryParser {
private final QueryTokenizer tokenizer = new QueryTokenizer();
public List<Token> parse(String query) {
List<Token> basicTokens = tokenizer.tokenizeQuery(query);
public List<QueryToken> parse(String query) {
List<QueryToken> basicTokens = tokenizer.tokenizeQuery(query);
TransformList<Token> list = new TransformList<>(basicTokens);
TransformList<QueryToken> list = new TransformList<>(basicTokens);
list.transformEach(QueryParser::handleQuoteTokens);
list.transformEach(QueryParser::trimLiterals);
list.transformEachPair(QueryParser::createNegatedTerms);
list.transformEachPair(QueryParser::createPriorityTerms);
list.transformEach(QueryParser::handleSpecialOperations);
list.scanAndTransform(TokenType.LPAREN, TokenType.RPAREN, QueryParser::handleAdvisoryTerms);
list.scanAndTransform(QueryToken.LParen.class::isInstance, QueryToken.RParen.class::isInstance, QueryParser::handleAdvisoryTerms);
list.transformEach(QueryParser::normalizeDomainName);
return list.getBackingList();
}
private static void handleQuoteTokens(TransformList<Token>.Entity entity) {
var t = entity.value();
if (t.type == TokenType.QUOT) {
entity.replace(new Token(TokenType.QUOT_TERM,
t.str.replaceAll("\\s+", WordPatterns.WORD_TOKEN_JOINER),
t.displayStr));
}
}
private static void trimLiterals(TransformList<Token>.Entity entity) {
private static void normalizeDomainName(TransformList<QueryToken>.Entity entity) {
var t = entity.value();
if (t.type == TokenType.LITERAL_TERM
&& (t.str.endsWith(":") || t.str.endsWith("."))
&& t.str.length() > 1) {
entity.replace(new Token(TokenType.LITERAL_TERM, t.str.substring(0, t.str.length() - 1), t.displayStr));
if (!(t instanceof QueryToken.LiteralTerm))
return;
if (t.str().startsWith("site:")) {
entity.replace(new QueryToken.LiteralTerm(t.str().toLowerCase(), t.displayStr()));
}
}
private static void createNegatedTerms(TransformList<Token>.Entity first, TransformList<Token>.Entity second) {
var t = first.value();
var tn = second.value();
if (t.type == TokenType.MINUS && tn.type == TokenType.LITERAL_TERM) {
first.remove();
second.replace(new Token(TokenType.EXCLUDE_TERM, tn.str, "-" + tn.str));
}
}
private static void createPriorityTerms(TransformList<Token>.Entity first, TransformList<Token>.Entity second) {
var t = first.value();
var tn = second.value();
if (t.type == TokenType.QMARK && tn.type == TokenType.LITERAL_TERM) {
first.remove();
second.replace(new Token(TokenType.PRIORTY_TERM, tn.str, "?" + tn.str));
}
}
private static void handleSpecialOperations(TransformList<Token>.Entity entity) {
private static void handleQuoteTokens(TransformList<QueryToken>.Entity entity) {
var t = entity.value();
if (t.type != TokenType.LITERAL_TERM) {
if (!(t instanceof QueryToken.Quot)) {
return;
}
if (t.str.startsWith("q") && t.str.matches("q[=><]\\d+")) {
entity.replace(new Token(TokenType.QUALITY_TERM, t.str.substring(1), t.displayStr));
} else if (t.str.startsWith("near:")) {
entity.replace(new Token(TokenType.NEAR_TERM, t.str.substring(5), t.displayStr));
} else if (t.str.startsWith("year") && t.str.matches("year[=><]\\d{4}")) {
entity.replace(new Token(TokenType.YEAR_TERM, t.str.substring(4), t.displayStr));
} else if (t.str.startsWith("size") && t.str.matches("size[=><]\\d+")) {
entity.replace(new Token(TokenType.SIZE_TERM, t.str.substring(4), t.displayStr));
} else if (t.str.startsWith("rank") && t.str.matches("rank[=><]\\d+")) {
entity.replace(new Token(TokenType.RANK_TERM, t.str.substring(4), t.displayStr));
} else if (t.str.startsWith("qs=")) {
entity.replace(new Token(TokenType.QS_TERM, t.str.substring(3), t.displayStr));
} else if (t.str.contains(":")) {
entity.replace(new Token(TokenType.ADVICE_TERM, t.str, t.displayStr));
}
entity.replace(new QueryToken.QuotTerm(
t.str().replaceAll("\\s+", WordPatterns.WORD_TOKEN_JOINER),
t.displayStr()));
}
private static void handleAdvisoryTerms(TransformList<Token>.Entity entity) {
private static void trimLiterals(TransformList<QueryToken>.Entity entity) {
var t = entity.value();
if (t.type == TokenType.LPAREN) {
if (!(t instanceof QueryToken.LiteralTerm lt))
return;
String str = lt.str();
if (str.isBlank())
return;
if (str.endsWith(":") || str.endsWith(".")) {
entity.replace(new QueryToken.LiteralTerm(str.substring(0, str.length() - 1), lt.displayStr()));
}
}
private static void createNegatedTerms(TransformList<QueryToken>.Entity first, TransformList<QueryToken>.Entity second) {
var t = first.value();
var tn = second.value();
if (!(t instanceof QueryToken.Minus))
return;
if (!(tn instanceof QueryToken.LiteralTerm) && !(tn instanceof QueryToken.AdviceTerm))
return;
first.remove();
second.replace(new QueryToken.ExcludeTerm(tn.str(), "-" + tn.displayStr()));
}
private static void createPriorityTerms(TransformList<QueryToken>.Entity first, TransformList<QueryToken>.Entity second) {
var t = first.value();
var tn = second.value();
if (!(t instanceof QueryToken.QMark))
return;
if (!(tn instanceof QueryToken.LiteralTerm) && !(tn instanceof QueryToken.AdviceTerm))
return;
var replacement = new QueryToken.PriorityTerm(tn.str(), "?" + tn.displayStr());
first.remove();
second.replace(replacement);
}
private static void handleSpecialOperations(TransformList<QueryToken>.Entity entity) {
var t = entity.value();
if (!(t instanceof QueryToken.LiteralTerm)) {
return;
}
String str = t.str();
if (str.startsWith("q") && str.matches("q[=><]\\d+")) {
entity.replace(new QueryToken.QualityTerm(str.substring(1)));
} else if (str.startsWith("near:")) {
entity.replace(new QueryToken.NearTerm(str.substring(5)));
} else if (str.startsWith("year") && str.matches("year[=><]\\d{4}")) {
entity.replace(new QueryToken.YearTerm(str.substring(4)));
} else if (str.startsWith("size") && str.matches("size[=><]\\d+")) {
entity.replace(new QueryToken.SizeTerm(str.substring(4)));
} else if (str.startsWith("rank") && str.matches("rank[=><]\\d+")) {
entity.replace(new QueryToken.RankTerm(str.substring(4)));
} else if (str.startsWith("qs=")) {
entity.replace(new QueryToken.QsTerm(str.substring(3)));
} else if (str.contains(":")) {
entity.replace(new QueryToken.AdviceTerm(str, t.displayStr()));
}
}
private static void handleAdvisoryTerms(TransformList<QueryToken>.Entity entity) {
var t = entity.value();
if (t instanceof QueryToken.LParen) {
entity.remove();
} else if (t.type == TokenType.RPAREN) {
} else if (t instanceof QueryToken.RParen) {
entity.remove();
} else if (t.type == TokenType.LITERAL_TERM) {
entity.replace(new Token(TokenType.ADVICE_TERM, t.str, "(" + t.str + ")"));
} else if (t instanceof QueryToken.LiteralTerm) {
entity.replace(new QueryToken.AdviceTerm(t.str(), "(" + t.displayStr() + ")"));
}
}
}

View File

@ -1,229 +0,0 @@
package nu.marginalia.functions.searchquery.query_parser;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenType;
import nu.marginalia.language.WordPatterns;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.util.stream.Stream.concat;
public class QueryPermutation {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final QueryVariants queryVariants;
public static final Pattern wordPattern = Pattern.compile("[#]?[_@.a-zA-Z0-9'+\\-\\u00C0-\\u00D6\\u00D8-\\u00f6\\u00f8-\\u00ff]+[#]?");
public static final Pattern wordAppendixPattern = Pattern.compile("[.]?[0-9a-zA-Z\\u00C0-\\u00D6\\u00D8-\\u00f6\\u00f8-\\u00ff]{1,3}[0-9]?");
public static final Predicate<String> wordQualitiesPredicate = wordPattern.asMatchPredicate();
public static final Predicate<String> wordAppendixPredicate = wordAppendixPattern.asMatchPredicate();
public static final Predicate<String> wordPredicateEither = wordQualitiesPredicate.or(wordAppendixPredicate);
public QueryPermutation(QueryVariants queryVariants) {
this.queryVariants = queryVariants;
}
public List<List<Token>> permuteQueries(List<Token> items) {
int start = -1;
int end = items.size();
for (int i = 0; i < items.size(); i++) {
var token = items.get(i);
if (start < 0) {
if (token.type == TokenType.LITERAL_TERM && wordQualitiesPredicate.test(token.str)) {
start = i;
}
}
else {
if (token.type != TokenType.LITERAL_TERM || !wordPredicateEither.test(token.str)) {
end = i;
break;
}
}
}
if (start >= 0 && end - start > 1) {
List<List<Token>> permuteParts = combineSearchTerms(items.subList(start, end));
int s = start;
int e = end;
return permuteParts.stream().map(part ->
concat(items.subList(0, s).stream(), concat(part.stream(), items.subList(e, items.size()).stream()))
.collect(Collectors.toList()))
.peek(lst -> lst.removeIf(this::isJunkWord))
.limit(24)
.collect(Collectors.toList());
}
else {
return List.of(items);
}
}
public List<List<Token>> permuteQueriesNew(List<Token> items) {
int start = -1;
int end = items.size();
for (int i = 0; i < items.size(); i++) {
var token = items.get(i);
if (start < 0) {
if (token.type == TokenType.LITERAL_TERM && wordQualitiesPredicate.test(token.str)) {
start = i;
}
}
else {
if (token.type != TokenType.LITERAL_TERM || !wordPredicateEither.test(token.str)) {
end = i;
break;
}
}
}
if (start >= 0 && end - start >= 1) {
var result = queryVariants.getQueryVariants(items.subList(start, end));
logger.debug("{}", result);
if (result.isEmpty()) {
logger.warn("Empty variants result, falling back on old code");
return permuteQueries(items);
}
List<List<Token>> queryVariants = new ArrayList<>();
for (var query : result.faithful) {
var tokens = query.terms.stream().map(term -> new Token(TokenType.LITERAL_TERM, term)).collect(Collectors.toList());
tokens.addAll(result.nonLiterals);
queryVariants.add(tokens);
}
for (var query : result.alternative) {
if (queryVariants.size() >= 6)
break;
var tokens = query.terms.stream().map(term -> new Token(TokenType.LITERAL_TERM, term)).collect(Collectors.toList());
tokens.addAll(result.nonLiterals);
queryVariants.add(tokens);
}
List<List<Token>> returnValue = new ArrayList<>(queryVariants.size());
for (var variant: queryVariants) {
List<Token> r = new ArrayList<>(start + variant.size() + (items.size() - end));
r.addAll(items.subList(0, start));
r.addAll(variant);
r.addAll(items.subList(end, items.size()));
returnValue.add(r);
}
return returnValue;
}
else {
return List.of(items);
}
}
private boolean isJunkWord(Token token) {
if (WordPatterns.isStopWord(token.str) &&
!token.str.matches("^(\\d+|([a-z]+:.*))$")) {
return true;
}
return switch (token.str) {
case "vs", "versus", "or", "and" -> true;
default -> false;
};
}
private List<List<Token>> combineSearchTerms(List<Token> subList) {
int size = subList.size();
if (size < 1) {
return Collections.emptyList();
}
else if (size == 1) {
if (WordPatterns.isStopWord(subList.get(0).str)) {
return Collections.emptyList();
}
return List.of(subList);
}
List<List<Token>> results = new ArrayList<>(size*(size+1)/2);
if (subList.size() <= 4 && subList.get(0).str.length() >= 2 && !isPrefixWord(subList.get(subList.size()-1).str)) {
results.add(List.of(joinTokens(subList)));
}
outer: for (int i = size - 1; i >= 1; i--) {
var left = combineSearchTerms(subList.subList(0, i));
var right = combineSearchTerms(subList.subList(i, size));
for (var l : left) {
if (results.size() > 48) {
break outer;
}
for (var r : right) {
if (results.size() > 48) {
break outer;
}
List<Token> combined = new ArrayList<>(l.size() + r.size());
combined.addAll(l);
combined.addAll(r);
if (!results.contains(combined)) {
results.add(combined);
}
}
}
}
if (!results.contains(subList)) {
results.add(subList);
}
Comparator<List<Token>> tc = (o1, o2) -> {
int dJoininess = o2.stream().mapToInt(s->(int)Math.pow(joininess(s.str), 2)).sum() -
o1.stream().mapToInt(s->(int)Math.pow(joininess(s.str), 2)).sum();
if (dJoininess == 0) {
return (o2.stream().mapToInt(s->(int)Math.pow(rightiness(s.str), 2)).sum() -
o1.stream().mapToInt(s->(int)Math.pow(rightiness(s.str), 2)).sum());
}
return (int) Math.signum(dJoininess);
};
results.sort(tc);
return results;
}
private boolean isPrefixWord(String str) {
return switch (str) {
case "the", "of", "when" -> true;
default -> false;
};
}
int joininess(String s) {
return (int) s.chars().filter(c -> c == '_').count();
}
int rightiness(String s) {
int rightiness = 0;
for (int i = 0; i < s.length(); i++) {
if (s.charAt(i) == '_') {
rightiness+=i;
}
}
return rightiness;
}
private Token joinTokens(List<Token> subList) {
return new Token(TokenType.LITERAL_TERM,
subList.stream().map(t -> t.str).collect(Collectors.joining("_")),
subList.stream().map(t -> t.str).collect(Collectors.joining(" ")));
}
}

View File

@ -1,7 +1,6 @@
package nu.marginalia.functions.searchquery.query_parser;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenType;
import nu.marginalia.functions.searchquery.query_parser.token.QueryToken;
import nu.marginalia.language.encoding.AsciiFlattener;
import java.util.ArrayList;
@ -11,8 +10,8 @@ import java.util.regex.Pattern;
public class QueryTokenizer {
private static final Pattern noisePattern = Pattern.compile("[,\\s]");
public List<Token> tokenizeQuery(String rawQuery) {
List<Token> tokens = new ArrayList<>();
public List<QueryToken> tokenizeQuery(String rawQuery) {
List<QueryToken> tokens = new ArrayList<>();
String query = AsciiFlattener.flattenUnicode(rawQuery);
query = noisePattern.matcher(query).replaceAll(" ");
@ -21,26 +20,27 @@ public class QueryTokenizer {
int chr = query.charAt(i);
if ('(' == chr) {
tokens.add(new Token(TokenType.LPAREN, "(", "("));
tokens.add(new QueryToken.LParen());
}
else if (')' == chr) {
tokens.add(new Token(TokenType.RPAREN, ")", ")"));
tokens.add(new QueryToken.RParen());
}
else if ('"' == chr) {
int end = query.indexOf('"', i+1);
if (end == -1) {
end = query.length();
}
tokens.add(new Token(TokenType.QUOT,
query.substring(i+1, end).toLowerCase(),
query.substring(i, Math.min(query.length(), end+1))));
tokens.add(new QueryToken.Quot(query.substring(i + 1, end).toLowerCase()));
i = end;
}
else if ('-' == chr) {
tokens.add(new Token(TokenType.MINUS, "-"));
tokens.add(new QueryToken.Minus());
}
else if ('?' == chr) {
tokens.add(new Token(TokenType.QMARK, "?"));
tokens.add(new QueryToken.QMark());
}
else if (Character.isSpaceChar(chr)) {
//
@ -52,9 +52,12 @@ public class QueryTokenizer {
if (query.charAt(end) == ' ' || query.charAt(end) == ')')
break;
}
tokens.add(new Token(TokenType.LITERAL_TERM,
query.substring(i, end).toLowerCase(),
query.substring(i, end)));
String displayStr = query.substring(i, end);
String str = displayStr.toLowerCase();
tokens.add(new QueryToken.LiteralTerm(str, displayStr));
i = end-1;
}
}

View File

@ -1,378 +0,0 @@
package nu.marginalia.functions.searchquery.query_parser;
import ca.rmen.porterstemmer.PorterStemmer;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenType;
import nu.marginalia.util.language.EnglishDictionary;
import nu.marginalia.LanguageModels;
import nu.marginalia.keyword.KeywordExtractor;
import nu.marginalia.language.sentence.SentenceExtractor;
import nu.marginalia.util.ngrams.NGramBloomFilter;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import nu.marginalia.language.model.DocumentSentence;
import nu.marginalia.language.model.WordSpan;
import java.util.*;
import java.util.regex.Pattern;
public class QueryVariants {
private final KeywordExtractor keywordExtractor;
private final TermFrequencyDict dict;
private final PorterStemmer ps = new PorterStemmer();
private final NGramBloomFilter nGramBloomFilter;
private final EnglishDictionary englishDictionary;
private final ThreadLocal<SentenceExtractor> sentenceExtractor;
public QueryVariants(LanguageModels lm,
TermFrequencyDict dict,
NGramBloomFilter nGramBloomFilter,
EnglishDictionary englishDictionary) {
this.nGramBloomFilter = nGramBloomFilter;
this.englishDictionary = englishDictionary;
this.keywordExtractor = new KeywordExtractor();
this.sentenceExtractor = ThreadLocal.withInitial(() -> new SentenceExtractor(lm));
this.dict = dict;
}
final Pattern numWordBoundary = Pattern.compile("[0-9][a-zA-Z]|[a-zA-Z][0-9]");
final Pattern dashBoundary = Pattern.compile("-");
@AllArgsConstructor
private static class Word {
public final String stemmed;
public final String word;
public final String wordOriginal;
}
@AllArgsConstructor @Getter @ToString @EqualsAndHashCode
public static class QueryVariant {
public final List<String> terms;
public final double value;
}
@Getter @ToString
public static class QueryVariantSet {
final List<QueryVariant> faithful = new ArrayList<>();
final List<QueryVariant> alternative = new ArrayList<>();
final List<Token> nonLiterals = new ArrayList<>();
public boolean isEmpty() {
return faithful.isEmpty() && alternative.isEmpty() && nonLiterals.isEmpty();
}
}
public QueryVariantSet getQueryVariants(List<Token> query) {
final JoinedQueryAndNonLiteralTokens joinedQuery = joinQuery(query);
final TreeMap<Integer, List<WordSpan>> byStart = new TreeMap<>();
var se = sentenceExtractor.get();
var sentence = se.extractSentence(joinedQuery.joinedQuery);
for (int i = 0; i < sentence.posTags.length; i++) {
if (sentence.posTags[i].startsWith("N") || sentence.posTags[i].startsWith("V")) {
sentence.posTags[i] = "NNP";
}
else if ("JJ".equals(sentence.posTags[i]) || "CD".equals(sentence.posTags[i]) || sentence.posTags[i].startsWith("P")) {
sentence.posTags[i] = "NNP";
sentence.setIsStopWord(i, false);
}
}
for (var kw : keywordExtractor.getKeywordsFromSentence(sentence)) {
byStart.computeIfAbsent(kw.start, k -> new ArrayList<>()).add(kw);
}
final List<ArrayList<WordSpan>> livingSpans = new ArrayList<>();
var first = byStart.firstEntry();
if (first == null) {
var span = new WordSpan(0, sentence.length());
byStart.put(0, List.of(span));
}
else if (first.getKey() > 0) {
List<WordSpan> elongatedFirstWords = new ArrayList<>(first.getValue().size());
first.getValue().forEach(span -> {
elongatedFirstWords.add(new WordSpan(0, span.start));
elongatedFirstWords.add(new WordSpan(0, span.end));
});
byStart.put(0, elongatedFirstWords);
}
final List<List<Word>> goodSpans = getWordSpans(byStart, sentence, livingSpans);
List<List<String>> faithfulQueries = new ArrayList<>();
List<List<String>> alternativeQueries = new ArrayList<>();
for (var ls : goodSpans) {
faithfulQueries.addAll(createTokens(ls));
}
for (var span : goodSpans) {
alternativeQueries.addAll(joinTerms(span));
}
for (var ls : goodSpans) {
var last = ls.get(ls.size() - 1);
if (!last.wordOriginal.isBlank() && !Character.isUpperCase(last.wordOriginal.charAt(0))) {
var altLast = englishDictionary.getWordVariants(last.word);
for (String s : altLast) {
List<String> newList = new ArrayList<>(ls.size());
for (int i = 0; i < ls.size() - 1; i++) {
newList.add(ls.get(i).word);
}
newList.add(s);
alternativeQueries.add(newList);
}
}
}
QueryVariantSet returnValue = new QueryVariantSet();
returnValue.faithful.addAll(evaluateQueries(faithfulQueries));
returnValue.alternative.addAll(evaluateQueries(alternativeQueries));
returnValue.faithful.sort(Comparator.comparing(QueryVariant::getValue));
returnValue.alternative.sort(Comparator.comparing(QueryVariant::getValue));
returnValue.nonLiterals.addAll(joinedQuery.nonLiterals);
return returnValue;
}
final Pattern underscore = Pattern.compile("_");
private List<QueryVariant> evaluateQueries(List<List<String>> queryStrings) {
Set<QueryVariant> variantsSet = new HashSet<>();
List<QueryVariant> ret = new ArrayList<>();
for (var lst : queryStrings) {
double q = 0;
for (var word : lst) {
String[] parts = underscore.split(word);
double qp = 0;
for (String part : parts) {
qp += 1./(1+ dict.getTermFreq(part));
}
q += 1.0 / qp;
}
var qv = new QueryVariant(lst, q);
if (variantsSet.add(qv)) {
ret.add(qv);
}
}
return ret;
}
private Collection<List<String>> createTokens(List<Word> ls) {
List<String> asTokens = new ArrayList<>();
List<List<String>> ret = new ArrayList<>();
boolean dash = false;
boolean num = false;
for (var span : ls) {
dash |= dashBoundary.matcher(span.word).find();
num |= numWordBoundary.matcher(span.word).find();
if (ls.size() == 1 || !isOmittableWord(span.word)) {
asTokens.add(span.word);
}
}
ret.add(asTokens);
if (dash) {
ret.addAll(combineDashWords(ls));
}
if (num) {
ret.addAll(splitWordNum(ls));
}
return ret;
}
private boolean isOmittableWord(String word) {
return switch (word) {
case "vs", "or", "and", "versus", "is", "the", "why", "when", "if", "who", "are", "am" -> true;
default -> false;
};
}
private Collection<? extends List<String>> splitWordNum(List<Word> ls) {
List<String> asTokens2 = new ArrayList<>();
boolean num = false;
for (var span : ls) {
var wordMatcher = numWordBoundary.matcher(span.word);
var stemmedMatcher = numWordBoundary.matcher(span.stemmed);
int ws = 0;
int ss = 0;
boolean didSplit = false;
while (wordMatcher.find(ws) && stemmedMatcher.find(ss)) {
ws = wordMatcher.start()+1;
ss = stemmedMatcher.start()+1;
if (nGramBloomFilter.isKnownNGram(splitAtNumBoundary(span.word, stemmedMatcher.start(), "_"))
|| nGramBloomFilter.isKnownNGram(splitAtNumBoundary(span.word, stemmedMatcher.start(), "-")))
{
String combined = splitAtNumBoundary(span.word, wordMatcher.start(), "_");
asTokens2.add(combined);
didSplit = true;
num = true;
}
}
if (!didSplit) {
asTokens2.add(span.word);
}
}
if (num) {
return List.of(asTokens2);
}
return Collections.emptyList();
}
private Collection<? extends List<String>> combineDashWords(List<Word> ls) {
List<String> asTokens2 = new ArrayList<>();
boolean dash = false;
for (var span : ls) {
var matcher = dashBoundary.matcher(span.word);
if (matcher.find() && nGramBloomFilter.isKnownNGram(ps.stemWord(dashBoundary.matcher(span.word).replaceAll("")))) {
dash = true;
String combined = dashBoundary.matcher(span.word).replaceAll("");
asTokens2.add(combined);
}
else {
asTokens2.add(span.word);
}
}
if (dash) {
return List.of(asTokens2);
}
return Collections.emptyList();
}
private String splitAtNumBoundary(String in, int splitPoint, String joiner) {
return in.substring(0, splitPoint+1) + joiner + in.substring(splitPoint+1);
}
private List<List<Word>> getWordSpans(TreeMap<Integer, List<WordSpan>> byStart, DocumentSentence sentence, List<ArrayList<WordSpan>> livingSpans) {
List<List<Word>> goodSpans = new ArrayList<>();
for (int i = 0; i < 1; i++) {
var spans = byStart.get(i);
if (spans == null )
continue;
for (var span : spans) {
ArrayList<WordSpan> fragment = new ArrayList<>();
fragment.add(span);
livingSpans.add(fragment);
}
if (sentence.posTags[i].startsWith("N") || sentence.posTags[i].startsWith("V")) break;
}
while (!livingSpans.isEmpty()) {
final List<ArrayList<WordSpan>> newLivingSpans = new ArrayList<>(livingSpans.size());
for (var span : livingSpans) {
int end = span.get(span.size()-1).end;
if (end == sentence.length()) {
var gs = new ArrayList<Word>(span.size());
for (var s : span) {
gs.add(new Word(sentence.constructStemmedWordFromSpan(s), sentence.constructWordFromSpan(s),
s.size() == 1 ? sentence.words[s.start] : ""));
}
goodSpans.add(gs);
}
var nextWordsKey = byStart.ceilingKey(end);
if (null == nextWordsKey)
continue;
for (var next : byStart.get(nextWordsKey)) {
var newSpan = new ArrayList<WordSpan>(span.size() + 1);
newSpan.addAll(span);
newSpan.add(next);
newLivingSpans.add(newSpan);
}
}
livingSpans.clear();
livingSpans.addAll(newLivingSpans);
}
return goodSpans;
}
private List<List<String>> joinTerms(List<Word> span) {
List<List<String>> ret = new ArrayList<>();
for (int i = 0; i < span.size()-1; i++) {
var a = span.get(i);
var b = span.get(i+1);
var stemmed = ps.stemWord(a.word + b.word);
double scoreCombo = dict.getTermFreqStemmed(stemmed);
if (scoreCombo > 10000) {
List<String> asTokens = new ArrayList<>();
for (int j = 0; j < i; j++) {
var word = span.get(j).word;
asTokens.add(word);
}
{
var word = a.word + b.word;
asTokens.add(word);
}
for (int j = i+2; j < span.size(); j++) {
var word = span.get(j).word;
asTokens.add(word);
}
ret.add(asTokens);
}
}
return ret;
}
private JoinedQueryAndNonLiteralTokens joinQuery(List<Token> query) {
StringJoiner s = new StringJoiner(" ");
List<Token> leftovers = new ArrayList<>(5);
for (var t : query) {
if (t.type == TokenType.LITERAL_TERM) {
s.add(t.displayStr);
}
else {
leftovers.add(t);
}
}
return new JoinedQueryAndNonLiteralTokens(s.toString(), leftovers);
}
record JoinedQueryAndNonLiteralTokens(String joinedQuery, List<Token> nonLiterals) {}
}

View File

@ -0,0 +1,51 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import ca.rmen.porterstemmer.PorterStemmer;
public record QWord(
int ord,
boolean variant,
String stemmed,
String word,
String original)
{
// These are special words that are not in the input, but are added to the graph,
// note the space around the ^ and $, to avoid collisions with real words
private static final String BEG_MARKER = " ^ ";
private static final String END_MARKER = " $ ";
private static final PorterStemmer ps = new PorterStemmer();
public boolean isBeg() {
return word.equals(BEG_MARKER);
}
public boolean isEnd() {
return word.equals(END_MARKER);
}
public static QWord beg() {
return new QWord(Integer.MIN_VALUE, false, BEG_MARKER, BEG_MARKER, BEG_MARKER);
}
public static QWord end() {
return new QWord(Integer.MAX_VALUE, false, END_MARKER, END_MARKER, END_MARKER);
}
public boolean isOriginal() {
return !variant;
}
public QWord(int ord, String word) {
this(ord, false, ps.stemWord(word), word, word);
}
public QWord(int ord, QWord original, String word) {
this(ord, true, ps.stemWord(word), word, original.original);
}
public String toString() {
return STR."q{\{word}}";
}
}

View File

@ -0,0 +1,267 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import org.jetbrains.annotations.NotNull;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** Graph structure for constructing query variants. The graph should be a directed acyclic graph,
* with a single start node and a single end node, denoted by QWord.beg() and QWord.end() respectively.
* <p></p>
* Naively, every path from the start to the end node should represent a valid query variant, although in
* practice it is desirable to be clever about how to evaluate the paths, to avoid a large number of queries
* being generated.
*/
public class QWordGraph implements Iterable<QWord> {
public record QWordGraphLink(QWord from, QWord to) {}
private final List<QWordGraphLink> links = new ArrayList<>();
private final Map<QWord, List<QWord>> fromTo = new HashMap<>();
private final Map<QWord, List<QWord>> toFrom = new HashMap<>();
private int wordId = 0;
public QWordGraph(String... words) {
this(List.of(words));
}
public QWordGraph(List<String> words) {
QWord beg = QWord.beg();
QWord end = QWord.end();
var prev = beg;
for (String s : words) {
var word = new QWord(wordId++, s);
addLink(prev, word);
prev = word;
}
addLink(prev, end);
}
public void addVariant(QWord original, String word) {
var siblings = getVariants(original);
if (siblings.stream().anyMatch(w -> w.word().equals(word)))
return;
var newWord = new QWord(wordId++, original, word);
for (var prev : getPrev(original))
addLink(prev, newWord);
for (var next : getNext(original))
addLink(newWord, next);
}
public void addVariantForSpan(QWord first, QWord last, String word) {
var newWord = new QWord(wordId++, first, word);
for (var prev : getPrev(first))
addLink(prev, newWord);
for (var next : getNext(last))
addLink(newWord, next);
}
public List<QWord> getVariants(QWord original) {
var prevNext = getPrev(original).stream()
.flatMap(prev -> getNext(prev).stream())
.collect(Collectors.toSet());
return getNext(original).stream()
.flatMap(next -> getPrev(next).stream())
.filter(prevNext::contains)
.collect(Collectors.toList());
}
public void addLink(QWord from, QWord to) {
links.add(new QWordGraphLink(from, to));
fromTo.computeIfAbsent(from, k -> new ArrayList<>()).add(to);
toFrom.computeIfAbsent(to, k -> new ArrayList<>()).add(from);
}
public List<QWordGraphLink> links() {
return Collections.unmodifiableList(links);
}
public List<QWord> nodes() {
return links.stream()
.flatMap(l -> Stream.of(l.from(), l.to()))
.sorted(Comparator.comparing(QWord::ord))
.distinct()
.collect(Collectors.toList());
}
public QWord node(String word) {
return nodes().stream()
.filter(n -> n.word().equals(word))
.findFirst()
.orElseThrow();
}
public List<QWord> getNext(QWord word) {
return fromTo.getOrDefault(word, List.of());
}
public List<QWord> getNextOriginal(QWord word) {
return fromTo.getOrDefault(word, List.of())
.stream()
.filter(QWord::isOriginal)
.toList();
}
public List<QWord> getPrev(QWord word) {
return toFrom.getOrDefault(word, List.of());
}
public List<QWord> getPrevOriginal(QWord word) {
return toFrom.getOrDefault(word, List.of())
.stream()
.filter(QWord::isOriginal)
.toList();
}
public Map<QWord, Set<QWord>> forwardReachability() {
Map<QWord, Set<QWord>> ret = new HashMap<>();
Set<QWord> edge = Set.of(QWord.beg());
Set<QWord> visited = new HashSet<>();
while (!edge.isEmpty()) {
Set<QWord> next = new LinkedHashSet<>();
for (var w : edge) {
for (var n : getNext(w)) {
var set = ret.computeIfAbsent(n, k -> new HashSet<>());
set.add(w);
set.addAll(ret.getOrDefault(w, Set.of()));
next.add(n);
}
}
next.removeAll(visited);
visited.addAll(next);
edge = next;
}
return ret;
}
public Map<QWord, Set<QWord>> reverseReachability() {
Map<QWord, Set<QWord>> ret = new HashMap<>();
Set<QWord> edge = Set.of(QWord.end());
Set<QWord> visited = new HashSet<>();
while (!edge.isEmpty()) {
Set<QWord> prev = new LinkedHashSet<>();
for (var w : edge) {
for (var p : getPrev(w)) {
var set = ret.computeIfAbsent(p, k -> new HashSet<>());
set.add(w);
set.addAll(ret.getOrDefault(w, Set.of()));
prev.add(p);
}
}
prev.removeAll(visited);
visited.addAll(prev);
edge = prev;
}
return ret;
}
public record ReachabilityData(List<QWord> sortedNodes,
Map<QWord, Integer> sortOrder,
Map<QWord, Set<QWord>> forward,
Map<QWord, Set<QWord>> reverse)
{
public Set<QWord> forward(QWord node) {
return forward.getOrDefault(node, Set.of());
}
public Set<QWord> reverse(QWord node) {
return reverse.getOrDefault(node, Set.of());
}
public Comparator<QWord> topologicalComparator() {
Comparator<QWord> comp = Comparator.comparing(sortOrder::get);
return comp.thenComparing(QWord::ord);
}
}
/** Gather data about graph reachability, including the topological order of nodes */
public ReachabilityData reachability() {
var forwardReachability = forwardReachability();
var reverseReachability = reverseReachability();
List<QWord> nodes = new ArrayList<>(nodes());
nodes.sort(new SetMembershipComparator<>(forwardReachability));
Map<QWord, Integer> topologicalOrder = new HashMap<>();
for (int i = 0; i < nodes.size(); i++) {
topologicalOrder.put(nodes.get(i), i);
}
return new ReachabilityData(nodes, topologicalOrder, forwardReachability, reverseReachability);
}
static class SetMembershipComparator<T> implements Comparator<T> {
private final Map<T, Set<T>> membership;
SetMembershipComparator(Map<T, Set<T>> membership) {
this.membership = membership;
}
@Override
public int compare(T o1, T o2) {
return Boolean.compare(isIn(o1, o2), isIn(o2, o1));
}
private boolean isIn(T a, T b) {
return membership.getOrDefault(a, Set.of()).contains(b);
}
}
public String compileToQuery() {
return QWordPathsRenderer.render(this);
}
public String compileToDot() {
StringBuilder sb = new StringBuilder();
sb.append("digraph {\n");
for (var link : links) {
sb.append(STR."\"\{link.from().word()}\" -> \"\{link.to.word()}\";\n");
}
sb.append("}\n");
return sb.toString();
}
@NotNull
@Override
public Iterator<QWord> iterator() {
return new Iterator<>() {
QWord pos = QWord.beg();
@Override
public boolean hasNext() {
return !pos.isEnd();
}
@Override
public QWord next() {
pos = getNextOriginal(pos).getFirst();
return pos;
}
};
}
}

View File

@ -0,0 +1,57 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Objects;
import java.util.Set;
/** Utility class for listing each path in a {@link QWordGraph}, from the beginning node to the end.
* Normally this would be a risk for combinatorial explosion, but in practice the graph will be constructed
* in a way that avoids this risk.
* */
public class QWordGraphPathLister {
private final QWordGraph graph;
public QWordGraphPathLister(QWordGraph graph) {
this.graph = graph;
}
public static Set<QWordPath> listPaths(QWordGraph graph) {
return new QWordGraphPathLister(graph).listPaths();
}
Set<QWordPath> listPaths() {
Set<QWordPath> paths = new HashSet<>();
listPaths(paths, new LinkedList<>(), QWord.beg(), QWord.end());
return paths;
}
void listPaths(Set<QWordPath> acc,
LinkedList<QWord> stack,
QWord start,
QWord end)
{
stack.addLast(start);
if (Objects.equals(start, end)) {
var nodes = new HashSet<>(stack);
// Remove the start and end nodes from the path, as these are
// not part of the query but merely used to simplify the construction
// of the graph
nodes.remove(QWord.beg());
nodes.remove(QWord.end());
acc.add(new QWordPath(nodes));
}
else {
for (var next : graph.getNext(start)) {
listPaths(acc, stack, next, end);
}
}
stack.removeLast();
}
}

View File

@ -0,0 +1,68 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** Represents a path of QWords in a QWordGraph. Since the order of operations when
* evaluating a query does not affect its semantics, only performance, the order of the
* nodes in the path is not significant; thus the path is represented with a set.
*/
public class QWordPath {
private final Set<QWord> nodes;
QWordPath(Collection<QWord> nodes) {
this.nodes = new HashSet<>(nodes);
}
public boolean contains(QWord node) {
return nodes.contains(node);
}
/** Construct a new path by removing a word from the path. */
public QWordPath without(QWord word) {
Set<QWord> newNodes = new HashSet<>(nodes);
newNodes.remove(word);
return new QWordPath(newNodes);
}
public Stream<QWord> stream() {
return nodes.stream();
}
/** Construct a new path by projecting the path onto a set of nodes, such that
* the nodes in the new set is a strict subset of the provided nodes */
public QWordPath project(Set<QWord> nodes) {
return new QWordPath(this.nodes.stream().filter(nodes::contains).collect(Collectors.toSet()));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
QWordPath wordPath = (QWordPath) o;
return nodes.equals(wordPath.nodes);
}
public boolean isEmpty() {
return nodes.isEmpty();
}
public int size() {
return nodes.size();
}
@Override
public int hashCode() {
return nodes.hashCode();
}
@Override
public String toString() {
return STR."WordPath{nodes=\{nodes}\{'}'}";
}
}

View File

@ -0,0 +1,149 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import java.util.*;
import java.util.stream.Collectors;
/** Renders a set of QWordPaths into a human-readable infix-style expression. It's not guaranteed to find
* the globally optimal expression, but rather uses a greedy algorithm as a tradeoff in effort to outcome.
*/
public class QWordPathsRenderer {
private final Set<QWordPath> paths;
private QWordPathsRenderer(Collection<QWordPath> paths) {
this.paths = Set.copyOf(paths);
}
private QWordPathsRenderer(QWordGraph graph) {
this.paths = Set.copyOf(QWordGraphPathLister.listPaths(graph));
}
public static String render(QWordGraph graph) {
return new QWordPathsRenderer(graph).render(graph.reachability());
}
private static String render(Collection<QWordPath> paths,
QWordGraph.ReachabilityData reachability)
{
return new QWordPathsRenderer(paths).render(reachability);
}
/** Render the paths into a human-readable infix-style expression.
* <p></p>
* This method is recursive, but the recursion depth is limited by the
* maximum length of the paths, which is hard limited to a value typically around 10,
* so we don't need to worry about stack overflows here...
*/
String render(QWordGraph.ReachabilityData reachability) {
if (paths.size() == 1) {
return paths.iterator().next().stream().map(QWord::word).collect(Collectors.joining(" "));
}
// Find the commonality of words in the paths
Map<QWord, Integer> commonality = nodeCommonality(paths);
// Break the words into two categories: those that are common to all paths, and those that are not
List<QWord> commonToAll = new ArrayList<>();
Set<QWord> notCommonToAll = new HashSet<>();
commonality.forEach((k, v) -> {
if (v == paths.size()) {
commonToAll.add(k);
} else {
notCommonToAll.add(k);
}
});
StringJoiner resultJoiner = new StringJoiner(" ");
if (!commonToAll.isEmpty()) { // Case where one or more words are common to all paths
commonToAll.sort(reachability.topologicalComparator());
for (var word : commonToAll) {
resultJoiner.add(word.word());
}
// Deal portion of the paths that do not all share a common word
if (!notCommonToAll.isEmpty()) {
List<QWordPath> nonOverlappingPortions = new ArrayList<>();
// Create a new path for each path that does not contain the common words we just printed
for (var path : paths) {
var np = path.project(notCommonToAll);
if (np.isEmpty())
continue;
nonOverlappingPortions.add(np);
}
// Recurse into the non-overlapping portions
resultJoiner.add(render(nonOverlappingPortions, reachability));
}
} else if (commonality.size() > 1) { // The case where no words are common to all paths
// Sort the words by commonality, so that we can consider the most common words first
Map<QWord, List<QWordPath>> pathsByCommonWord = new HashMap<>();
// Mutable copy of the paths
List<QWordPath> allDivergentPaths = new ArrayList<>(paths);
// Break the paths into branches by the first common word they contain, in order of decreasing commonality
while (!allDivergentPaths.isEmpty()) {
QWord mostCommon = mostCommonQWord(allDivergentPaths);
var iter = allDivergentPaths.iterator();
while (iter.hasNext()) {
var path = iter.next();
if (!path.contains(mostCommon)) {
continue;
}
// Remove the common word from the path
var newPath = path.without(mostCommon);
pathsByCommonWord
.computeIfAbsent(mostCommon, k -> new ArrayList<>())
.add(newPath);
// Remove the path from the list of divergent paths since we've now accounted for it and
// we don't want redundant branches:
iter.remove();
}
}
var branches = pathsByCommonWord.entrySet().stream()
.sorted(Map.Entry.comparingByKey(reachability.topologicalComparator())) // Sort by topological order to ensure consistent output
.map(e -> {
String commonWord = e.getKey().word();
// Recurse into the branches:
String branchPart = render(e.getValue(), reachability);
return STR."\{commonWord} \{branchPart}";
})
.collect(Collectors.joining(" | ", " ( ", " ) "));
resultJoiner.add(branches);
}
// Remove any double spaces that may have been introduced
return resultJoiner.toString().replaceAll("\\s+", " ").trim();
}
/** Compute how many paths each word is part of */
private static Map<QWord, Integer> nodeCommonality(Collection<QWordPath> paths) {
return paths.stream().flatMap(QWordPath::stream)
.collect(Collectors.groupingBy(w -> w, Collectors.summingInt(w -> 1)));
}
private static QWord mostCommonQWord(Collection<QWordPath> paths) {
assert !paths.isEmpty();
return nodeCommonality(paths).entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElseThrow();
}
}

View File

@ -0,0 +1,86 @@
package nu.marginalia.functions.searchquery.query_parser.token;
public sealed interface QueryToken {
String str();
String displayStr();
record LiteralTerm(String str, String displayStr) implements QueryToken {}
record QuotTerm(String str, String displayStr) implements QueryToken {}
record ExcludeTerm(String str, String displayStr) implements QueryToken {}
record AdviceTerm(String str, String displayStr) implements QueryToken {}
record PriorityTerm(String str, String displayStr) implements QueryToken {}
record QualityTerm(String str) implements QueryToken {
public String displayStr() {
return "q" + str;
}
}
record YearTerm(String str) implements QueryToken {
public String displayStr() {
return "year" + str;
}
}
record SizeTerm(String str) implements QueryToken {
public String displayStr() {
return "size" + str;
}
}
record RankTerm(String str) implements QueryToken {
public String displayStr() {
return "rank" + str;
}
}
record NearTerm(String str) implements QueryToken {
public String displayStr() {
return "near:" + str;
}
}
record QsTerm(String str) implements QueryToken {
public String displayStr() {
return "qs" + str;
}
}
record Quot(String str) implements QueryToken {
public String displayStr() {
return "\"" + str + "\"";
}
}
record Minus() implements QueryToken {
public String str() {
return "-";
}
public String displayStr() {
return "-";
}
}
record QMark() implements QueryToken {
public String str() {
return "?";
}
public String displayStr() {
return "?";
}
}
record LParen() implements QueryToken {
public String str() {
return "(";
}
public String displayStr() {
return "(";
}
}
record RParen() implements QueryToken {
public String str() {
return ")";
}
public String displayStr() {
return ")";
}
}
record Ignore(String str, String displayStr) implements QueryToken {}
}

View File

@ -1,49 +0,0 @@
package nu.marginalia.functions.searchquery.query_parser.token;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import lombok.With;
@ToString
@EqualsAndHashCode
@With
public class Token {
public TokenType type;
public String str;
public final String displayStr;
public Token(TokenType type, String str, String displayStr) {
this.type = type;
this.str = str;
this.displayStr = safeString(displayStr);
}
public Token(TokenType type, String str) {
this.type = type;
this.str = str;
this.displayStr = safeString(str);
}
private static String safeString(String s) {
return s.replaceAll("<", "&lt;")
.replaceAll(">", "&gt;");
}
public void visit(TokenVisitor visitor) {
switch (type) {
case QUOT_TERM: visitor.onQuotTerm(this); break;
case EXCLUDE_TERM: visitor.onExcludeTerm(this); break;
case PRIORTY_TERM: visitor.onPriorityTerm(this); break;
case ADVICE_TERM: visitor.onAdviceTerm(this); break;
case LITERAL_TERM: visitor.onLiteralTerm(this); break;
case YEAR_TERM: visitor.onYearTerm(this); break;
case RANK_TERM: visitor.onRankTerm(this); break;
case SIZE_TERM: visitor.onSizeTerm(this); break;
case QS_TERM: visitor.onQsTerm(this); break;
case QUALITY_TERM: visitor.onQualityTerm(this); break;
}
}
}

View File

@ -1,34 +0,0 @@
package nu.marginalia.functions.searchquery.query_parser.token;
import java.util.function.Predicate;
public enum TokenType implements Predicate<Token> {
TERM,
LITERAL_TERM,
QUOT_TERM,
EXCLUDE_TERM,
ADVICE_TERM,
PRIORTY_TERM,
QUALITY_TERM,
YEAR_TERM,
SIZE_TERM,
RANK_TERM,
NEAR_TERM,
QS_TERM,
QUOT,
MINUS,
QMARK,
LPAREN,
RPAREN,
IGNORE;
public boolean test(Token t) {
return t.type == this;
}
}

View File

@ -1,14 +0,0 @@
package nu.marginalia.functions.searchquery.query_parser.token;
public interface TokenVisitor {
void onLiteralTerm(Token token);
void onQuotTerm(Token token);
void onExcludeTerm(Token token);
void onPriorityTerm(Token token);
void onAdviceTerm(Token token);
void onYearTerm(Token token);
void onSizeTerm(Token token);
void onRankTerm(Token token);
void onQualityTerm(Token token);
void onQsTerm(Token token);
}

View File

@ -2,72 +2,42 @@ package nu.marginalia.functions.searchquery.svc;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.LanguageModels;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.util.language.EnglishDictionary;
import nu.marginalia.functions.searchquery.query_parser.QueryExpansion;
import nu.marginalia.functions.searchquery.query_parser.token.QueryToken;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.query.limit.SpecificationLimit;
import nu.marginalia.language.WordPatterns;
import nu.marginalia.util.ngrams.NGramBloomFilter;
import nu.marginalia.api.searchquery.model.query.QueryParams;
import nu.marginalia.api.searchquery.model.query.ProcessedQuery;
import nu.marginalia.functions.searchquery.query_parser.QueryParser;
import nu.marginalia.functions.searchquery.query_parser.QueryPermutation;
import nu.marginalia.functions.searchquery.query_parser.QueryVariants;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenType;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Singleton
public class QueryFactory {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final int RETAIN_QUERY_VARIANT_COUNT = 5;
private final ThreadLocal<QueryVariants> queryVariants;
private final QueryParser queryParser = new QueryParser();
private final QueryExpansion queryExpansion;
@Inject
public QueryFactory(LanguageModels lm,
TermFrequencyDict dict,
EnglishDictionary englishDictionary,
NGramBloomFilter nGramBloomFilter) {
this.queryVariants = ThreadLocal.withInitial(() -> new QueryVariants(lm ,dict, nGramBloomFilter, englishDictionary));
public QueryFactory(QueryExpansion queryExpansion)
{
this.queryExpansion = queryExpansion;
}
public QueryPermutation getQueryPermutation() {
return new QueryPermutation(queryVariants.get());
}
public ProcessedQuery createQuery(QueryParams params) {
final var processedQuery = createQuery(getQueryPermutation(), params);
final List<SearchSubquery> subqueries = processedQuery.specs.subqueries;
// There used to be a piece of logic here that would try to figure out which one of these subqueries were the "best",
// it's gone for the moment, but it would be neat if it resurrected somehow
trimArray(subqueries, RETAIN_QUERY_VARIANT_COUNT);
return processedQuery;
}
private void trimArray(List<?> arr, int maxSize) {
if (arr.size() > maxSize) {
arr.subList(0, arr.size() - maxSize).clear();
}
}
public ProcessedQuery createQuery(QueryPermutation queryPermutation,
QueryParams params)
{
final var query = params.humanQuery();
if (query.length() > 1000) {
@ -77,42 +47,86 @@ public class QueryFactory {
List<String> searchTermsHuman = new ArrayList<>();
List<String> problems = new ArrayList<>();
String domain = null;
var basicQuery = queryParser.parse(query);
List<QueryToken> basicQuery = queryParser.parse(query);
if (basicQuery.size() >= 12) {
problems.add("Your search query is too long");
basicQuery.clear();
}
List<String> searchTermsExclude = new ArrayList<>();
List<String> searchTermsInclude = new ArrayList<>();
List<String> searchTermsAdvice = new ArrayList<>();
List<String> searchTermsPriority = new ArrayList<>();
List<List<String>> searchTermCoherences = new ArrayList<>();
QueryLimitsAccumulator qualityLimits = new QueryLimitsAccumulator(params);
SpecificationLimit qualityLimit = SpecificationLimit.none();
SpecificationLimit year = SpecificationLimit.none();
SpecificationLimit size = SpecificationLimit.none();
SpecificationLimit rank = SpecificationLimit.none();
QueryStrategy queryStrategy = QueryStrategy.AUTO;
for (Token t : basicQuery) {
if (t.type == TokenType.QUOT_TERM || t.type == TokenType.LITERAL_TERM) {
if (t.str.startsWith("site:")) {
t.str = normalizeDomainName(t.str);
String domain = null;
for (QueryToken t : basicQuery) {
switch (t) {
case QueryToken.QuotTerm(String str, String displayStr) -> {
analyzeSearchTerm(problems, str, displayStr);
searchTermsHuman.addAll(Arrays.asList(displayStr.replace("\"", "").split("\\s+")));
String[] parts = StringUtils.split(str, '_');
// Checking for stop words here is a bit of a stop-gap to fix the issue of stop words being
// required in the query (which is a problem because they are not indexed). How to do this
// in a clean way is a bit of an open problem that may not get resolved until query-parsing is
// improved.
if (parts.length > 1 && !anyPartIsStopWord(parts)) {
// Prefer that the actual n-gram is present
searchTermsAdvice.add(str);
// Require that the terms appear in the same sentence
searchTermCoherences.add(Arrays.asList(parts));
// Require that each term exists in the document
// (needed for ranking)
searchTermsInclude.addAll(Arrays.asList(parts));
}
else {
searchTermsInclude.add(str);
}
}
case QueryToken.LiteralTerm(String str, String displayStr) -> {
analyzeSearchTerm(problems, str, displayStr);
searchTermsHuman.addAll(Arrays.asList(displayStr.split("\\s+")));
searchTermsInclude.add(str);
}
searchTermsHuman.addAll(toHumanSearchTerms(t));
analyzeSearchTerm(problems, t);
case QueryToken.ExcludeTerm(String str, String displayStr) -> searchTermsExclude.add(str);
case QueryToken.PriorityTerm(String str, String displayStr) -> searchTermsPriority.add(str);
case QueryToken.AdviceTerm(String str, String displayStr) -> {
searchTermsAdvice.add(str);
if (str.toLowerCase().startsWith("site:")) {
domain = str.substring("site:".length());
}
}
t.visit(qualityLimits);
case QueryToken.YearTerm(String str) -> year = parseSpecificationLimit(str);
case QueryToken.SizeTerm(String str) -> size = parseSpecificationLimit(str);
case QueryToken.RankTerm(String str) -> rank = parseSpecificationLimit(str);
case QueryToken.QualityTerm(String str) -> qualityLimit = parseSpecificationLimit(str);
case QueryToken.QsTerm(String str) -> queryStrategy = parseQueryStrategy(str);
default -> {}
}
}
var queryPermutations = queryPermutation.permuteQueriesNew(basicQuery);
List<SearchSubquery> subqueries = new ArrayList<>();
for (var parts : queryPermutations) {
QuerySearchTermsAccumulator termsAccumulator = new QuerySearchTermsAccumulator(parts);
SearchSubquery subquery = termsAccumulator.createSubquery();
domain = termsAccumulator.domain;
subqueries.add(subquery);
if (searchTermsInclude.isEmpty() && !searchTermsAdvice.isEmpty()) {
searchTermsInclude.addAll(searchTermsAdvice);
searchTermsAdvice.clear();
}
List<Integer> domainIds = params.domainIds();
@ -123,55 +137,85 @@ public class QueryFactory {
limits = limits.forSingleDomain();
}
var searchQuery = new SearchQuery(
queryExpansion.expandQuery(
searchTermsInclude
),
searchTermsInclude,
searchTermsExclude,
searchTermsAdvice,
searchTermsPriority,
searchTermCoherences
);
var specsBuilder = SearchSpecification.builder()
.subqueries(subqueries)
.query(searchQuery)
.humanQuery(query)
.quality(qualityLimits.qualityLimit)
.year(qualityLimits.year)
.size(qualityLimits.size)
.rank(qualityLimits.rank)
.quality(qualityLimit)
.year(year)
.size(size)
.rank(rank)
.domains(domainIds)
.queryLimits(limits)
.searchSetIdentifier(params.identifier())
.rankingParams(ResultRankingParameters.sensibleDefaults())
.queryStrategy(qualityLimits.queryStrategy);
.queryStrategy(queryStrategy);
SearchSpecification specs = specsBuilder.build();
for (var sq : specs.subqueries) {
sq.searchTermsAdvice.addAll(params.tacitAdvice());
sq.searchTermsPriority.addAll(params.tacitPriority());
sq.searchTermsInclude.addAll(params.tacitIncludes());
sq.searchTermsExclude.addAll(params.tacitExcludes());
}
specs.query.searchTermsAdvice.addAll(params.tacitAdvice());
specs.query.searchTermsPriority.addAll(params.tacitPriority());
specs.query.searchTermsExclude.addAll(params.tacitExcludes());
return new ProcessedQuery(specs, searchTermsHuman, domain);
}
private String normalizeDomainName(String str) {
return str.toLowerCase();
}
private List<String> toHumanSearchTerms(Token t) {
if (t.type == TokenType.LITERAL_TERM) {
return Arrays.asList(t.displayStr.split("\\s+"));
}
else if (t.type == TokenType.QUOT_TERM) {
return Arrays.asList(t.displayStr.replace("\"", "").split("\\s+"));
}
return Collections.emptyList();
}
private void analyzeSearchTerm(List<String> problems, Token term) {
final String word = term.str;
private void analyzeSearchTerm(List<String> problems, String str, String displayStr) {
final String word = str;
if (word.length() < WordPatterns.MIN_WORD_LENGTH) {
problems.add("Search term \"" + term.displayStr + "\" too short");
problems.add("Search term \"" + displayStr + "\" too short");
}
if (!word.contains("_") && word.length() >= WordPatterns.MAX_WORD_LENGTH) {
problems.add("Search term \"" + term.displayStr + "\" too long");
problems.add("Search term \"" + displayStr + "\" too long");
}
}
private SpecificationLimit parseSpecificationLimit(String str) {
int startChar = str.charAt(0);
int val = Integer.parseInt(str.substring(1));
if (startChar == '=') {
return SpecificationLimit.equals(val);
} else if (startChar == '<') {
return SpecificationLimit.lessThan(val);
} else if (startChar == '>') {
return SpecificationLimit.greaterThan(val);
} else {
return SpecificationLimit.none();
}
}
private QueryStrategy parseQueryStrategy(String str) {
return switch (str.toUpperCase()) {
case "RF_TITLE" -> QueryStrategy.REQUIRE_FIELD_TITLE;
case "RF_SUBJECT" -> QueryStrategy.REQUIRE_FIELD_SUBJECT;
case "RF_SITE" -> QueryStrategy.REQUIRE_FIELD_SITE;
case "RF_URL" -> QueryStrategy.REQUIRE_FIELD_URL;
case "RF_DOMAIN" -> QueryStrategy.REQUIRE_FIELD_DOMAIN;
case "RF_LINK" -> QueryStrategy.REQUIRE_FIELD_LINK;
case "SENTENCE" -> QueryStrategy.SENTENCE;
case "TOPIC" -> QueryStrategy.TOPIC;
default -> QueryStrategy.AUTO;
};
}
private boolean anyPartIsStopWord(String[] parts) {
for (String part : parts) {
if (WordPatterns.isStopWord(part)) {
return true;
}
}
return false;
}
}

View File

@ -1,93 +0,0 @@
package nu.marginalia.functions.searchquery.svc;
import nu.marginalia.api.searchquery.model.query.QueryParams;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.query.limit.SpecificationLimit;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenVisitor;
public class QueryLimitsAccumulator implements TokenVisitor {
public SpecificationLimit qualityLimit;
public SpecificationLimit year;
public SpecificationLimit size;
public SpecificationLimit rank;
public QueryStrategy queryStrategy = QueryStrategy.AUTO;
public QueryLimitsAccumulator(QueryParams params) {
qualityLimit = params.quality();
year = params.year();
size = params.size();
rank = params.rank();
}
private SpecificationLimit parseSpecificationLimit(String str) {
int startChar = str.charAt(0);
int val = Integer.parseInt(str.substring(1));
if (startChar == '=') {
return SpecificationLimit.equals(val);
} else if (startChar == '<') {
return SpecificationLimit.lessThan(val);
} else if (startChar == '>') {
return SpecificationLimit.greaterThan(val);
} else {
return SpecificationLimit.none();
}
}
private QueryStrategy parseQueryStrategy(String str) {
return switch (str.toUpperCase()) {
case "RF_TITLE" -> QueryStrategy.REQUIRE_FIELD_TITLE;
case "RF_SUBJECT" -> QueryStrategy.REQUIRE_FIELD_SUBJECT;
case "RF_SITE" -> QueryStrategy.REQUIRE_FIELD_SITE;
case "RF_URL" -> QueryStrategy.REQUIRE_FIELD_URL;
case "RF_DOMAIN" -> QueryStrategy.REQUIRE_FIELD_DOMAIN;
case "RF_LINK" -> QueryStrategy.REQUIRE_FIELD_LINK;
case "SENTENCE" -> QueryStrategy.SENTENCE;
case "TOPIC" -> QueryStrategy.TOPIC;
default -> QueryStrategy.AUTO;
};
}
@Override
public void onYearTerm(Token token) {
year = parseSpecificationLimit(token.str);
}
@Override
public void onSizeTerm(Token token) {
size = parseSpecificationLimit(token.str);
}
@Override
public void onRankTerm(Token token) {
rank = parseSpecificationLimit(token.str);
}
@Override
public void onQualityTerm(Token token) {
qualityLimit = parseSpecificationLimit(token.str);
}
@Override
public void onQsTerm(Token token) {
queryStrategy = parseQueryStrategy(token.str);
}
@Override
public void onLiteralTerm(Token token) {}
@Override
public void onQuotTerm(Token token) {}
@Override
public void onExcludeTerm(Token token) {}
@Override
public void onPriorityTerm(Token token) {}
@Override
public void onAdviceTerm(Token token) {}
}

View File

@ -1,109 +0,0 @@
package nu.marginalia.functions.searchquery.svc;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.language.WordPatterns;
import nu.marginalia.functions.searchquery.query_parser.token.Token;
import nu.marginalia.functions.searchquery.query_parser.token.TokenVisitor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** @see SearchSubquery */
public class QuerySearchTermsAccumulator implements TokenVisitor {
public List<String> searchTermsExclude = new ArrayList<>();
public List<String> searchTermsInclude = new ArrayList<>();
public List<String> searchTermsAdvice = new ArrayList<>();
public List<String> searchTermsPriority = new ArrayList<>();
public List<List<String>> searchTermCoherences = new ArrayList<>();
public String domain;
public SearchSubquery createSubquery() {
return new SearchSubquery(searchTermsInclude, searchTermsExclude, searchTermsAdvice, searchTermsPriority, searchTermCoherences);
}
public QuerySearchTermsAccumulator(List<Token> parts) {
for (Token t : parts) {
t.visit(this);
}
if (searchTermsInclude.isEmpty() && !searchTermsAdvice.isEmpty()) {
searchTermsInclude.addAll(searchTermsAdvice);
searchTermsAdvice.clear();
}
}
@Override
public void onLiteralTerm(Token token) {
searchTermsInclude.add(token.str);
}
@Override
public void onQuotTerm(Token token) {
String[] parts = token.str.split("_");
// HACK (2023-05-02 vlofgren)
//
// Checking for stop words here is a bit of a stop-gap to fix the issue of stop words being
// required in the query (which is a problem because they are not indexed). How to do this
// in a clean way is a bit of an open problem that may not get resolved until query-parsing is
// improved.
if (parts.length > 1 && !anyPartIsStopWord(parts)) {
// Prefer that the actual n-gram is present
searchTermsAdvice.add(token.str);
// Require that the terms appear in the same sentence
searchTermCoherences.add(Arrays.asList(parts));
// Require that each term exists in the document
// (needed for ranking)
searchTermsInclude.addAll(Arrays.asList(parts));
}
else {
searchTermsInclude.add(token.str);
}
}
private boolean anyPartIsStopWord(String[] parts) {
for (String part : parts) {
if (WordPatterns.isStopWord(part)) {
return true;
}
}
return false;
}
@Override
public void onExcludeTerm(Token token) {
searchTermsExclude.add(token.str);
}
@Override
public void onPriorityTerm(Token token) {
searchTermsPriority.add(token.str);
}
@Override
public void onAdviceTerm(Token token) {
searchTermsAdvice.add(token.str);
if (token.str.toLowerCase().startsWith("site:")) {
domain = token.str.substring("site:".length());
}
}
@Override
public void onYearTerm(Token token) {}
@Override
public void onSizeTerm(Token token) {}
@Override
public void onRankTerm(Token token) {}
@Override
public void onQualityTerm(Token token) {}
@Override
public void onQsTerm(Token token) {}
}

View File

@ -1,69 +0,0 @@
package nu.marginalia.util.ngrams;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.BitSet;
// It's unclear why this exists, we should probably use a BitSet instead?
// Chesterton's fence?
public class DenseBitMap {
public static final long MAX_CAPACITY_2GB_16BN_ITEMS=(1L<<34)-8;
public final long cardinality;
private final ByteBuffer buffer;
public DenseBitMap(long cardinality) {
this.cardinality = cardinality;
boolean misaligned = (cardinality & 7) > 0;
this.buffer = ByteBuffer.allocateDirect((int)((cardinality / 8) + (misaligned ? 1 : 0)));
}
public static DenseBitMap loadFromFile(Path file) throws IOException {
long size = Files.size(file);
var dbm = new DenseBitMap(size/8);
try (var bc = Files.newByteChannel(file)) {
while (dbm.buffer.position() < dbm.buffer.capacity()) {
bc.read(dbm.buffer);
}
}
dbm.buffer.clear();
return dbm;
}
public void writeToFile(Path file) throws IOException {
try (var bc = Files.newByteChannel(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
while (buffer.position() < buffer.capacity()) {
bc.write(buffer);
}
}
buffer.clear();
}
public boolean get(long pos) {
return (buffer.get((int)(pos >>> 3)) & ((byte)1 << (int)(pos & 7))) != 0;
}
/** Set the bit indexed by pos, returns
* its previous value.
*/
public boolean set(long pos) {
int offset = (int) (pos >>> 3);
int oldVal = buffer.get(offset);
int mask = (byte) 1 << (int) (pos & 7);
buffer.put(offset, (byte) (oldVal | mask));
return (oldVal & mask) != 0;
}
public void clear(long pos) {
int offset = (int)(pos >>> 3);
buffer.put(offset, (byte)(buffer.get(offset) & ~(byte)(1 << (int)(pos & 7))));
}
}

View File

@ -1,64 +0,0 @@
package nu.marginalia.util.ngrams;
import ca.rmen.porterstemmer.PorterStemmer;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.google.inject.Inject;
import nu.marginalia.LanguageModels;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.regex.Pattern;
public class NGramBloomFilter {
private final DenseBitMap bitMap;
private static final PorterStemmer ps = new PorterStemmer();
private static final HashFunction hasher = Hashing.murmur3_128(0);
private static final Logger logger = LoggerFactory.getLogger(NGramBloomFilter.class);
@Inject
public NGramBloomFilter(LanguageModels lm) throws IOException {
this(loadSafely(lm.ngramBloomFilter));
}
private static DenseBitMap loadSafely(Path path) throws IOException {
if (Files.isRegularFile(path)) {
return DenseBitMap.loadFromFile(path);
}
else {
logger.warn("NGrams file missing " + path);
return new DenseBitMap(1);
}
}
public NGramBloomFilter(DenseBitMap bitMap) {
this.bitMap = bitMap;
}
public boolean isKnownNGram(String word) {
long bit = bitForWord(word, bitMap.cardinality);
return bitMap.get(bit);
}
public static NGramBloomFilter load(Path file) throws IOException {
return new NGramBloomFilter(DenseBitMap.loadFromFile(file));
}
private static final Pattern underscore = Pattern.compile("_");
private static long bitForWord(String s, long n) {
String[] parts = underscore.split(s);
long hc = 0;
for (String part : parts) {
hc = hc * 31 + hasher.hashString(ps.stemWord(part), StandardCharsets.UTF_8).padToLong();
}
return (hc & 0x7FFF_FFFF_FFFF_FFFFL) % n;
}
}

View File

@ -80,6 +80,15 @@ public class TransformList<T> {
iter.remove();
}
}
else if (firstEntity.action == Action.NO_OP) {
if (secondEntry.action == Action.REPLACE) {
backingList.set(iter.nextIndex(), secondEntry.value);
}
else if (secondEntry.action == Action.REMOVE) {
iter.next();
iter.remove();
}
}
}
}

View File

@ -0,0 +1,115 @@
package nu.marginalia.functions.searchquery.query_parser.model;
import org.junit.jupiter.api.Test;
import java.util.Comparator;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
class QWordGraphTest {
@Test
void forwardReachability() {
// Construct a graph like
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("b"), "d");
var reachability = graph.forwardReachability();
System.out.println(reachability.get(graph.node("a")));
System.out.println(reachability.get(graph.node("b")));
System.out.println(reachability.get(graph.node("c")));
System.out.println(reachability.get(graph.node("d")));
assertEquals(Set.of(graph.node(" ^ ")), reachability.get(graph.node("a")));
assertEquals(Set.of(graph.node(" ^ "), graph.node("a")), reachability.get(graph.node("b")));
assertEquals(Set.of(graph.node(" ^ "), graph.node("a")), reachability.get(graph.node("d")));
assertEquals(Set.of(graph.node(" ^ "), graph.node("a"), graph.node("b"), graph.node("d")), reachability.get(graph.node("c")));
assertEquals(Set.of(graph.node(" ^ "), graph.node("a"), graph.node("b"), graph.node("d"), graph.node("c")), reachability.get(graph.node(" $ ")));
}
@Test
void reverseReachability() {
// Construct a graph like
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("b"), "d");
var reachability = graph.reverseReachability();
System.out.println(reachability.get(graph.node("a")));
System.out.println(reachability.get(graph.node("b")));
System.out.println(reachability.get(graph.node("c")));
System.out.println(reachability.get(graph.node("d")));
assertEquals(Set.of(graph.node(" $ ")), reachability.get(graph.node("c")));
assertEquals(Set.of(graph.node(" $ "), graph.node("c")), reachability.get(graph.node("b")));
assertEquals(Set.of(graph.node(" $ "), graph.node("c")), reachability.get(graph.node("d")));
assertEquals(Set.of(graph.node(" $ "), graph.node("c"), graph.node("b"), graph.node("d")), reachability.get(graph.node("a")));
assertEquals(Set.of(graph.node(" $ "), graph.node("c"), graph.node("b"), graph.node("d"), graph.node("a")), reachability.get(graph.node(" ^ ")));
}
@Test
void testCompile1() {
// Construct a graph like
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("b"), "d");
assertEquals("a c ( b | d )", graph.compileToQuery());
}
@Test
void testCompile2() {
// Construct a graph like
// ^ - a - b - c - $
QWordGraph graph = new QWordGraph("a", "b", "c");
assertEquals("a b c", graph.compileToQuery());
}
@Test
void testCompile3() {
// Construct a graph like
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("a"), "d");
assertEquals("b c ( a | d )", graph.compileToQuery());
}
@Test
void testCompile4() {
// Construct a graph like
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("c"), "d");
assertEquals("a b ( c | d )", graph.compileToQuery());
}
@Test
void testCompile5() {
// Construct a graph like
// /- e -\
// ^ - a - b - c - $
// \- d -/
QWordGraph graph = new QWordGraph("a", "b", "c");
graph.addVariant(graph.node("c"), "d");
graph.addVariant(graph.node("b"), "e");
assertEquals("a ( c ( b | e ) | d ( b | e ) )", graph.compileToQuery());
}
}

View File

@ -3,19 +3,21 @@ package nu.marginalia.query.svc;
import nu.marginalia.WmsaHome;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.functions.searchquery.query_parser.QueryExpansion;
import nu.marginalia.functions.searchquery.svc.QueryFactory;
import nu.marginalia.index.query.limit.QueryLimits;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.query.limit.SpecificationLimit;
import nu.marginalia.index.query.limit.SpecificationLimitType;
import nu.marginalia.util.language.EnglishDictionary;
import nu.marginalia.util.ngrams.NGramBloomFilter;
import nu.marginalia.segmentation.NgramLexicon;
import nu.marginalia.api.searchquery.model.query.QueryParams;
import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -28,12 +30,9 @@ public class QueryFactoryTest {
public static void setUpAll() throws IOException {
var lm = WmsaHome.getLanguageModels();
var tfd = new TermFrequencyDict(lm);
queryFactory = new QueryFactory(lm,
tfd,
new EnglishDictionary(tfd),
new NGramBloomFilter(lm)
queryFactory = new QueryFactory(
new QueryExpansion(new TermFrequencyDict(lm), new NgramLexicon(lm))
);
}
@ -55,6 +54,21 @@ public class QueryFactoryTest {
ResultRankingParameters.TemporalBias.NONE)).specs;
}
@Test
void qsec10() {
try (var lines = Files.lines(Path.of("/home/vlofgren/Exports/qsec10/webis-qsec-10-training-set/webis-qsec-10-training-set-queries.txt"))) {
lines.limit(1000).forEach(line -> {
String[] parts = line.split("\t");
if (parts.length == 2) {
System.out.println(parseAndGetSpecs(parts[1]).getQuery().compiledQuery);
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Test
public void testParseNoSpecials() {
var year = parseAndGetSpecs("in the year 2000").year;
@ -114,17 +128,15 @@ public class QueryFactoryTest {
{
// the is a stopword, so it should generate an ngram search term
var specs = parseAndGetSpecs("\"the shining\"");
assertEquals(List.of("the_shining"), specs.subqueries.iterator().next().searchTermsInclude);
assertEquals(List.of(), specs.subqueries.iterator().next().searchTermsAdvice);
assertEquals(List.of(), specs.subqueries.iterator().next().searchTermCoherences);
assertEquals("the_shining", specs.query.compiledQuery);
}
{
// tde isn't a stopword, so we should get the normal behavior
var specs = parseAndGetSpecs("\"tde shining\"");
assertEquals(List.of("tde", "shining"), specs.subqueries.iterator().next().searchTermsInclude);
assertEquals(List.of("tde_shining"), specs.subqueries.iterator().next().searchTermsAdvice);
assertEquals(List.of(List.of("tde", "shining")), specs.subqueries.iterator().next().searchTermCoherences);
assertEquals("tde shining", specs.query.compiledQuery);
assertEquals(List.of("tde_shining"), specs.query.searchTermsAdvice);
assertEquals(List.of(List.of("tde", "shining")), specs.query.searchTermCoherences);
}
}
@ -152,8 +164,18 @@ public class QueryFactoryTest {
@Test
public void testPriorityTerm() {
var subquery = parseAndGetSpecs("physics ?tld:edu").subqueries.iterator().next();
var subquery = parseAndGetSpecs("physics ?tld:edu").query;
assertEquals(List.of("tld:edu"), subquery.searchTermsPriority);
assertEquals(List.of("physics"), subquery.searchTermsInclude);
assertEquals("physics", subquery.compiledQuery);
}
@Test
public void testExpansion() {
long start = System.currentTimeMillis();
var subquery = parseAndGetSpecs("elden ring mechanical keyboard slackware linux duke nukem 3d").query;
System.out.println("Time: " + (System.currentTimeMillis() - start));
System.out.println(subquery.compiledQuery);
}
}

View File

@ -7,6 +7,7 @@ import nu.marginalia.index.query.EntrySource;
import static java.lang.Math.min;
public class ReverseIndexEntrySource implements EntrySource {
private final String name;
private final BTreeReader reader;
int pos;
@ -15,9 +16,11 @@ public class ReverseIndexEntrySource implements EntrySource {
final int entrySize;
private final long wordId;
public ReverseIndexEntrySource(BTreeReader reader,
public ReverseIndexEntrySource(String name,
BTreeReader reader,
int entrySize,
long wordId) {
this.name = name;
this.reader = reader;
this.entrySize = entrySize;
this.wordId = wordId;
@ -46,7 +49,7 @@ public class ReverseIndexEntrySource implements EntrySource {
return;
for (int ri = entrySize, wi=1; ri < buffer.end ; ri+=entrySize, wi++) {
buffer.data[wi] = buffer.data[ri];
buffer.data.set(wi, buffer.data.get(ri));
}
buffer.end /= entrySize;
@ -60,6 +63,6 @@ public class ReverseIndexEntrySource implements EntrySource {
@Override
public String indexName() {
return "Full:" + Long.toHexString(wordId);
return name + ":" + Long.toHexString(wordId);
}
}

View File

@ -25,8 +25,11 @@ public class ReverseIndexReader {
private final long wordsDataOffset;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final BTreeReader wordsBTreeReader;
private final String name;
public ReverseIndexReader(String name, Path words, Path documents) throws IOException {
this.name = name;
public ReverseIndexReader(Path words, Path documents) throws IOException {
if (!Files.exists(words) || !Files.exists(documents)) {
this.words = null;
this.documents = null;
@ -65,8 +68,12 @@ public class ReverseIndexReader {
}
long wordOffset(long wordId) {
long idx = wordsBTreeReader.findEntry(wordId);
/** Calculate the offset of the word in the documents.
* If the return-value is negative, the term does not exist
* in the index.
*/
long wordOffset(long termId) {
long idx = wordsBTreeReader.findEntry(termId);
if (idx < 0)
return -1L;
@ -74,37 +81,43 @@ public class ReverseIndexReader {
return words.get(wordsDataOffset + idx + 1);
}
public EntrySource documents(long wordId) {
public EntrySource documents(long termId) {
if (null == words) {
logger.warn("Reverse index is not ready, dropping query");
return new EmptyEntrySource();
}
long offset = wordOffset(wordId);
long offset = wordOffset(termId);
if (offset < 0) return new EmptyEntrySource();
if (offset < 0) // No documents
return new EmptyEntrySource();
return new ReverseIndexEntrySource(createReaderNew(offset), 2, wordId);
return new ReverseIndexEntrySource(name, createReaderNew(offset), 2, termId);
}
public QueryFilterStepIf also(long wordId) {
long offset = wordOffset(wordId);
/** Create a filter step requiring the specified termId to exist in the documents */
public QueryFilterStepIf also(long termId) {
long offset = wordOffset(termId);
if (offset < 0) return new QueryFilterNoPass();
if (offset < 0) // No documents
return new QueryFilterNoPass();
return new ReverseIndexRetainFilter(createReaderNew(offset), "full", wordId);
return new ReverseIndexRetainFilter(createReaderNew(offset), name, termId);
}
public QueryFilterStepIf not(long wordId) {
long offset = wordOffset(wordId);
/** Create a filter step requiring the specified termId to be absent from the documents */
public QueryFilterStepIf not(long termId) {
long offset = wordOffset(termId);
if (offset < 0) return new QueryFilterLetThrough();
if (offset < 0) // No documents
return new QueryFilterLetThrough();
return new ReverseIndexRejectFilter(createReaderNew(offset));
}
public int numDocuments(long wordId) {
long offset = wordOffset(wordId);
/** Return the number of documents with the termId in the index */
public int numDocuments(long termId) {
long offset = wordOffset(termId);
if (offset < 0)
return 0;
@ -112,15 +125,20 @@ public class ReverseIndexReader {
return createReaderNew(offset).numEntries();
}
/** Create a BTreeReader for the document offset associated with a termId */
private BTreeReader createReaderNew(long offset) {
return new BTreeReader(documents, ReverseIndexParameters.docsBTreeContext, offset);
return new BTreeReader(
documents,
ReverseIndexParameters.docsBTreeContext,
offset);
}
public long[] getTermMeta(long wordId, long[] docIds) {
long offset = wordOffset(wordId);
public long[] getTermMeta(long termId, long[] docIds) {
long offset = wordOffset(termId);
if (offset < 0) {
logger.debug("Missing offset for word {}", wordId);
// This is likely a bug in the code, but we can't throw an exception here
logger.debug("Missing offset for word {}", termId);
return new long[docIds.length];
}
@ -133,10 +151,9 @@ public class ReverseIndexReader {
private boolean isUniqueAndSorted(long[] ids) {
if (ids.length == 0)
return true;
long prev = ids[0];
for (int i = 1; i < ids.length; i++) {
if(ids[i] <= prev)
if(ids[i] <= ids[i-1])
return false;
}

View File

@ -102,7 +102,7 @@ class ReverseIndexReaderTest {
preindex.finalizeIndex(docsFile, wordsFile);
preindex.delete();
return new ReverseIndexReader(wordsFile, docsFile);
return new ReverseIndexReader("test", wordsFile, docsFile);
}
}

View File

@ -41,14 +41,14 @@ public class IndexFactory {
public ReverseIndexReader getReverseIndexReader() throws IOException {
return new ReverseIndexReader(
return new ReverseIndexReader("full",
ReverseIndexFullFileNames.resolve(liveStorage, ReverseIndexFullFileNames.FileIdentifier.WORDS, ReverseIndexFullFileNames.FileVersion.CURRENT),
ReverseIndexFullFileNames.resolve(liveStorage, ReverseIndexFullFileNames.FileIdentifier.DOCS, ReverseIndexFullFileNames.FileVersion.CURRENT)
);
}
public ReverseIndexReader getReverseIndexPrioReader() throws IOException {
return new ReverseIndexReader(
return new ReverseIndexReader("prio",
ReverseIndexPrioFileNames.resolve(liveStorage, ReverseIndexPrioFileNames.FileIdentifier.WORDS, ReverseIndexPrioFileNames.FileVersion.CURRENT),
ReverseIndexPrioFileNames.resolve(liveStorage, ReverseIndexPrioFileNames.FileIdentifier.DOCS, ReverseIndexPrioFileNames.FileVersion.CURRENT)
);

View File

@ -9,14 +9,15 @@ import io.prometheus.client.Histogram;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.SneakyThrows;
import nu.marginalia.api.searchquery.*;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.results.*;
import nu.marginalia.array.buffer.LongQueryBuffer;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.model.SearchTerms;
import nu.marginalia.index.model.SearchTermsUtil;
import nu.marginalia.index.query.IndexQuery;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.results.IndexResultValuatorService;
@ -135,15 +136,15 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
var rawItem = RpcRawResultItem.newBuilder();
rawItem.setCombinedId(rawResult.combinedId);
rawItem.setResultsFromDomain(rawResult.resultsFromDomain);
rawItem.setHtmlFeatures(rawResult.htmlFeatures);
rawItem.setEncodedDocMetadata(rawResult.encodedDocMetadata);
rawItem.setHasPriorityTerms(rawResult.hasPrioTerm);
for (var score : rawResult.keywordScores) {
rawItem.addKeywordScores(
RpcResultKeywordScore.newBuilder()
.setEncodedDocMetadata(score.encodedDocMetadata())
.setEncodedWordMetadata(score.encodedWordMetadata())
.setKeyword(score.keyword)
.setHtmlFeatures(score.htmlFeatures())
.setSubquery(score.subquery)
);
}
@ -156,6 +157,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
.setTitle(result.title)
.setUrl(result.url.toString())
.setWordsTotal(result.wordsTotal)
.setBestPositions(result.bestPositions)
.setRawItem(rawItem);
if (result.pubYear != null) {
@ -203,7 +205,7 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
return new SearchResultSet(List.of());
}
ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.subqueries);
ResultRankingContext rankingContext = createRankingContext(params.rankingParams, params.compiledQueryIds);
var queryExecution = new QueryExecution(rankingContext, params.fetchSize);
@ -255,15 +257,11 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
/** Execute a search query */
public SearchResultSet run(SearchParameters parameters) throws SQLException, InterruptedException {
for (var subquery : parameters.subqueries) {
var terms = new SearchTerms(subquery);
if (terms.isEmpty())
continue;
var terms = new SearchTerms(parameters.query, parameters.compiledQueryIds);
for (var indexQuery : index.createQueries(terms, parameters.queryParams)) {
workerPool.execute(new IndexLookup(indexQuery, parameters.budget));
}
}
for (int i = 0; i < indexValuationThreads; i++) {
workerPool.execute(new ResultRanker(parameters, resultRankingContext));
@ -327,7 +325,9 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
buffer.reset();
query.getMoreResults(buffer);
results.addElements(0, buffer.data, 0, buffer.end);
for (int i = 0; i < buffer.end; i++) {
results.add(buffer.data.get(i));
}
if (results.size() < 512) {
enqueueResults(new CombinedDocIdList(results));
@ -413,18 +413,23 @@ public class IndexGrpcService extends IndexApiGrpc.IndexApiImplBase {
}
private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams, List<SearchSubquery> subqueries) {
final var termToId = SearchTermsUtil.getAllIncludeTerms(subqueries);
final Map<String, Integer> termFrequencies = new HashMap<>(termToId.size());
final Map<String, Integer> prioFrequencies = new HashMap<>(termToId.size());
private ResultRankingContext createRankingContext(ResultRankingParameters rankingParams,
CompiledQueryLong compiledQueryIds)
{
termToId.forEach((key, id) -> termFrequencies.put(key, index.getTermFrequency(id)));
termToId.forEach((key, id) -> prioFrequencies.put(key, index.getTermFrequencyPrio(id)));
int[] full = new int[compiledQueryIds.size()];
int[] prio = new int[compiledQueryIds.size()];
for (int idx = 0; idx < compiledQueryIds.size(); idx++) {
long id = compiledQueryIds.at(idx);
full[idx] = index.getTermFrequency(id);
prio[idx] = index.getTermFrequencyPrio(id);
}
return new ResultRankingContext(index.getTotalDocCount(),
rankingParams,
termFrequencies,
prioFrequencies);
new CqDataInt(full),
new CqDataInt(prio));
}
}

View File

@ -38,6 +38,14 @@ public class CombinedIndexReader {
return new IndexQueryBuilderImpl(reverseIndexFullReader, reverseIndexPriorityReader, query);
}
public QueryFilterStepIf hasWordFull(long termId) {
return reverseIndexFullReader.also(termId);
}
public QueryFilterStepIf hasWordPrio(long termId) {
return reverseIndexPriorityReader.also(termId);
}
/** Creates a query builder for terms in the priority index */
public IndexQueryBuilder findPriorityWord(long wordId) {

View File

@ -1,9 +1,11 @@
package nu.marginalia.index.index;
import java.util.List;
import gnu.trove.set.hash.TLongHashSet;
import nu.marginalia.index.ReverseIndexReader;
import nu.marginalia.index.query.IndexQuery;
import nu.marginalia.index.query.IndexQueryBuilder;
import nu.marginalia.index.query.filter.QueryFilterAnyOf;
import nu.marginalia.index.query.filter.QueryFilterStepIf;
public class IndexQueryBuilderImpl implements IndexQueryBuilder {
@ -34,7 +36,7 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder {
return this;
}
public IndexQueryBuilder alsoFull(long termId) {
public IndexQueryBuilder also(long termId) {
if (alreadyConsideredTerms.add(termId)) {
query.addInclusionFilter(reverseIndexFullReader.also(termId));
@ -43,16 +45,7 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder {
return this;
}
public IndexQueryBuilder alsoPrio(long termId) {
if (alreadyConsideredTerms.add(termId)) {
query.addInclusionFilter(reverseIndexPrioReader.also(termId));
}
return this;
}
public IndexQueryBuilder notFull(long termId) {
public IndexQueryBuilder not(long termId) {
query.addInclusionFilter(reverseIndexFullReader.not(termId));
@ -66,6 +59,20 @@ public class IndexQueryBuilderImpl implements IndexQueryBuilder {
return this;
}
public IndexQueryBuilder addInclusionFilterAny(List<QueryFilterStepIf> filterSteps) {
if (filterSteps.isEmpty())
return this;
if (filterSteps.size() == 1) {
query.addInclusionFilter(filterSteps.getFirst());
}
else {
query.addInclusionFilter(new QueryFilterAnyOf(filterSteps));
}
return this;
}
public IndexQuery build() {
return query;
}

View File

@ -0,0 +1,108 @@
package nu.marginalia.index.index;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongArraySet;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
/** Helper class for index query construction */
public class QueryBranchWalker {
private static final Logger logger = LoggerFactory.getLogger(QueryBranchWalker.class);
public final long[] priorityOrder;
public final List<LongSet> paths;
public final long termId;
private QueryBranchWalker(long[] priorityOrder, List<LongSet> paths, long termId) {
this.priorityOrder = priorityOrder;
this.paths = paths;
this.termId = termId;
}
public boolean atEnd() {
return priorityOrder.length == 0;
}
/** Group the provided paths by the lowest termId they contain per the provided priorityOrder,
* into a list of QueryBranchWalkers. This can be performed iteratively on the resultant QBW:s
* to traverse the tree via the next() method.
* <p></p>
* The paths can be extracted through the {@link nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates CompiledQueryAggregates}
* queriesAggregate method.
*/
public static List<QueryBranchWalker> create(long[] priorityOrder, List<LongSet> paths) {
if (paths.isEmpty())
return List.of();
List<QueryBranchWalker> ret = new ArrayList<>();
List<LongSet> remainingPaths = new LinkedList<>(paths);
remainingPaths.removeIf(LongSet::isEmpty);
List<LongSet> pathsForPrio = new ArrayList<>();
for (int i = 0; i < priorityOrder.length; i++) {
long termId = priorityOrder[i];
var it = remainingPaths.iterator();
while (it.hasNext()) {
var path = it.next();
if (path.contains(termId)) {
// Remove the current termId from the path
path.remove(termId);
// Add it to the set of paths associated with the termId
pathsForPrio.add(path);
// Remove it from consideration
it.remove();
}
}
if (!pathsForPrio.isEmpty()) {
long[] newPrios = keepRelevantPriorities(priorityOrder, pathsForPrio);
ret.add(new QueryBranchWalker(newPrios, new ArrayList<>(pathsForPrio), termId));
pathsForPrio.clear();
}
}
// This happens if the priorityOrder array doesn't contain all items in the paths,
// in practice only when an index doesn't contain all the search terms, so we can just
// skip those paths
if (!remainingPaths.isEmpty()) {
logger.debug("Dropping: {}", remainingPaths);
}
return ret;
}
/** From the provided priorityOrder array, keep the elements that are present in any set in paths */
private static long[] keepRelevantPriorities(long[] priorityOrder, List<LongSet> paths) {
LongArrayList remainingPrios = new LongArrayList(paths.size());
// these sets are typically very small so array set is a good choice
LongSet allElements = new LongArraySet(priorityOrder.length);
for (var path : paths) {
allElements.addAll(path);
}
for (var p : priorityOrder) {
if (allElements.contains(p))
remainingPrios.add(p);
}
return remainingPrios.elements();
}
/** Convenience method that applies the create() method
* to the priority order and paths associated with this instance */
public List<QueryBranchWalker> next() {
return create(priorityOrder, paths);
}
}

View File

@ -2,6 +2,12 @@ package nu.marginalia.index.index;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates;
import nu.marginalia.index.query.filter.QueryFilterAllOf;
import nu.marginalia.index.query.filter.QueryFilterAnyOf;
import nu.marginalia.index.query.filter.QueryFilterStepIf;
import nu.marginalia.index.results.model.ids.CombinedDocIdList;
import nu.marginalia.index.results.model.ids.DocMetadataList;
import nu.marginalia.index.model.QueryParams;
@ -14,12 +20,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
/** This class delegates SearchIndexReader and deals with the stateful nature of the index,
* i.e. it may be possible to reconstruct the index and load a new set of data.
@ -87,7 +92,6 @@ public class StatefulIndex {
logger.error("Uncaught exception", ex);
}
finally {
lock.unlock();
}
@ -105,7 +109,6 @@ public class StatefulIndex {
return combinedIndexReader != null && combinedIndexReader.isLoaded();
}
public List<IndexQuery> createQueries(SearchTerms terms, QueryParams params) {
if (!isLoaded()) {
@ -113,55 +116,106 @@ public class StatefulIndex {
return Collections.emptyList();
}
final long[] orderedIncludes = terms.sortedDistinctIncludes(this::compareKeywords);
final long[] orderedIncludesPrio = terms.sortedDistinctIncludes(this::compareKeywordsPrio);
List<IndexQueryBuilder> queryHeads = new ArrayList<>(10);
List<IndexQuery> queries = new ArrayList<>(10);
// To ensure that good results are discovered, create separate query heads for the priority index that
// filter for terms that contain pairs of two search terms
if (orderedIncludesPrio.length > 1) {
for (int i = 0; i + 1 < orderedIncludesPrio.length; i++) {
for (int j = i + 1; j < orderedIncludesPrio.length; j++) {
var entrySource = combinedIndexReader
.findPriorityWord(orderedIncludesPrio[i])
.alsoPrio(orderedIncludesPrio[j]);
queryHeads.add(entrySource);
final long[] termPriority = terms.sortedDistinctIncludes(this::compareKeywords);
List<LongSet> paths = CompiledQueryAggregates.queriesAggregate(terms.compiledQuery());
// Remove any paths that do not contain all prioritized terms, as this means
// the term is missing from the index and can never be found
paths.removeIf(containsAll(termPriority).negate());
List<QueryBranchWalker> walkers = QueryBranchWalker.create(termPriority, paths);
for (var walker : walkers) {
for (var builder : List.of(
combinedIndexReader.findPriorityWord(walker.termId),
combinedIndexReader.findFullWord(walker.termId)
))
{
queryHeads.add(builder);
if (walker.atEnd())
continue; // Single term search query
// Add filter steps for the remaining combinations of terms
List<QueryFilterStepIf> filterSteps = new ArrayList<>();
for (var step : walker.next()) {
filterSteps.add(createFilter(step, 0));
}
builder.addInclusionFilterAny(filterSteps);
}
}
// Next consider entries that appear only once in the priority index
for (var wordId : orderedIncludesPrio) {
queryHeads.add(combinedIndexReader.findPriorityWord(wordId));
}
// Finally consider terms in the full index
queryHeads.add(combinedIndexReader.findFullWord(orderedIncludes[0]));
// Add additional conditions to the query heads
for (var query : queryHeads) {
if (query == null) {
return Collections.emptyList();
}
// Note that we can add all includes as filters, even though
// they may not be present in the query head, as the query builder
// will ignore redundant include filters:
for (long orderedInclude : orderedIncludes) {
query = query.alsoFull(orderedInclude);
// Advice terms are a special case, mandatory but not ranked, and exempt from re-writing
for (long term : terms.advice()) {
query = query.also(term);
}
for (long term : terms.excludes()) {
query = query.notFull(term);
query = query.not(term);
}
// Run these filter steps last, as they'll worst-case cause as many page faults as there are
// items in the buffer
queries.add(query.addInclusionFilter(combinedIndexReader.filterForParams(params)).build());
query.addInclusionFilter(combinedIndexReader.filterForParams(params));
}
return queries;
return queryHeads
.stream()
.map(IndexQueryBuilder::build)
.toList();
}
/** Recursively create a filter step based on the QBW and its children */
private QueryFilterStepIf createFilter(QueryBranchWalker walker, int depth) {
// Create a filter for the current termId
final QueryFilterStepIf ownFilterCondition = ownFilterCondition(walker, depth);
var childSteps = walker.next();
if (childSteps.isEmpty()) // no children, and so we're satisfied with just a single filter condition
return ownFilterCondition;
// If there are children, we append the filter conditions for each child as an anyOf condition
// to the current filter condition
List<QueryFilterStepIf> combinedFilters = new ArrayList<>();
for (var step : childSteps) {
// Recursion will be limited to a fairly shallow stack depth due to how the queries are constructed.
var childFilter = createFilter(step, depth+1);
combinedFilters.add(new QueryFilterAllOf(ownFilterCondition, childFilter));
}
// Flatten the filter conditions if there's only one branch
if (combinedFilters.size() == 1)
return combinedFilters.getFirst();
else
return new QueryFilterAnyOf(combinedFilters);
}
/** Create a filter condition based on the termId associated with the QBW */
private QueryFilterStepIf ownFilterCondition(QueryBranchWalker walker, int depth) {
if (depth < 2) {
// At shallow depths we prioritize terms that appear in the priority index,
// to increase the odds we find "good" results before the execution timer runs out
return new QueryFilterAnyOf(
combinedIndexReader.hasWordPrio(walker.termId),
combinedIndexReader.hasWordFull(walker.termId)
);
} else {
return combinedIndexReader.hasWordFull(walker.termId);
}
}
private Predicate<LongSet> containsAll(long[] permitted) {
LongSet permittedTerms = new LongOpenHashSet(permitted);
return permittedTerms::containsAll;
}
private int compareKeywords(long a, long b) {
@ -171,13 +225,6 @@ public class StatefulIndex {
);
}
private int compareKeywordsPrio(long a, long b) {
return Long.compare(
combinedIndexReader.numHitsPrio(a),
combinedIndexReader.numHitsPrio(b)
);
}
/** Return an array of encoded document metadata longs corresponding to the
* document identifiers provided; with metadata for termId. The input array
* docs[] *must* be sorted.

View File

@ -2,16 +2,16 @@ package nu.marginalia.index.model;
import nu.marginalia.api.searchquery.IndexProtobufCodec;
import nu.marginalia.api.searchquery.RpcIndexQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryParser;
import nu.marginalia.api.searchquery.model.query.SearchSpecification;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.index.query.IndexSearchBudget;
import nu.marginalia.index.query.limit.QueryStrategy;
import nu.marginalia.index.searchset.SearchSet;
import java.util.ArrayList;
import java.util.List;
import static nu.marginalia.api.searchquery.IndexProtobufCodec.convertSpecLimit;
public class SearchParameters {
@ -21,13 +21,16 @@ public class SearchParameters {
*/
public final int fetchSize;
public final IndexSearchBudget budget;
public final List<SearchSubquery> subqueries;
public final SearchQuery query;
public final QueryParams queryParams;
public final ResultRankingParameters rankingParams;
public final int limitByDomain;
public final int limitTotal;
public final CompiledQuery<String> compiledQuery;
public final CompiledQueryLong compiledQueryIds;
// mutable:
/**
@ -40,7 +43,7 @@ public class SearchParameters {
this.fetchSize = limits.fetchSize();
this.budget = new IndexSearchBudget(limits.timeoutMs());
this.subqueries = specsSet.subqueries;
this.query = specsSet.query;
this.limitByDomain = limits.resultsByDomain();
this.limitTotal = limits.resultsTotal();
@ -52,6 +55,9 @@ public class SearchParameters {
searchSet,
specsSet.queryStrategy);
compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery);
compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
rankingParams = specsSet.rankingParams;
}
@ -63,11 +69,8 @@ public class SearchParameters {
// The time budget is halved because this is the point when we start to
// wrap up the search and return the results.
this.budget = new IndexSearchBudget(limits.timeoutMs() / 2);
this.query = IndexProtobufCodec.convertRpcQuery(request.getQuery());
this.subqueries = new ArrayList<>(request.getSubqueriesCount());
for (int i = 0; i < request.getSubqueriesCount(); i++) {
this.subqueries.add(IndexProtobufCodec.convertSearchSubquery(request.getSubqueries(i)));
}
this.limitByDomain = limits.resultsByDomain();
this.limitTotal = limits.resultsTotal();
@ -79,9 +82,13 @@ public class SearchParameters {
searchSet,
QueryStrategy.valueOf(request.getQueryStrategy()));
compiledQuery = CompiledQueryParser.parse(this.query.compiledQuery);
compiledQueryIds = compiledQuery.mapToLong(SearchTermsUtil::getWordId);
rankingParams = IndexProtobufCodec.convertRankingParameterss(request.getParameters());
}
public long getDataCost() {
return dataCost;
}

View File

@ -3,49 +3,36 @@ package nu.marginalia.index.model;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongComparator;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static nu.marginalia.index.model.SearchTermsUtil.getWordId;
public final class SearchTerms {
private final LongList includes;
private final LongList advice;
private final LongList excludes;
private final LongList priority;
private final List<LongList> coherences;
public SearchTerms(
LongList includes,
LongList excludes,
LongList priority,
List<LongList> coherences
) {
this.includes = includes;
this.excludes = excludes;
this.priority = priority;
this.coherences = coherences;
private final CompiledQueryLong compiledQueryIds;
public SearchTerms(SearchQuery query,
CompiledQueryLong compiledQueryIds)
{
this.excludes = new LongArrayList();
this.priority = new LongArrayList();
this.coherences = new ArrayList<>();
this.advice = new LongArrayList();
this.compiledQueryIds = compiledQueryIds;
for (var word : query.searchTermsAdvice) {
advice.add(getWordId(word));
}
public SearchTerms(SearchSubquery subquery) {
this(new LongArrayList(),
new LongArrayList(),
new LongArrayList(),
new ArrayList<>());
for (var word : subquery.searchTermsInclude) {
includes.add(getWordId(word));
}
for (var word : subquery.searchTermsAdvice) {
// This looks like a bug, but it's not
includes.add(getWordId(word));
}
for (var coherence : subquery.searchTermCoherences) {
for (var coherence : query.searchTermCoherences) {
LongList parts = new LongArrayList(coherence.size());
for (var word : coherence) {
@ -55,39 +42,32 @@ public final class SearchTerms {
coherences.add(parts);
}
for (var word : subquery.searchTermsExclude) {
for (var word : query.searchTermsExclude) {
excludes.add(getWordId(word));
}
for (var word : subquery.searchTermsPriority) {
for (var word : query.searchTermsPriority) {
priority.add(getWordId(word));
}
}
public boolean isEmpty() {
return includes.isEmpty();
return compiledQueryIds.isEmpty();
}
public long[] sortedDistinctIncludes(LongComparator comparator) {
if (includes.isEmpty())
return includes.toLongArray();
LongList list = new LongArrayList(new LongOpenHashSet(includes));
LongList list = new LongArrayList(compiledQueryIds.copyData());
list.sort(comparator);
return list.toLongArray();
}
public int size() {
return includes.size() + excludes.size() + priority.size();
}
public LongList includes() {
return includes;
}
public LongList excludes() {
return excludes;
}
public LongList advice() {
return advice;
}
public LongList priority() {
return priority;
}
@ -96,29 +76,6 @@ public final class SearchTerms {
return coherences;
}
@Override
public boolean equals(Object obj) {
if (obj == this) return true;
if (obj == null || obj.getClass() != this.getClass()) return false;
var that = (SearchTerms) obj;
return Objects.equals(this.includes, that.includes) &&
Objects.equals(this.excludes, that.excludes) &&
Objects.equals(this.priority, that.priority) &&
Objects.equals(this.coherences, that.coherences);
}
@Override
public int hashCode() {
return Objects.hash(includes, excludes, priority, coherences);
}
@Override
public String toString() {
return "SearchTerms[" +
"includes=" + includes + ", " +
"excludes=" + excludes + ", " +
"priority=" + priority + ", " +
"coherences=" + coherences + ']';
}
public CompiledQueryLong compiledQuery() { return compiledQueryIds; }
}

View File

@ -1,29 +1,9 @@
package nu.marginalia.index.model;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.hash.MurmurHash3_128;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class SearchTermsUtil {
/** Extract all include-terms from the specified subqueries,
* and a return a map of the terms and their termIds.
*/
public static Map<String, Long> getAllIncludeTerms(List<SearchSubquery> subqueries) {
Map<String, Long> ret = new HashMap<>();
for (var subquery : subqueries) {
for (var include : subquery.searchTermsInclude) {
ret.computeIfAbsent(include, i -> getWordId(include));
}
}
return ret;
}
private static final MurmurHash3_128 hasher = new MurmurHash3_128();
/** Translate the word to a unique id. */

View File

@ -4,7 +4,8 @@ import com.google.inject.Inject;
import gnu.trove.map.hash.TObjectLongHashMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectArrayMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.query.SearchQuery;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.SearchTermsUtil;
import nu.marginalia.index.results.model.QuerySearchTerms;
@ -13,9 +14,6 @@ import nu.marginalia.index.results.model.TermMetadataForCombinedDocumentIds;
import nu.marginalia.index.results.model.ids.CombinedDocIdList;
import nu.marginalia.index.results.model.ids.TermIdList;
import java.util.ArrayList;
import java.util.List;
import static nu.marginalia.index.results.model.TermCoherenceGroupList.TermCoherenceGroup;
import static nu.marginalia.index.results.model.TermMetadataForCombinedDocumentIds.DocumentsWithMetadata;
@ -42,14 +40,20 @@ public class IndexMetadataService {
return new TermMetadataForCombinedDocumentIds(termdocToMeta);
}
public QuerySearchTerms getSearchTerms(List<SearchSubquery> searchTermVariants) {
public QuerySearchTerms getSearchTerms(CompiledQuery<String> compiledQuery, SearchQuery searchQuery) {
LongArrayList termIdsList = new LongArrayList();
LongArrayList termIdsPrio = new LongArrayList();
TObjectLongHashMap<String> termToId = new TObjectLongHashMap<>(10, 0.75f, -1);
for (var subquery : searchTermVariants) {
for (var term : subquery.searchTermsInclude) {
for (String word : compiledQuery) {
long id = SearchTermsUtil.getWordId(word);
termIdsList.add(id);
termToId.put(word, id);
}
for (var term : searchQuery.searchTermsAdvice) {
if (termToId.containsKey(term)) {
continue;
}
@ -58,27 +62,25 @@ public class IndexMetadataService {
termIdsList.add(id);
termToId.put(term, id);
}
for (var term : searchQuery.searchTermsPriority) {
if (termToId.containsKey(term)) {
continue;
}
long id = SearchTermsUtil.getWordId(term);
termIdsList.add(id);
termIdsPrio.add(id);
termToId.put(term, id);
}
return new QuerySearchTerms(termToId,
new TermIdList(termIdsList),
getTermCoherences(searchTermVariants));
}
private TermCoherenceGroupList getTermCoherences(List<SearchSubquery> searchTermVariants) {
List<TermCoherenceGroup> coherences = new ArrayList<>();
for (var subquery : searchTermVariants) {
for (var coh : subquery.searchTermCoherences) {
coherences.add(new TermCoherenceGroup(coh));
}
// It's assumed each subquery has identical coherences
break;
}
return new TermCoherenceGroupList(coherences);
new TermIdList(termIdsPrio),
new TermCoherenceGroupList(
searchQuery.searchTermCoherences.stream().map(TermCoherenceGroup::new).toList()
)
);
}
}

View File

@ -1,10 +1,12 @@
package nu.marginalia.index.results;
import nu.marginalia.api.searchquery.model.query.SearchSubquery;
import nu.marginalia.api.searchquery.model.compiled.*;
import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.api.searchquery.model.results.SearchResultItem;
import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore;
import nu.marginalia.index.index.StatefulIndex;
import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.results.model.ids.CombinedDocIdList;
import nu.marginalia.index.model.QueryParams;
import nu.marginalia.index.results.model.QuerySearchTerms;
@ -23,7 +25,6 @@ import java.util.List;
* reasons to cache this data, and performs the calculations */
public class IndexResultValuationContext {
private final StatefulIndex statefulIndex;
private final List<List<String>> searchTermVariants;
private final QueryParams queryParams;
private final TermMetadataForCombinedDocumentIds termMetadataForCombinedDocumentIds;
@ -31,24 +32,28 @@ public class IndexResultValuationContext {
private final ResultRankingContext rankingContext;
private final ResultValuator searchResultValuator;
private final CompiledQuery<String> compiledQuery;
private final CompiledQueryLong compiledQueryIds;
public IndexResultValuationContext(IndexMetadataService metadataService,
ResultValuator searchResultValuator,
CombinedDocIdList ids,
StatefulIndex statefulIndex,
ResultRankingContext rankingContext,
List<SearchSubquery> subqueries,
QueryParams queryParams
SearchParameters params
) {
this.statefulIndex = statefulIndex;
this.rankingContext = rankingContext;
this.searchResultValuator = searchResultValuator;
this.searchTermVariants = subqueries.stream().map(sq -> sq.searchTermsInclude).distinct().toList();
this.queryParams = queryParams;
this.queryParams = params.queryParams;
this.compiledQuery = params.compiledQuery;
this.compiledQueryIds = params.compiledQueryIds;
this.searchTerms = metadataService.getSearchTerms(subqueries);
this.termMetadataForCombinedDocumentIds = metadataService.getTermMetadataForDocuments(ids, searchTerms.termIdsAll);
this.searchTerms = metadataService.getSearchTerms(params.compiledQuery, params.query);
this.termMetadataForCombinedDocumentIds = metadataService.getTermMetadataForDocuments(ids,
searchTerms.termIdsAll);
}
private final long flagsFilterMask =
@ -65,110 +70,97 @@ public class IndexResultValuationContext {
long docMetadata = statefulIndex.getDocumentMetadata(docId);
int htmlFeatures = statefulIndex.getHtmlFeatures(docId);
int maxFlagsCount = 0;
boolean anyAllSynthetic = false;
int maxPositionsSet = 0;
SearchResultItem searchResult = new SearchResultItem(docId,
searchTermVariants.stream().mapToInt(List::size).sum());
for (int querySetId = 0;
querySetId < searchTermVariants.size();
querySetId++)
{
var termList = searchTermVariants.get(querySetId);
SearchResultKeywordScore[] termScoresForSet = new SearchResultKeywordScore[termList.size()];
boolean synthetic = true;
for (int termIdx = 0; termIdx < termList.size(); termIdx++) {
String searchTerm = termList.get(termIdx);
long termMetadata = termMetadataForCombinedDocumentIds.getTermMetadata(
searchTerms.getIdForTerm(searchTerm),
combinedId
);
var score = new SearchResultKeywordScore(
querySetId,
searchTerm,
termMetadata,
docMetadata,
htmlFeatures
);
htmlFeatures,
hasPrioTerm(combinedId));
synthetic &= WordFlags.Synthetic.isPresent(termMetadata);
long[] wordMetas = new long[compiledQuery.size()];
SearchResultKeywordScore[] scores = new SearchResultKeywordScore[compiledQuery.size()];
searchResult.keywordScores.add(score);
for (int i = 0; i < wordMetas.length; i++) {
final long termId = compiledQueryIds.at(i);
final String term = compiledQuery.at(i);
termScoresForSet[termIdx] = score;
wordMetas[i] = termMetadataForCombinedDocumentIds.getTermMetadata(termId, combinedId);
scores[i] = new SearchResultKeywordScore(term, termId, wordMetas[i]);
}
if (!meetsQueryStrategyRequirements(termScoresForSet, queryParams.queryStrategy())) {
continue;
// DANGER: IndexResultValuatorService assumes that searchResult.keywordScores has this specific order, as it needs
// to be able to re-construct its own CompiledQuery<SearchResultKeywordScore> for re-ranking the results. This is
// a very flimsy assumption.
searchResult.keywordScores.addAll(List.of(scores));
CompiledQueryLong wordMetasQuery = new CompiledQueryLong(compiledQuery.root, new CqDataLong(wordMetas));
boolean allSynthetic = !CompiledQueryAggregates.booleanAggregate(wordMetasQuery, WordFlags.Synthetic::isAbsent);
int flagsCount = CompiledQueryAggregates.intMaxMinAggregate(wordMetasQuery, wordMeta -> Long.bitCount(wordMeta & flagsFilterMask));
int positionsCount = CompiledQueryAggregates.intMaxMinAggregate(wordMetasQuery, wordMeta -> Long.bitCount(WordMetadata.decodePositions(wordMeta)));
if (!meetsQueryStrategyRequirements(wordMetasQuery, queryParams.queryStrategy())) {
return null;
}
int minFlagsCount = 8;
int minPositionsSet = 4;
for (var termScore : termScoresForSet) {
final int flagCount = Long.bitCount(termScore.encodedWordMetadata() & flagsFilterMask);
minFlagsCount = Math.min(minFlagsCount, flagCount);
minPositionsSet = Math.min(minPositionsSet, termScore.positionCount());
}
maxFlagsCount = Math.max(maxFlagsCount, minFlagsCount);
maxPositionsSet = Math.max(maxPositionsSet, minPositionsSet);
anyAllSynthetic |= synthetic;
}
if (maxFlagsCount == 0 && !anyAllSynthetic && maxPositionsSet == 0)
if (flagsCount == 0 && !allSynthetic && positionsCount == 0)
return null;
double score = searchResultValuator.calculateSearchResultValue(searchResult.keywordScores,
double score = searchResultValuator.calculateSearchResultValue(
wordMetasQuery,
docMetadata,
htmlFeatures,
5000, // use a dummy value here as it's not present in the index
rankingContext);
if (searchResult.hasPrioTerm) {
score = 0.75 * score;
}
searchResult.setScore(score);
return searchResult;
}
private boolean meetsQueryStrategyRequirements(SearchResultKeywordScore[] termSet, QueryStrategy queryStrategy) {
private boolean hasPrioTerm(long combinedId) {
for (var term : searchTerms.termIdsPrio.array()) {
if (termMetadataForCombinedDocumentIds.hasTermMeta(term, combinedId)) {
return true;
}
}
return false;
}
private boolean meetsQueryStrategyRequirements(CompiledQueryLong queryGraphScores,
QueryStrategy queryStrategy)
{
if (queryStrategy == QueryStrategy.AUTO ||
queryStrategy == QueryStrategy.SENTENCE ||
queryStrategy == QueryStrategy.TOPIC) {
return true;
}
for (var keyword : termSet) {
if (!meetsQueryStrategyRequirements(keyword, queryParams.queryStrategy())) {
return false;
}
return CompiledQueryAggregates.booleanAggregate(queryGraphScores,
docs -> meetsQueryStrategyRequirements(docs, queryParams.queryStrategy()));
}
return true;
}
private boolean meetsQueryStrategyRequirements(SearchResultKeywordScore termScore, QueryStrategy queryStrategy) {
private boolean meetsQueryStrategyRequirements(long wordMeta, QueryStrategy queryStrategy) {
if (queryStrategy == QueryStrategy.REQUIRE_FIELD_SITE) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Site.asBit());
return WordFlags.Site.isPresent(wordMeta);
}
else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_SUBJECT) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Subjects.asBit());
return WordFlags.Subjects.isPresent(wordMeta);
}
else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_TITLE) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.Title.asBit());
return WordFlags.Title.isPresent(wordMeta);
}
else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_URL) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.UrlPath.asBit());
return WordFlags.UrlPath.isPresent(wordMeta);
}
else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_DOMAIN) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.UrlDomain.asBit());
return WordFlags.UrlDomain.isPresent(wordMeta);
}
else if (queryStrategy == QueryStrategy.REQUIRE_FIELD_LINK) {
return WordMetadata.hasFlags(termScore.encodedWordMetadata(), WordFlags.ExternalLink.asBit());
return WordFlags.ExternalLink.isPresent(wordMeta);
}
return true;
}

View File

@ -4,7 +4,12 @@ import com.google.inject.Inject;
import com.google.inject.Singleton;
import gnu.trove.list.TLongList;
import gnu.trove.list.array.TLongArrayList;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongSet;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.compiled.CqDataLong;
import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates;
import nu.marginalia.api.searchquery.model.results.DecoratedSearchResultItem;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.api.searchquery.model.results.SearchResultItem;
@ -13,14 +18,13 @@ import nu.marginalia.index.model.SearchParameters;
import nu.marginalia.index.results.model.ids.CombinedDocIdList;
import nu.marginalia.linkdb.docs.DocumentDbReader;
import nu.marginalia.linkdb.model.DocdbUrlDetail;
import nu.marginalia.model.idx.WordMetadata;
import nu.marginalia.ranking.results.ResultValuator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
@Singleton
public class IndexResultValuatorService {
@ -70,8 +74,7 @@ public class IndexResultValuatorService {
resultIds,
statefulIndex,
rankingContext,
params.subqueries,
params.queryParams);
params);
}
@ -96,12 +99,13 @@ public class IndexResultValuatorService {
item.resultsFromDomain = domainCountFilter.getCount(item);
}
return decorateAndRerank(resultsList, rankingContext);
return decorateAndRerank(resultsList, params.compiledQuery, rankingContext);
}
/** Decorate the result items with additional information from the link database
* and calculate an updated ranking with the additional information */
public List<DecoratedSearchResultItem> decorateAndRerank(List<SearchResultItem> rawResults,
CompiledQuery<String> compiledQuery,
ResultRankingContext rankingContext)
throws SQLException
{
@ -125,13 +129,31 @@ public class IndexResultValuatorService {
continue;
}
resultItems.add(createCombinedItem(result, docData, rankingContext));
// Reconstruct the compiledquery for re-valuation
//
// CAVEAT: This hinges on a very fragile that IndexResultValuationContext puts them in the same
// order as the data for the CompiledQuery<String>.
long[] wordMetas = new long[compiledQuery.size()];
for (int i = 0; i < compiledQuery.size(); i++) {
var score = result.keywordScores.get(i);
wordMetas[i] = score.encodedWordMetadata();
}
CompiledQueryLong metaQuery = new CompiledQueryLong(compiledQuery.root, new CqDataLong(wordMetas));
resultItems.add(createCombinedItem(
result,
docData,
metaQuery,
rankingContext));
}
return resultItems;
}
private DecoratedSearchResultItem createCombinedItem(SearchResultItem result,
DocdbUrlDetail docData,
CompiledQueryLong wordMetas,
ResultRankingContext rankingContext) {
return new DecoratedSearchResultItem(
result,
@ -144,8 +166,33 @@ public class IndexResultValuatorService {
docData.pubYear(),
docData.dataHash(),
docData.wordsTotal(),
resultValuator.calculateSearchResultValue(result.keywordScores, docData.wordsTotal(), rankingContext)
);
bestPositions(wordMetas),
resultValuator.calculateSearchResultValue(wordMetas,
result.encodedDocMetadata,
result.htmlFeatures,
docData.wordsTotal(),
rankingContext)
);
}
private long bestPositions(CompiledQueryLong wordMetas) {
LongSet positionsSet = CompiledQueryAggregates.positionsAggregate(wordMetas, WordMetadata::decodePositions);
int bestPc = 0;
long bestPositions = 0;
var li = positionsSet.longIterator();
while (li.hasNext()) {
long pos = li.nextLong();
int pc = Long.bitCount(pos);
if (pc > bestPc) {
bestPc = pc;
bestPositions = pos;
}
}
return bestPositions;
}
}

View File

@ -6,14 +6,17 @@ import nu.marginalia.index.results.model.ids.TermIdList;
public class QuerySearchTerms {
private final TObjectLongHashMap<String> termToId;
public final TermIdList termIdsAll;
public final TermIdList termIdsPrio;
public final TermCoherenceGroupList coherences;
public QuerySearchTerms(TObjectLongHashMap<String> termToId,
TermIdList termIdsAll,
TermIdList termIdsPrio,
TermCoherenceGroupList coherences) {
this.termToId = termToId;
this.termIdsAll = termIdsAll;
this.termIdsPrio = termIdsPrio;
this.coherences = coherences;
}

View File

@ -18,12 +18,21 @@ public class TermMetadataForCombinedDocumentIds {
public long getTermMetadata(long termId, long combinedId) {
var metaByCombinedId = termdocToMeta.get(termId);
if (metaByCombinedId == null) {
logger.warn("Missing meta for term {}", termId);
return 0;
}
return metaByCombinedId.get(combinedId);
}
public boolean hasTermMeta(long termId, long combinedId) {
var metaByCombinedId = termdocToMeta.get(termId);
if (metaByCombinedId == null) {
return false;
}
return metaByCombinedId.get(combinedId) != 0;
}
public record DocumentsWithMetadata(Long2LongOpenHashMap data) {
public DocumentsWithMetadata(CombinedDocIdList combinedDocIdsAll, DocMetadataList metadata) {
this(new Long2LongOpenHashMap(combinedDocIdsAll.array(), metadata.array()));

View File

@ -1,26 +0,0 @@
package nu.marginalia.ranking.results;
import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore;
import java.util.List;
public record ResultKeywordSet(List<SearchResultKeywordScore> keywords) {
public int length() {
return keywords.size();
}
public boolean isEmpty() { return length() == 0; }
public boolean hasNgram() {
for (var word : keywords) {
if (word.keyword.contains("_")) {
return true;
}
}
return false;
}
@Override
public String toString() {
return "%s[%s]".formatted(getClass().getSimpleName(), keywords);
}
}

View File

@ -1,8 +1,8 @@
package nu.marginalia.ranking.results;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.api.searchquery.model.results.ResultRankingParameters;
import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore;
import nu.marginalia.model.crawl.HtmlFeature;
import nu.marginalia.model.crawl.PubDate;
import nu.marginalia.model.idx.DocumentFlags;
@ -14,33 +14,32 @@ import com.google.inject.Singleton;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
@Singleton
public class ResultValuator {
final static double scalingFactor = 500.;
private final Bm25Factor bm25Factor;
private final TermCoherenceFactor termCoherenceFactor;
private static final Logger logger = LoggerFactory.getLogger(ResultValuator.class);
@Inject
public ResultValuator(Bm25Factor bm25Factor,
TermCoherenceFactor termCoherenceFactor) {
this.bm25Factor = bm25Factor;
public ResultValuator(TermCoherenceFactor termCoherenceFactor) {
this.termCoherenceFactor = termCoherenceFactor;
}
public double calculateSearchResultValue(List<SearchResultKeywordScore> scores,
public double calculateSearchResultValue(CompiledQueryLong wordMeta,
long documentMetadata,
int features,
int length,
ResultRankingContext ctx)
{
int sets = numberOfSets(scores);
if (wordMeta.isEmpty())
return Double.MAX_VALUE;
if (length < 0) {
length = 5000;
}
long documentMetadata = documentMetadata(scores);
int features = htmlFeatures(scores);
var rankingParams = ctx.params;
int rank = DocumentMetadata.decodeRank(documentMetadata);
@ -75,32 +74,17 @@ public class ResultValuator {
+ temporalBias
+ flagsPenalty;
double bestTcf = 0;
double bestBM25F = 0;
double bestBM25P = 0;
double bestBM25PN = 0;
for (int set = 0; set < sets; set++) {
ResultKeywordSet keywordSet = createKeywordSet(scores, set);
if (keywordSet.isEmpty())
continue;
bestTcf = Math.max(bestTcf, rankingParams.tcfWeight * termCoherenceFactor.calculate(keywordSet));
bestBM25P = Math.max(bestBM25P, rankingParams.bm25PrioWeight * bm25Factor.calculateBm25Prio(rankingParams.prioParams, keywordSet, ctx));
bestBM25F = Math.max(bestBM25F, rankingParams.bm25FullWeight * bm25Factor.calculateBm25(rankingParams.fullParams, keywordSet, length, ctx));
if (keywordSet.hasNgram()) {
bestBM25PN = Math.max(bestBM25PN, rankingParams.bm25PrioWeight * bm25Factor.calculateBm25Prio(rankingParams.prioParams, keywordSet, ctx));
}
}
double bestTcf = rankingParams.tcfWeight * termCoherenceFactor.calculate(wordMeta);
double bestBM25F = rankingParams.bm25FullWeight * wordMeta.root.visit(new Bm25FullGraphVisitor(rankingParams.fullParams, wordMeta.data, length, ctx));
double bestBM25P = rankingParams.bm25PrioWeight * wordMeta.root.visit(new Bm25PrioGraphVisitor(rankingParams.prioParams, wordMeta.data, ctx));
double overallPartPositive = Math.max(0, overallPart);
double overallPartNegative = -Math.min(0, overallPart);
// Renormalize to 0...15, where 0 is the best possible score;
// this is a historical artifact of the original ranking function
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + 0.25 * bestBM25PN + overallPartPositive, overallPartNegative);
return normalize(1.5 * bestTcf + bestBM25F + bestBM25P + overallPartPositive, overallPartNegative);
}
private double calculateQualityPenalty(int size, int quality, ResultRankingParameters rankingParams) {
@ -159,51 +143,6 @@ public class ResultValuator {
return (int) -penalty;
}
private long documentMetadata(List<SearchResultKeywordScore> rawScores) {
for (var score : rawScores) {
return score.encodedDocMetadata();
}
return 0;
}
private int htmlFeatures(List<SearchResultKeywordScore> rawScores) {
for (var score : rawScores) {
return score.htmlFeatures();
}
return 0;
}
private ResultKeywordSet createKeywordSet(List<SearchResultKeywordScore> rawScores,
int thisSet)
{
List<SearchResultKeywordScore> scoresList = new ArrayList<>();
for (var score : rawScores) {
if (score.subquery != thisSet)
continue;
// Don't consider synthetic keywords for ranking, these are keywords that don't
// have counts. E.g. "tld:edu"
if (score.isKeywordSpecial())
continue;
scoresList.add(score);
}
return new ResultKeywordSet(scoresList);
}
private int numberOfSets(List<SearchResultKeywordScore> scores) {
int maxSet = 0;
for (var score : scores) {
maxSet = Math.max(maxSet, score.subquery);
}
return 1 + maxSet;
}
public static double normalize(double value, double penalty) {
if (value < 0)
value = 0;

View File

@ -1,122 +0,0 @@
package nu.marginalia.ranking.results.factors;
import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.api.searchquery.model.results.SearchResultKeywordScore;
import nu.marginalia.model.idx.WordFlags;
import nu.marginalia.ranking.results.ResultKeywordSet;
public class Bm25Factor {
private static final int AVG_LENGTH = 5000;
/** This is an estimation of <a href="https://en.wikipedia.org/wiki/Okapi_BM25">BM-25</a>.
*
* @see Bm25Parameters
*/
public double calculateBm25(Bm25Parameters bm25Parameters, ResultKeywordSet keywordSet, int length, ResultRankingContext ctx) {
final int docCount = ctx.termFreqDocCount();
if (length <= 0)
length = AVG_LENGTH;
double sum = 0.;
for (var keyword : keywordSet.keywords()) {
double count = keyword.positionCount();
int freq = ctx.frequency(keyword.keyword);
sum += invFreq(docCount, freq) * f(bm25Parameters.k(), bm25Parameters.b(), count, length);
}
return sum;
}
/** Bm25 calculation, except instead of counting positions in the document,
* the number of relevance signals for the term is counted instead.
*/
public double calculateBm25Prio(Bm25Parameters bm25Parameters, ResultKeywordSet keywordSet, ResultRankingContext ctx) {
final int docCount = ctx.termFreqDocCount();
double sum = 0.;
for (var keyword : keywordSet.keywords()) {
double count = evaluatePriorityScore(keyword);
int freq = ctx.priorityFrequency(keyword.keyword);
// note we override b to zero for priority terms as they are independent of document length
sum += invFreq(docCount, freq) * f(bm25Parameters.k(), 0, count, 0);
}
return sum;
}
private static double evaluatePriorityScore(SearchResultKeywordScore keyword) {
int pcount = keyword.positionCount();
double qcount = 0.;
if ((keyword.encodedWordMetadata() & WordFlags.ExternalLink.asBit()) != 0) {
qcount += 2.5;
if ((keyword.encodedWordMetadata() & WordFlags.UrlDomain.asBit()) != 0)
qcount += 2.5;
else if ((keyword.encodedWordMetadata() & WordFlags.UrlPath.asBit()) != 0)
qcount += 1.5;
if ((keyword.encodedWordMetadata() & WordFlags.Site.asBit()) != 0)
qcount += 1.25;
if ((keyword.encodedWordMetadata() & WordFlags.SiteAdjacent.asBit()) != 0)
qcount += 1.25;
}
else {
if ((keyword.encodedWordMetadata() & WordFlags.UrlDomain.asBit()) != 0)
qcount += 3;
else if ((keyword.encodedWordMetadata() & WordFlags.UrlPath.asBit()) != 0)
qcount += 1;
if ((keyword.encodedWordMetadata() & WordFlags.Site.asBit()) != 0)
qcount += 0.5;
if ((keyword.encodedWordMetadata() & WordFlags.SiteAdjacent.asBit()) != 0)
qcount += 0.5;
}
if ((keyword.encodedWordMetadata() & WordFlags.Title.asBit()) != 0)
qcount += 1.5;
if (pcount > 2) {
if ((keyword.encodedWordMetadata() & WordFlags.Subjects.asBit()) != 0)
qcount += 1.25;
if ((keyword.encodedWordMetadata() & WordFlags.NamesWords.asBit()) != 0)
qcount += 0.25;
if ((keyword.encodedWordMetadata() & WordFlags.TfIdfHigh.asBit()) != 0)
qcount += 0.5;
}
return qcount;
}
/**
*
* @param docCount Number of documents
* @param freq Number of matching documents
*/
private double invFreq(int docCount, int freq) {
return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5));
}
/**
*
* @param k determines the size of the impact of a single term
* @param b determines the magnitude of the length normalization
* @param count number of occurrences in the document
* @param length document length
*/
private double f(double k, double b, double count, int length) {
final double lengthRatio = (double) length / AVG_LENGTH;
return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio));
}
}

View File

@ -0,0 +1,81 @@
package nu.marginalia.ranking.results.factors;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.compiled.CqDataLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.model.idx.WordMetadata;
import java.util.List;
public class Bm25FullGraphVisitor implements CqExpression.DoubleVisitor {
private static final long AVG_LENGTH = 5000;
private final CqDataLong wordMetaData;
private final CqDataInt frequencies;
private final Bm25Parameters bm25Parameters;
private final int docCount;
private final int length;
public Bm25FullGraphVisitor(Bm25Parameters bm25Parameters,
CqDataLong wordMetaData,
int length,
ResultRankingContext ctx) {
this.length = length;
this.bm25Parameters = bm25Parameters;
this.docCount = ctx.termFreqDocCount();
this.wordMetaData = wordMetaData;
this.frequencies = ctx.fullCounts;
}
@Override
public double onAnd(List<? extends CqExpression> parts) {
double value = 0;
for (var part : parts) {
value += part.visit(this);
}
return value;
}
@Override
public double onOr(List<? extends CqExpression> parts) {
double value = 0;
for (var part : parts) {
value = Math.max(value, part.visit(this));
}
return value;
}
@Override
public double onLeaf(int idx) {
double count = Long.bitCount(WordMetadata.decodePositions(wordMetaData.get(idx)));
int freq = frequencies.get(idx);
return invFreq(docCount, freq) * f(bm25Parameters.k(), bm25Parameters.b(), count, length);
}
/**
*
* @param docCount Number of documents
* @param freq Number of matching documents
*/
private double invFreq(int docCount, int freq) {
return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5));
}
/**
*
* @param k determines the size of the impact of a single term
* @param b determines the magnitude of the length normalization
* @param count number of occurrences in the document
* @param length document length
*/
private double f(double k, double b, double count, int length) {
final double lengthRatio = (double) length / AVG_LENGTH;
return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio));
}
}

View File

@ -0,0 +1,127 @@
package nu.marginalia.ranking.results.factors;
import nu.marginalia.api.searchquery.model.compiled.CqDataInt;
import nu.marginalia.api.searchquery.model.compiled.CqDataLong;
import nu.marginalia.api.searchquery.model.compiled.CqExpression;
import nu.marginalia.api.searchquery.model.results.Bm25Parameters;
import nu.marginalia.api.searchquery.model.results.ResultRankingContext;
import nu.marginalia.model.idx.WordFlags;
import nu.marginalia.model.idx.WordMetadata;
import java.util.List;
public class Bm25PrioGraphVisitor implements CqExpression.DoubleVisitor {
private static final long AVG_LENGTH = 5000;
private final CqDataLong wordMetaData;
private final CqDataInt frequencies;
private final Bm25Parameters bm25Parameters;
private final int docCount;
public Bm25PrioGraphVisitor(Bm25Parameters bm25Parameters,
CqDataLong wordMetaData,
ResultRankingContext ctx) {
this.bm25Parameters = bm25Parameters;
this.docCount = ctx.termFreqDocCount();
this.wordMetaData = wordMetaData;
this.frequencies = ctx.fullCounts;
}
@Override
public double onAnd(List<? extends CqExpression> parts) {
double value = 0;
for (var part : parts) {
value += part.visit(this);
}
return value;
}
@Override
public double onOr(List<? extends CqExpression> parts) {
double value = 0;
for (var part : parts) {
value = Math.max(value, part.visit(this));
}
return value;
}
@Override
public double onLeaf(int idx) {
double count = evaluatePriorityScore(wordMetaData.get(idx));
int freq = frequencies.get(idx);
// note we override b to zero for priority terms as they are independent of document length
return invFreq(docCount, freq) * f(bm25Parameters.k(), 0, count, 0);
}
private static double evaluatePriorityScore(long wordMeta) {
int pcount = Long.bitCount(WordMetadata.decodePositions(wordMeta));
double qcount = 0.;
if ((wordMeta & WordFlags.ExternalLink.asBit()) != 0) {
qcount += 2.5;
if ((wordMeta & WordFlags.UrlDomain.asBit()) != 0)
qcount += 2.5;
else if ((wordMeta & WordFlags.UrlPath.asBit()) != 0)
qcount += 1.5;
if ((wordMeta & WordFlags.Site.asBit()) != 0)
qcount += 1.25;
if ((wordMeta & WordFlags.SiteAdjacent.asBit()) != 0)
qcount += 1.25;
}
else {
if ((wordMeta & WordFlags.UrlDomain.asBit()) != 0)
qcount += 3;
else if ((wordMeta & WordFlags.UrlPath.asBit()) != 0)
qcount += 1;
if ((wordMeta & WordFlags.Site.asBit()) != 0)
qcount += 0.5;
if ((wordMeta & WordFlags.SiteAdjacent.asBit()) != 0)
qcount += 0.5;
}
if ((wordMeta & WordFlags.Title.asBit()) != 0)
qcount += 1.5;
if (pcount > 2) {
if ((wordMeta & WordFlags.Subjects.asBit()) != 0)
qcount += 1.25;
if ((wordMeta & WordFlags.NamesWords.asBit()) != 0)
qcount += 0.25;
if ((wordMeta & WordFlags.TfIdfHigh.asBit()) != 0)
qcount += 0.5;
}
return qcount;
}
/**
*
* @param docCount Number of documents
* @param freq Number of matching documents
*/
private double invFreq(int docCount, int freq) {
return Math.log(1.0 + (docCount - freq + 0.5) / (freq + 0.5));
}
/**
*
* @param k determines the size of the impact of a single term
* @param b determines the magnitude of the length normalization
* @param count number of occurrences in the document
* @param length document length
*/
private double f(double k, double b, double count, int length) {
final double lengthRatio = (double) length / AVG_LENGTH;
return (count * (k + 1)) / (count + k * (1 - b + b * lengthRatio));
}
}

View File

@ -1,14 +1,16 @@
package nu.marginalia.ranking.results.factors;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
import nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates;
import nu.marginalia.model.idx.WordMetadata;
import nu.marginalia.ranking.results.ResultKeywordSet;
/** Rewards documents where terms appear frequently within the same sentences
*/
public class TermCoherenceFactor {
public double calculate(ResultKeywordSet keywordSet) {
long mask = combinedMask(keywordSet);
public double calculate(CompiledQueryLong wordMetadataQuery) {
long mask = CompiledQueryAggregates.longBitmaskAggregate(wordMetadataQuery,
score -> score >>> WordMetadata.POSITIONS_SHIFT);
return bitsSetFactor(mask);
}
@ -19,14 +21,5 @@ public class TermCoherenceFactor {
return Math.pow(bitsSetInMask/(float) WordMetadata.POSITIONS_COUNT, 0.25);
}
long combinedMask(ResultKeywordSet keywordSet) {
long mask = WordMetadata.POSITIONS_MASK;
for (var keyword : keywordSet.keywords()) {
mask &= keyword.positions();
}
return mask;
}
}

View File

@ -2,6 +2,8 @@ package nu.marginalia.index.query;
import nu.marginalia.index.query.filter.QueryFilterStepIf;
import java.util.List;
/** Builds a query.
* <p />
* Note: The query builder may omit predicates that are deemed redundant.
@ -9,18 +11,14 @@ import nu.marginalia.index.query.filter.QueryFilterStepIf;
public interface IndexQueryBuilder {
/** Filters documents that also contain termId, within the full index.
*/
IndexQueryBuilder alsoFull(long termId);
/**
* Filters documents that also contain the termId, within the priority index.
*/
IndexQueryBuilder alsoPrio(long termIds);
IndexQueryBuilder also(long termId);
/** Excludes documents that contain termId, within the full index
*/
IndexQueryBuilder notFull(long termId);
IndexQueryBuilder not(long termId);
IndexQueryBuilder addInclusionFilter(QueryFilterStepIf filterStep);
IndexQueryBuilder addInclusionFilterAny(List<QueryFilterStepIf> filterStep);
IndexQuery build();
}

View File

@ -0,0 +1,71 @@
package nu.marginalia.index.query.filter;
import nu.marginalia.array.buffer.LongQueryBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
public class QueryFilterAllOf implements QueryFilterStepIf {
private final List<QueryFilterStepIf> steps;
public QueryFilterAllOf(List<? extends QueryFilterStepIf> steps) {
this.steps = new ArrayList<>(steps.size());
for (var step : steps) {
if (step instanceof QueryFilterAllOf allOf) {
this.steps.addAll(allOf.steps);
}
else {
this.steps.add(step);
}
}
}
public QueryFilterAllOf(QueryFilterStepIf... steps) {
this(List.of(steps));
}
public double cost() {
double prod = 1.;
for (var step : steps) {
double cost = step.cost();
if (cost > 1.0) {
prod *= Math.log(cost);
}
else {
prod += cost;
}
}
return prod;
}
@Override
public boolean test(long value) {
for (var step : steps) {
if (!step.test(value))
return false;
}
return true;
}
public void apply(LongQueryBuffer buffer) {
if (steps.isEmpty())
return;
for (var step : steps) {
step.apply(buffer);
}
}
public String describe() {
StringJoiner sj = new StringJoiner(",", "[All Of: ", "]");
for (var step : steps) {
sj.add(step.describe());
}
return sj.toString();
}
}

View File

@ -2,19 +2,31 @@ package nu.marginalia.index.query.filter;
import nu.marginalia.array.buffer.LongQueryBuffer;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
public class QueryFilterAnyOf implements QueryFilterStepIf {
private final List<? extends QueryFilterStepIf> steps;
private final List<QueryFilterStepIf> steps;
public QueryFilterAnyOf(List<? extends QueryFilterStepIf> steps) {
this.steps = steps;
this.steps = new ArrayList<>(steps.size());
for (var step : steps) {
if (step instanceof QueryFilterAnyOf anyOf) {
this.steps.addAll(anyOf.steps);
} else {
this.steps.add(step);
}
}
}
public QueryFilterAnyOf(QueryFilterStepIf... steps) {
this(List.of(steps));
}
public double cost() {
return steps.stream().mapToDouble(QueryFilterStepIf::cost).average().orElse(0.);
return steps.stream().mapToDouble(QueryFilterStepIf::cost).sum();
}
@Override
@ -31,31 +43,37 @@ public class QueryFilterAnyOf implements QueryFilterStepIf {
if (steps.isEmpty())
return;
int start;
int end = buffer.end;
if (steps.size() == 1) {
steps.getFirst().apply(buffer);
// The filter functions will partition the data in the buffer from 0 to END,
// and update END to the length of the retained items, keeping the retained
// items sorted but making no guarantees about the rejected half
//
// Therefore, we need to re-sort the rejected side, and to satisfy the
// constraint that the data is sorted up to END, finally sort it again.
//
// This sorting may seem like it's slower, but filter.apply(...) is
// typically much faster than iterating over filter.test(...); so this
// is more than made up for
for (int fi = 1; fi < steps.size(); fi++)
{
start = buffer.end;
Arrays.sort(buffer.data, start, end);
buffer.startFilterForRange(start, end);
steps.get(fi).apply(buffer);
return;
}
Arrays.sort(buffer.data, 0, buffer.end);
int start = 0;
final int endOfValidData = buffer.end; // End of valid data range
// The filters act as a partitioning function, where anything before buffer.end
// is "in", and is guaranteed to be sorted; and anything after buffer.end is "out"
// but no sorting guaranteed is provided.
// To provide a conditional filter, we re-sort the "out" range, slice it and apply filtering to the slice
for (var step : steps)
{
var slice = buffer.slice(start, endOfValidData);
slice.data.quickSort(0, slice.size());
step.apply(slice);
start += slice.end;
}
// After we're done, read and write pointers should be 0 and "end" should be the length of valid data,
// normally done through buffer.finalizeFiltering(); but that won't work here
buffer.reset();
buffer.end = start;
// After all filters have been applied, we must re-sort all the retained data
// to uphold the sortedness contract
buffer.data.quickSort(0, buffer.end);
}
public String describe() {

View File

@ -16,7 +16,7 @@ public class QueryFilterLetThrough implements QueryFilterStepIf {
}
public double cost() {
return 0.;
return 1.;
}
public String describe() {

Some files were not shown because too many files have changed in this diff Show More