From 3464ca514bc3f4e4f19cf435188b7b76415a24b9 Mon Sep 17 00:00:00 2001 From: Viktor Lofgren Date: Sat, 25 Mar 2023 10:20:44 +0100 Subject: [PATCH] Fix typeahead suggestions --- .../ranking/ResultValuatorTest.java | 2 +- .../TermFrequencyDict.java | 6 +- .../assistant/suggest/Suggestions.java | 91 +++++++++++-------- docker-compose.yml | 7 ++ 4 files changed, 62 insertions(+), 44 deletions(-) diff --git a/code/features-search/result-ranking/src/test/java/nu/marginalia/ranking/ResultValuatorTest.java b/code/features-search/result-ranking/src/test/java/nu/marginalia/ranking/ResultValuatorTest.java index 3305c015..b72e11a3 100644 --- a/code/features-search/result-ranking/src/test/java/nu/marginalia/ranking/ResultValuatorTest.java +++ b/code/features-search/result-ranking/src/test/java/nu/marginalia/ranking/ResultValuatorTest.java @@ -63,7 +63,7 @@ class ResultValuatorTest { @Test void evaluateTerms() { - when(dict.getTermFreq("bob")).thenReturn(10L); + when(dict.getTermFreq("bob")).thenReturn(10); SearchResultRankingContext context = new SearchResultRankingContext(100000, Map.of("bob", 10)); diff --git a/code/libraries/term-frequency-dict/src/main/java/nu/marginalia/term_frequency_dict/TermFrequencyDict.java b/code/libraries/term-frequency-dict/src/main/java/nu/marginalia/term_frequency_dict/TermFrequencyDict.java index 36a61827..7778aa97 100644 --- a/code/libraries/term-frequency-dict/src/main/java/nu/marginalia/term_frequency_dict/TermFrequencyDict.java +++ b/code/libraries/term-frequency-dict/src/main/java/nu/marginalia/term_frequency_dict/TermFrequencyDict.java @@ -65,17 +65,17 @@ public class TermFrequencyDict { } /** Get the term frequency for the string s */ - public long getTermFreq(String s) { + public int getTermFreq(String s) { return wordRates.get(getStringHash(s)); } /** Get the term frequency for the already stemmed string s */ - public long getTermFreqStemmed(String s) { + public int getTermFreqStemmed(String s) { return wordRates.get(longHash(s.getBytes())); } /** Get the term frequency for the already stemmed and already hashed value 'hash' */ - public long getTermFreqHash(long hash) { + public int getTermFreqHash(long hash) { return wordRates.get(hash); } diff --git a/code/services-core/assistant-service/src/main/java/nu/marginalia/assistant/suggest/Suggestions.java b/code/services-core/assistant-service/src/main/java/nu/marginalia/assistant/suggest/Suggestions.java index 0ea63f08..a75ab75a 100644 --- a/code/services-core/assistant-service/src/main/java/nu/marginalia/assistant/suggest/Suggestions.java +++ b/code/services-core/assistant-service/src/main/java/nu/marginalia/assistant/suggest/Suggestions.java @@ -6,6 +6,7 @@ import nu.marginalia.term_frequency_dict.TermFrequencyDict; import nu.marginalia.model.crawl.HtmlFeature; import nu.marginalia.assistant.dict.SpellChecker; import org.apache.commons.collections4.trie.PatriciaTrie; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -13,7 +14,6 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.*; -import java.util.function.Function; import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -54,11 +54,12 @@ public class Suggestions { .map(String::toLowerCase) .forEach(w -> ret.put(w, w)); + // Add special keywords to the suggestions for (var feature : HtmlFeature.values()) { String keyword = feature.getKeyword(); ret.put(keyword, keyword); - ret.put("-" + keyword, "-"+ keyword); + ret.put("-" + keyword, "-" + keyword); } return ret; @@ -69,39 +70,38 @@ public class Suggestions { } } - private record SuggestionStream(String prefix, Stream suggestionStream) { - public Stream stream() { - return suggestionStream.map(s -> prefix + s); - } - - } - public List getSuggestions(int count, String searchWord) { if (searchWord.length() < MIN_SUGGEST_LENGTH) { return Collections.emptyList(); } - searchWord = trimLeading(searchWord.toLowerCase()); + searchWord = StringUtils.stripStart(searchWord.toLowerCase(), " "); - List streams = new ArrayList<>(4); - streams.add(new SuggestionStream("", getSuggestionsForKeyword(count, searchWord))); - - int sp = searchWord.lastIndexOf(' '); - if (sp >= 0) { - String prefixString = searchWord.substring(0, sp+1); - String suggestString = searchWord.substring(sp+1); - - if (suggestString.length() >= MIN_SUGGEST_LENGTH) { - streams.add(new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString))); - } - - } - streams.add(spellCheckStream(searchWord)); - - return streams.stream().flatMap(SuggestionStream::stream).limit(count).collect(Collectors.toList()); + return Stream.of( + new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)), + suggestionsForLastWord(count, searchWord), + spellCheckStream(searchWord) + ) + .flatMap(SuggestionsStreamable::stream) + .limit(count) + .collect(Collectors.toList()); } - private SuggestionStream spellCheckStream(String word) { + private SuggestionsStreamable suggestionsForLastWord(int count, String searchWord) { + int sp = searchWord.lastIndexOf(' '); + + if (sp < 0) { + return Stream::empty; + } + + String prefixString = searchWord.substring(0, sp+1); + String suggestString = searchWord.substring(sp+1); + + return new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString)); + + } + + private SuggestionsStreamable spellCheckStream(String word) { int start = word.lastIndexOf(' '); String prefix; String corrWord; @@ -120,21 +120,16 @@ public class Suggestions { return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get)); } else { - return new SuggestionStream("", Stream.empty()); + return Stream::empty; } } - private String trimLeading(String word) { - - for (int i = 0; i < word.length(); i++) { - if (!Character.isWhitespace(word.charAt(i))) - return word.substring(i); - } - - return ""; - } public Stream getSuggestionsForKeyword(int count, String prefix) { + if (prefix.length() < MIN_SUGGEST_LENGTH) { + return Stream.empty(); + } + var start = suggestionsTrie.select(prefix); if (start == null) { @@ -145,14 +140,30 @@ public class Suggestions { return Stream.empty(); } - Map scach = new HashMap<>(512); - Function valr = s -> -termFrequencyDict.getTermFreqHash(scach.computeIfAbsent(s, TermFrequencyDict::getStringHash)); + SuggestionsValueCalculator sv = new SuggestionsValueCalculator(); return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey) .takeWhile(s -> s.startsWith(prefix)) .limit(256) - .sorted(Comparator.comparing(valr).thenComparing(String::length).thenComparing(Comparator.naturalOrder())) + .sorted(Comparator.comparing(sv::get).thenComparing(String::length).thenComparing(Comparator.naturalOrder())) .limit(count); } + private record SuggestionStream(String prefix, Stream suggestionStream) implements SuggestionsStreamable { + public Stream stream() { + return suggestionStream.map(s -> prefix + s); + } + } + + interface SuggestionsStreamable { Stream stream(); } + + private class SuggestionsValueCalculator { + + private final Map hashCache = new HashMap<>(512); + + public int get(String s) { + long hash = hashCache.computeIfAbsent(s, TermFrequencyDict::getStringHash); + return -termFrequencyDict.getTermFreqHash(hash); + } + } } diff --git a/docker-compose.yml b/docker-compose.yml index ac5d933d..66094aa4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,6 +5,7 @@ x-svc: &service - vol:/vol - conf:/wmsa/conf:ro - model:/wmsa/model + - data:/wmsa/data - logs:/var/log/wmsa networks: - wmsa @@ -126,3 +127,9 @@ volumes: type: none o: bind device: run/conf + data: + driver: local + driver_opts: + type: none + o: bind + device: run/data \ No newline at end of file