mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-02-23 21:18:58 +00:00
Fix typeahead suggestions
This commit is contained in:
parent
2f2c86a9f5
commit
3464ca514b
@ -63,7 +63,7 @@ class ResultValuatorTest {
|
|||||||
@Test
|
@Test
|
||||||
void evaluateTerms() {
|
void evaluateTerms() {
|
||||||
|
|
||||||
when(dict.getTermFreq("bob")).thenReturn(10L);
|
when(dict.getTermFreq("bob")).thenReturn(10);
|
||||||
SearchResultRankingContext context = new SearchResultRankingContext(100000,
|
SearchResultRankingContext context = new SearchResultRankingContext(100000,
|
||||||
Map.of("bob", 10));
|
Map.of("bob", 10));
|
||||||
|
|
||||||
|
@ -65,17 +65,17 @@ public class TermFrequencyDict {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Get the term frequency for the string s */
|
/** Get the term frequency for the string s */
|
||||||
public long getTermFreq(String s) {
|
public int getTermFreq(String s) {
|
||||||
return wordRates.get(getStringHash(s));
|
return wordRates.get(getStringHash(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get the term frequency for the already stemmed string s */
|
/** Get the term frequency for the already stemmed string s */
|
||||||
public long getTermFreqStemmed(String s) {
|
public int getTermFreqStemmed(String s) {
|
||||||
return wordRates.get(longHash(s.getBytes()));
|
return wordRates.get(longHash(s.getBytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Get the term frequency for the already stemmed and already hashed value 'hash' */
|
/** Get the term frequency for the already stemmed and already hashed value 'hash' */
|
||||||
public long getTermFreqHash(long hash) {
|
public int getTermFreqHash(long hash) {
|
||||||
return wordRates.get(hash);
|
return wordRates.get(hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import nu.marginalia.term_frequency_dict.TermFrequencyDict;
|
|||||||
import nu.marginalia.model.crawl.HtmlFeature;
|
import nu.marginalia.model.crawl.HtmlFeature;
|
||||||
import nu.marginalia.assistant.dict.SpellChecker;
|
import nu.marginalia.assistant.dict.SpellChecker;
|
||||||
import org.apache.commons.collections4.trie.PatriciaTrie;
|
import org.apache.commons.collections4.trie.PatriciaTrie;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
@ -13,7 +14,6 @@ import java.io.IOException;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
@ -54,11 +54,12 @@ public class Suggestions {
|
|||||||
.map(String::toLowerCase)
|
.map(String::toLowerCase)
|
||||||
.forEach(w -> ret.put(w, w));
|
.forEach(w -> ret.put(w, w));
|
||||||
|
|
||||||
|
// Add special keywords to the suggestions
|
||||||
for (var feature : HtmlFeature.values()) {
|
for (var feature : HtmlFeature.values()) {
|
||||||
String keyword = feature.getKeyword();
|
String keyword = feature.getKeyword();
|
||||||
|
|
||||||
ret.put(keyword, keyword);
|
ret.put(keyword, keyword);
|
||||||
ret.put("-" + keyword, "-"+ keyword);
|
ret.put("-" + keyword, "-" + keyword);
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
@ -69,39 +70,38 @@ public class Suggestions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private record SuggestionStream(String prefix, Stream<String> suggestionStream) {
|
|
||||||
public Stream<String> stream() {
|
|
||||||
return suggestionStream.map(s -> prefix + s);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getSuggestions(int count, String searchWord) {
|
public List<String> getSuggestions(int count, String searchWord) {
|
||||||
if (searchWord.length() < MIN_SUGGEST_LENGTH) {
|
if (searchWord.length() < MIN_SUGGEST_LENGTH) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
searchWord = trimLeading(searchWord.toLowerCase());
|
searchWord = StringUtils.stripStart(searchWord.toLowerCase(), " ");
|
||||||
|
|
||||||
List<SuggestionStream> streams = new ArrayList<>(4);
|
return Stream.of(
|
||||||
streams.add(new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)));
|
new SuggestionStream("", getSuggestionsForKeyword(count, searchWord)),
|
||||||
|
suggestionsForLastWord(count, searchWord),
|
||||||
int sp = searchWord.lastIndexOf(' ');
|
spellCheckStream(searchWord)
|
||||||
if (sp >= 0) {
|
)
|
||||||
String prefixString = searchWord.substring(0, sp+1);
|
.flatMap(SuggestionsStreamable::stream)
|
||||||
String suggestString = searchWord.substring(sp+1);
|
.limit(count)
|
||||||
|
.collect(Collectors.toList());
|
||||||
if (suggestString.length() >= MIN_SUGGEST_LENGTH) {
|
|
||||||
streams.add(new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString)));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
streams.add(spellCheckStream(searchWord));
|
|
||||||
|
|
||||||
return streams.stream().flatMap(SuggestionStream::stream).limit(count).collect(Collectors.toList());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private SuggestionStream spellCheckStream(String word) {
|
private SuggestionsStreamable suggestionsForLastWord(int count, String searchWord) {
|
||||||
|
int sp = searchWord.lastIndexOf(' ');
|
||||||
|
|
||||||
|
if (sp < 0) {
|
||||||
|
return Stream::empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
String prefixString = searchWord.substring(0, sp+1);
|
||||||
|
String suggestString = searchWord.substring(sp+1);
|
||||||
|
|
||||||
|
return new SuggestionStream(prefixString, getSuggestionsForKeyword(count, suggestString));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private SuggestionsStreamable spellCheckStream(String word) {
|
||||||
int start = word.lastIndexOf(' ');
|
int start = word.lastIndexOf(' ');
|
||||||
String prefix;
|
String prefix;
|
||||||
String corrWord;
|
String corrWord;
|
||||||
@ -120,21 +120,16 @@ public class Suggestions {
|
|||||||
return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get));
|
return new SuggestionStream(prefix, Stream.of(suggestionsLazyEval).flatMap(Supplier::get));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return new SuggestionStream("", Stream.empty());
|
return Stream::empty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private String trimLeading(String word) {
|
|
||||||
|
|
||||||
for (int i = 0; i < word.length(); i++) {
|
|
||||||
if (!Character.isWhitespace(word.charAt(i)))
|
|
||||||
return word.substring(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
public Stream<String> getSuggestionsForKeyword(int count, String prefix) {
|
public Stream<String> getSuggestionsForKeyword(int count, String prefix) {
|
||||||
|
if (prefix.length() < MIN_SUGGEST_LENGTH) {
|
||||||
|
return Stream.empty();
|
||||||
|
}
|
||||||
|
|
||||||
var start = suggestionsTrie.select(prefix);
|
var start = suggestionsTrie.select(prefix);
|
||||||
|
|
||||||
if (start == null) {
|
if (start == null) {
|
||||||
@ -145,14 +140,30 @@ public class Suggestions {
|
|||||||
return Stream.empty();
|
return Stream.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Long> scach = new HashMap<>(512);
|
SuggestionsValueCalculator sv = new SuggestionsValueCalculator();
|
||||||
Function<String, Long> valr = s -> -termFrequencyDict.getTermFreqHash(scach.computeIfAbsent(s, TermFrequencyDict::getStringHash));
|
|
||||||
|
|
||||||
return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey)
|
return Stream.iterate(start.getKey(), Objects::nonNull, suggestionsTrie::nextKey)
|
||||||
.takeWhile(s -> s.startsWith(prefix))
|
.takeWhile(s -> s.startsWith(prefix))
|
||||||
.limit(256)
|
.limit(256)
|
||||||
.sorted(Comparator.comparing(valr).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
|
.sorted(Comparator.comparing(sv::get).thenComparing(String::length).thenComparing(Comparator.naturalOrder()))
|
||||||
.limit(count);
|
.limit(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private record SuggestionStream(String prefix, Stream<String> suggestionStream) implements SuggestionsStreamable {
|
||||||
|
public Stream<String> stream() {
|
||||||
|
return suggestionStream.map(s -> prefix + s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SuggestionsStreamable { Stream<String> stream(); }
|
||||||
|
|
||||||
|
private class SuggestionsValueCalculator {
|
||||||
|
|
||||||
|
private final Map<String, Long> hashCache = new HashMap<>(512);
|
||||||
|
|
||||||
|
public int get(String s) {
|
||||||
|
long hash = hashCache.computeIfAbsent(s, TermFrequencyDict::getStringHash);
|
||||||
|
return -termFrequencyDict.getTermFreqHash(hash);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ x-svc: &service
|
|||||||
- vol:/vol
|
- vol:/vol
|
||||||
- conf:/wmsa/conf:ro
|
- conf:/wmsa/conf:ro
|
||||||
- model:/wmsa/model
|
- model:/wmsa/model
|
||||||
|
- data:/wmsa/data
|
||||||
- logs:/var/log/wmsa
|
- logs:/var/log/wmsa
|
||||||
networks:
|
networks:
|
||||||
- wmsa
|
- wmsa
|
||||||
@ -126,3 +127,9 @@ volumes:
|
|||||||
type: none
|
type: none
|
||||||
o: bind
|
o: bind
|
||||||
device: run/conf
|
device: run/conf
|
||||||
|
data:
|
||||||
|
driver: local
|
||||||
|
driver_opts:
|
||||||
|
type: none
|
||||||
|
o: bind
|
||||||
|
device: run/data
|
Loading…
Reference in New Issue
Block a user