Fix typeahead suggestions

This commit is contained in:
Viktor Lofgren 2023-03-25 10:20:44 +01:00
parent 2f2c86a9f5
commit 3464ca514b
4 changed files with 62 additions and 44 deletions

View File

@ -63,7 +63,7 @@ class ResultValuatorTest {
@Test
void evaluateTerms() {
when(dict.getTermFreq("bob")).thenReturn(10L);
when(dict.getTermFreq("bob")).thenReturn(10);
SearchResultRankingContext context = new SearchResultRankingContext(100000,
Map.of("bob", 10));

View File

@ -65,17 +65,17 @@ public class TermFrequencyDict {
}
/** Get the term frequency for the string s */
public long getTermFreq(String s) {
public int getTermFreq(String s) {
return wordRates.get(getStringHash(s));
}
/** Get the term frequency for the already stemmed string s */
public long getTermFreqStemmed(String s) {
public int getTermFreqStemmed(String s) {
return wordRates.get(longHash(s.getBytes()));
}
/** Get the term frequency for the already stemmed and already hashed value 'hash' */
public long getTermFreqHash(long hash) {
public int getTermFreqHash(long hash) {
return wordRates.get(hash);
}

View File

@ -6,6 +6,7 @@ import nu.marginalia.term_frequency_dict.TermFrequencyDict;
import nu.marginalia.model.crawl.HtmlFeature;
import nu.marginalia.assistant.dict.SpellChecker;
import org.apache.commons.collections4.trie.PatriciaTrie;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -13,7 +14,6 @@ import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@ -54,11 +54,12 @@ public class Suggestions {
.map(String::toLowerCase)
.forEach(w -> ret.put(w, w));
// Add special keywords to the suggestions
for (var feature : HtmlFeature.values()) {
String keyword = feature.getKeyword();
ret.put(keyword, keyword);
ret.put("-" + keyword, "-"+ keyword);
ret.put("-" + keyword, "-" + keyword);
}
return ret;
@ -69,39 +70,38 @@ public class Suggestions {
}
}
private record SuggestionStream(String prefix, Stream<String> suggestionStream) {
public Stream<String> stream() {
return suggestionStream.map(s -> prefix + s);
}
}
public List<String> getSuggestions(int count, String searchWord) {
if (searchWord.length() < MIN_SUGGEST_LENGTH) {
return Collections.emptyList();
}
searchWord = trimLeading(searchWord.toLowerCase());
searchWord = StringUtils.stripStart(searchWord.toLowerCase(), " ");
List<SuggestionStream> streams = new ArrayList<>(4);
streams.add(new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)));
int sp = searchWord.lastIndexOf(' ');
if (sp >= 0) {
String prefixString = searchWord.substring(0, sp+1);
String suggestString = searchWord.substring(sp+1);
if (suggestString.length() >= MIN_SUGGEST_LENGTH) {
streams.add(new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString)));
}
}
streams.add(spellCheckStream(searchWord));
return streams.stream().flatMap(SuggestionStream::stream).limit(count).collect(Collectors.toList());
return Stream.of(
new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)),
suggestionsForLastWord(count, searchWord),
spellCheckStream(searchWord)
)
.flatMap(SuggestionsStreamable::stream)
.limit(count)
.collect(Collectors.toList());
}
private SuggestionStream spellCheckStream(String word) {
private SuggestionsStreamable suggestionsForLastWord(int count, String searchWord) {
int sp = searchWord.lastIndexOf(' ');
if (sp < 0) {
return Stream::empty;
}
String prefixString = searchWord.substring(0, sp+1);
String suggestString = searchWord.substring(sp+1);
return new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString));
}
private SuggestionsStreamable spellCheckStream(String word) {
int start = word.lastIndexOf(' ');
String prefix;
String corrWord;
@ -120,21 +120,16 @@ public class Suggestions {
return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get));
}
else {
return new SuggestionStream("", Stream.empty());
return Stream::empty;
}
}
private String trimLeading(String word) {
for (int i = 0; i < word.length(); i++) {
if (!Character.isWhitespace(word.charAt(i)))
return word.substring(i);
}
return "";
}
public Stream<String> getSuggestionsForKeyword(int count, String prefix) {
if (prefix.length() < MIN_SUGGEST_LENGTH) {
return Stream.empty();
}
var start = suggestionsTrie.select(prefix);
if (start == null) {
@ -145,14 +140,30 @@ public class Suggestions {
return Stream.empty();
}
Map<String, Long> scach = new HashMap<>(512);
Function<String, Long> valr = s -> -termFrequencyDict.getTermFreqHash(scach.computeIfAbsent(s, TermFrequencyDict::getStringHash));
SuggestionsValueCalculator sv = new SuggestionsValueCalculator();
return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey)
.takeWhile(s -> s.startsWith(prefix))
.limit(256)
.sorted(Comparator.comparing(valr).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
.sorted(Comparator.comparing(sv::get).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
.limit(count);
}
private record SuggestionStream(String prefix, Stream<String> suggestionStream) implements SuggestionsStreamable {
public Stream<String> stream() {
return suggestionStream.map(s -> prefix + s);
}
}
interface SuggestionsStreamable { Stream<String> stream(); }
private class SuggestionsValueCalculator {
private final Map<String, Long> hashCache = new HashMap<>(512);
public int get(String s) {
long hash = hashCache.computeIfAbsent(s, TermFrequencyDict::getStringHash);
return -termFrequencyDict.getTermFreqHash(hash);
}
}
}

View File

@ -5,6 +5,7 @@ x-svc: &service
- vol:/vol
- conf:/wmsa/conf:ro
- model:/wmsa/model
- data:/wmsa/data
- logs:/var/log/wmsa
networks:
- wmsa
@ -126,3 +127,9 @@ volumes:
type: none
o: bind
device: run/conf
data:
driver: local
driver_opts:
type: none
o: bind
device: run/data