mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-02-23 13:09:00 +00:00
Fix typeahead suggestions
This commit is contained in:
parent
2f2c86a9f5
commit
3464ca514b
@ -63,7 +63,7 @@ class ResultValuatorTest {
|
||||
@Test
|
||||
void evaluateTerms() {
|
||||
|
||||
when(dict.getTermFreq("bob")).thenReturn(10L);
|
||||
when(dict.getTermFreq("bob")).thenReturn(10);
|
||||
SearchResultRankingContext context = new SearchResultRankingContext(100000,
|
||||
Map.of("bob", 10));
|
||||
|
||||
|
@ -65,17 +65,17 @@ public class TermFrequencyDict {
|
||||
}
|
||||
|
||||
/** Get the term frequency for the string s */
|
||||
public long getTermFreq(String s) {
|
||||
public int getTermFreq(String s) {
|
||||
return wordRates.get(getStringHash(s));
|
||||
}
|
||||
|
||||
/** Get the term frequency for the already stemmed string s */
|
||||
public long getTermFreqStemmed(String s) {
|
||||
public int getTermFreqStemmed(String s) {
|
||||
return wordRates.get(longHash(s.getBytes()));
|
||||
}
|
||||
|
||||
/** Get the term frequency for the already stemmed and already hashed value 'hash' */
|
||||
public long getTermFreqHash(long hash) {
|
||||
public int getTermFreqHash(long hash) {
|
||||
return wordRates.get(hash);
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ import nu.marginalia.term_frequency_dict.TermFrequencyDict;
|
||||
import nu.marginalia.model.crawl.HtmlFeature;
|
||||
import nu.marginalia.assistant.dict.SpellChecker;
|
||||
import org.apache.commons.collections4.trie.PatriciaTrie;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@ -13,7 +14,6 @@ import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
@ -54,11 +54,12 @@ public class Suggestions {
|
||||
.map(String::toLowerCase)
|
||||
.forEach(w -> ret.put(w, w));
|
||||
|
||||
// Add special keywords to the suggestions
|
||||
for (var feature : HtmlFeature.values()) {
|
||||
String keyword = feature.getKeyword();
|
||||
|
||||
ret.put(keyword, keyword);
|
||||
ret.put("-" + keyword, "-"+ keyword);
|
||||
ret.put("-" + keyword, "-" + keyword);
|
||||
}
|
||||
|
||||
return ret;
|
||||
@ -69,39 +70,38 @@ public class Suggestions {
|
||||
}
|
||||
}
|
||||
|
||||
private record SuggestionStream(String prefix, Stream<String> suggestionStream) {
|
||||
public Stream<String> stream() {
|
||||
return suggestionStream.map(s -> prefix + s);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public List<String> getSuggestions(int count, String searchWord) {
|
||||
if (searchWord.length() < MIN_SUGGEST_LENGTH) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
searchWord = trimLeading(searchWord.toLowerCase());
|
||||
searchWord = StringUtils.stripStart(searchWord.toLowerCase(), " ");
|
||||
|
||||
List<SuggestionStream> streams = new ArrayList<>(4);
|
||||
streams.add(new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)));
|
||||
|
||||
int sp = searchWord.lastIndexOf(' ');
|
||||
if (sp >= 0) {
|
||||
String prefixString = searchWord.substring(0, sp+1);
|
||||
String suggestString = searchWord.substring(sp+1);
|
||||
|
||||
if (suggestString.length() >= MIN_SUGGEST_LENGTH) {
|
||||
streams.add(new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString)));
|
||||
}
|
||||
|
||||
}
|
||||
streams.add(spellCheckStream(searchWord));
|
||||
|
||||
return streams.stream().flatMap(SuggestionStream::stream).limit(count).collect(Collectors.toList());
|
||||
return Stream.of(
|
||||
new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)),
|
||||
suggestionsForLastWord(count, searchWord),
|
||||
spellCheckStream(searchWord)
|
||||
)
|
||||
.flatMap(SuggestionsStreamable::stream)
|
||||
.limit(count)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private SuggestionStream spellCheckStream(String word) {
|
||||
private SuggestionsStreamable suggestionsForLastWord(int count, String searchWord) {
|
||||
int sp = searchWord.lastIndexOf(' ');
|
||||
|
||||
if (sp < 0) {
|
||||
return Stream::empty;
|
||||
}
|
||||
|
||||
String prefixString = searchWord.substring(0, sp+1);
|
||||
String suggestString = searchWord.substring(sp+1);
|
||||
|
||||
return new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString));
|
||||
|
||||
}
|
||||
|
||||
private SuggestionsStreamable spellCheckStream(String word) {
|
||||
int start = word.lastIndexOf(' ');
|
||||
String prefix;
|
||||
String corrWord;
|
||||
@ -120,21 +120,16 @@ public class Suggestions {
|
||||
return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get));
|
||||
}
|
||||
else {
|
||||
return new SuggestionStream("", Stream.empty());
|
||||
return Stream::empty;
|
||||
}
|
||||
}
|
||||
|
||||
private String trimLeading(String word) {
|
||||
|
||||
for (int i = 0; i < word.length(); i++) {
|
||||
if (!Character.isWhitespace(word.charAt(i)))
|
||||
return word.substring(i);
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
public Stream<String> getSuggestionsForKeyword(int count, String prefix) {
|
||||
if (prefix.length() < MIN_SUGGEST_LENGTH) {
|
||||
return Stream.empty();
|
||||
}
|
||||
|
||||
var start = suggestionsTrie.select(prefix);
|
||||
|
||||
if (start == null) {
|
||||
@ -145,14 +140,30 @@ public class Suggestions {
|
||||
return Stream.empty();
|
||||
}
|
||||
|
||||
Map<String, Long> scach = new HashMap<>(512);
|
||||
Function<String, Long> valr = s -> -termFrequencyDict.getTermFreqHash(scach.computeIfAbsent(s, TermFrequencyDict::getStringHash));
|
||||
SuggestionsValueCalculator sv = new SuggestionsValueCalculator();
|
||||
|
||||
return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey)
|
||||
.takeWhile(s -> s.startsWith(prefix))
|
||||
.limit(256)
|
||||
.sorted(Comparator.comparing(valr).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
|
||||
.sorted(Comparator.comparing(sv::get).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
|
||||
.limit(count);
|
||||
}
|
||||
|
||||
private record SuggestionStream(String prefix, Stream<String> suggestionStream) implements SuggestionsStreamable {
|
||||
public Stream<String> stream() {
|
||||
return suggestionStream.map(s -> prefix + s);
|
||||
}
|
||||
}
|
||||
|
||||
interface SuggestionsStreamable { Stream<String> stream(); }
|
||||
|
||||
private class SuggestionsValueCalculator {
|
||||
|
||||
private final Map<String, Long> hashCache = new HashMap<>(512);
|
||||
|
||||
public int get(String s) {
|
||||
long hash = hashCache.computeIfAbsent(s, TermFrequencyDict::getStringHash);
|
||||
return -termFrequencyDict.getTermFreqHash(hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ x-svc: &service
|
||||
- vol:/vol
|
||||
- conf:/wmsa/conf:ro
|
||||
- model:/wmsa/model
|
||||
- data:/wmsa/data
|
||||
- logs:/var/log/wmsa
|
||||
networks:
|
||||
- wmsa
|
||||
@ -126,3 +127,9 @@ volumes:
|
||||
type: none
|
||||
o: bind
|
||||
device: run/conf
|
||||
data:
|
||||
driver: local
|
||||
driver_opts:
|
||||
type: none
|
||||
o: bind
|
||||
device: run/data
|
Loading…
Reference in New Issue
Block a user