Fix typeahead suggestions

This commit is contained in:
Viktor Lofgren 2023-03-25 10:20:44 +01:00
parent 2f2c86a9f5
commit 3464ca514b
4 changed files with 62 additions and 44 deletions

View File

@ -63,7 +63,7 @@ class ResultValuatorTest {
@Test @Test
void evaluateTerms() { void evaluateTerms() {
when(dict.getTermFreq("bob")).thenReturn(10L); when(dict.getTermFreq("bob")).thenReturn(10);
SearchResultRankingContext context = new SearchResultRankingContext(100000, SearchResultRankingContext context = new SearchResultRankingContext(100000,
Map.of("bob", 10)); Map.of("bob", 10));

View File

@ -65,17 +65,17 @@ public class TermFrequencyDict {
} }
/** Get the term frequency for the string s */ /** Get the term frequency for the string s */
public long getTermFreq(String s) { public int getTermFreq(String s) {
return wordRates.get(getStringHash(s)); return wordRates.get(getStringHash(s));
} }
/** Get the term frequency for the already stemmed string s */ /** Get the term frequency for the already stemmed string s */
public long getTermFreqStemmed(String s) { public int getTermFreqStemmed(String s) {
return wordRates.get(longHash(s.getBytes())); return wordRates.get(longHash(s.getBytes()));
} }
/** Get the term frequency for the already stemmed and already hashed value 'hash' */ /** Get the term frequency for the already stemmed and already hashed value 'hash' */
public long getTermFreqHash(long hash) { public int getTermFreqHash(long hash) {
return wordRates.get(hash); return wordRates.get(hash);
} }

View File

@ -6,6 +6,7 @@ import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import nu.marginalia.model.crawl.HtmlFeature; import nu.marginalia.model.crawl.HtmlFeature;
import nu.marginalia.assistant.dict.SpellChecker; import nu.marginalia.assistant.dict.SpellChecker;
import org.apache.commons.collections4.trie.PatriciaTrie; import org.apache.commons.collections4.trie.PatriciaTrie;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -13,7 +14,6 @@ import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.*; import java.util.*;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -54,11 +54,12 @@ public class Suggestions {
.map(String::toLowerCase) .map(String::toLowerCase)
.forEach(w -> ret.put(w, w)); .forEach(w -> ret.put(w, w));
// Add special keywords to the suggestions
for (var feature : HtmlFeature.values()) { for (var feature : HtmlFeature.values()) {
String keyword = feature.getKeyword(); String keyword = feature.getKeyword();
ret.put(keyword, keyword); ret.put(keyword, keyword);
ret.put("-" + keyword, "-"+ keyword); ret.put("-" + keyword, "-" + keyword);
} }
return ret; return ret;
@ -69,39 +70,38 @@ public class Suggestions {
} }
} }
private record SuggestionStream(String prefix, Stream<String> suggestionStream) {
public Stream<String> stream() {
return suggestionStream.map(s -> prefix + s);
}
}
public List<String> getSuggestions(int count, String searchWord) { public List<String> getSuggestions(int count, String searchWord) {
if (searchWord.length() < MIN_SUGGEST_LENGTH) { if (searchWord.length() < MIN_SUGGEST_LENGTH) {
return Collections.emptyList(); return Collections.emptyList();
} }
searchWord = trimLeading(searchWord.toLowerCase()); searchWord = StringUtils.stripStart(searchWord.toLowerCase(), " ");
List<SuggestionStream> streams = new ArrayList<>(4); return Stream.of(
streams.add(new SuggestionStream("", getSuggestionsForKeyword(count, searchWord))); new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)),
suggestionsForLastWord(count, searchWord),
int sp = searchWord.lastIndexOf(' '); spellCheckStream(searchWord)
if (sp >= 0) { )
String prefixString = searchWord.substring(0, sp+1); .flatMap(SuggestionsStreamable::stream)
String suggestString = searchWord.substring(sp+1); .limit(count)
.collect(Collectors.toList());
if (suggestString.length() >= MIN_SUGGEST_LENGTH) {
streams.add(new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString)));
}
}
streams.add(spellCheckStream(searchWord));
return streams.stream().flatMap(SuggestionStream::stream).limit(count).collect(Collectors.toList());
} }
private SuggestionStream spellCheckStream(String word) { private SuggestionsStreamable suggestionsForLastWord(int count, String searchWord) {
int sp = searchWord.lastIndexOf(' ');
if (sp < 0) {
return Stream::empty;
}
String prefixString = searchWord.substring(0, sp+1);
String suggestString = searchWord.substring(sp+1);
return new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString));
}
private SuggestionsStreamable spellCheckStream(String word) {
int start = word.lastIndexOf(' '); int start = word.lastIndexOf(' ');
String prefix; String prefix;
String corrWord; String corrWord;
@ -120,21 +120,16 @@ public class Suggestions {
return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get)); return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get));
} }
else { else {
return new SuggestionStream("", Stream.empty()); return Stream::empty;
} }
} }
private String trimLeading(String word) {
for (int i = 0; i < word.length(); i++) {
if (!Character.isWhitespace(word.charAt(i)))
return word.substring(i);
}
return "";
}
public Stream<String> getSuggestionsForKeyword(int count, String prefix) { public Stream<String> getSuggestionsForKeyword(int count, String prefix) {
if (prefix.length() < MIN_SUGGEST_LENGTH) {
return Stream.empty();
}
var start = suggestionsTrie.select(prefix); var start = suggestionsTrie.select(prefix);
if (start == null) { if (start == null) {
@ -145,14 +140,30 @@ public class Suggestions {
return Stream.empty(); return Stream.empty();
} }
Map<String, Long> scach = new HashMap<>(512); SuggestionsValueCalculator sv = new SuggestionsValueCalculator();
Function<String, Long> valr = s -> -termFrequencyDict.getTermFreqHash(scach.computeIfAbsent(s, TermFrequencyDict::getStringHash));
return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey) return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey)
.takeWhile(s -> s.startsWith(prefix)) .takeWhile(s -> s.startsWith(prefix))
.limit(256) .limit(256)
.sorted(Comparator.comparing(valr).thenComparing(String::length).thenComparing(Comparator.naturalOrder())) .sorted(Comparator.comparing(sv::get).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
.limit(count); .limit(count);
} }
private record SuggestionStream(String prefix, Stream<String> suggestionStream) implements SuggestionsStreamable {
public Stream<String> stream() {
return suggestionStream.map(s -> prefix + s);
}
}
interface SuggestionsStreamable { Stream<String> stream(); }
private class SuggestionsValueCalculator {
private final Map<String, Long> hashCache = new HashMap<>(512);
public int get(String s) {
long hash = hashCache.computeIfAbsent(s, TermFrequencyDict::getStringHash);
return -termFrequencyDict.getTermFreqHash(hash);
}
}
} }

View File

@ -5,6 +5,7 @@ x-svc: &service
- vol:/vol - vol:/vol
- conf:/wmsa/conf:ro - conf:/wmsa/conf:ro
- model:/wmsa/model - model:/wmsa/model
- data:/wmsa/data
- logs:/var/log/wmsa - logs:/var/log/wmsa
networks: networks:
- wmsa - wmsa
@ -126,3 +127,9 @@ volumes:
type: none type: none
o: bind o: bind
device: run/conf device: run/conf
data:
driver: local
driver_opts:
type: none
o: bind
device: run/data