mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-02-23 21:18:58 +00:00
(WIP) Implement first take of new query segmentation algorithm
This commit is contained in:
parent
57e6a12d08
commit
8ae1f08095
@ -26,6 +26,9 @@ dependencies {
|
||||
implementation project(':code:libraries:term-frequency-dict')
|
||||
|
||||
implementation project(':third-party:porterstemmer')
|
||||
implementation project(':third-party:openzim')
|
||||
implementation project(':third-party:commons-codec')
|
||||
|
||||
implementation project(':code:libraries:language-processing')
|
||||
implementation project(':code:libraries:term-frequency-dict')
|
||||
implementation project(':code:features-convert:keyword-extraction')
|
||||
@ -36,6 +39,8 @@ dependencies {
|
||||
implementation libs.bundles.grpc
|
||||
implementation libs.notnull
|
||||
implementation libs.guice
|
||||
implementation libs.jsoup
|
||||
implementation libs.commons.lang3
|
||||
implementation libs.trove
|
||||
implementation libs.fastutil
|
||||
implementation libs.bundles.gson
|
||||
|
@ -0,0 +1,16 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import ca.rmen.porterstemmer.PorterStemmer;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class BasicSentenceExtractor {
|
||||
|
||||
private static PorterStemmer porterStemmer = new PorterStemmer();
|
||||
public static String[] getStemmedParts(String sentence) {
|
||||
String[] parts = StringUtils.split(sentence, ' ');
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
parts[i] = porterStemmer.stemWord(parts[i]);
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import nu.marginalia.hash.MurmurHash3_128;
|
||||
|
||||
/** A group of hash functions that can be used to hash a sequence of strings,
|
||||
* that also has an inverse operation that can be used to remove a previously applied
|
||||
* string from the sequence. */
|
||||
sealed interface HasherGroup {
|
||||
/** Apply a hash to the accumulator */
|
||||
long apply(long acc, long add);
|
||||
|
||||
/** Remove a hash that was added n operations ago from the accumulator, add a new one */
|
||||
long replace(long acc, long add, long rem, int n);
|
||||
|
||||
/** Create a new hasher group that preserves the order of appleid hash functions */
|
||||
static HasherGroup ordered() {
|
||||
return new OrderedHasher();
|
||||
}
|
||||
|
||||
/** Create a new hasher group that does not preserve the order of applied hash functions */
|
||||
static HasherGroup unordered() {
|
||||
return new UnorderedHasher();
|
||||
}
|
||||
|
||||
/** Bake the words in the sentence into a hash successively using the group's apply function */
|
||||
default long rollingHash(String[] parts) {
|
||||
long code = 0;
|
||||
for (String part : parts) {
|
||||
code = apply(code, hash(part));
|
||||
}
|
||||
return code;
|
||||
}
|
||||
|
||||
MurmurHash3_128 hash = new MurmurHash3_128();
|
||||
/** Calculate the hash of a string */
|
||||
static long hash(String term) {
|
||||
return hash.hashNearlyASCII(term);
|
||||
}
|
||||
|
||||
final class UnorderedHasher implements HasherGroup {
|
||||
|
||||
public long apply(long acc, long add) {
|
||||
return acc ^ add;
|
||||
}
|
||||
|
||||
public long replace(long acc, long add, long rem, int n) {
|
||||
return acc ^ rem ^ add;
|
||||
}
|
||||
}
|
||||
|
||||
final class OrderedHasher implements HasherGroup {
|
||||
|
||||
public long apply(long acc, long add) {
|
||||
return Long.rotateLeft(acc, 1) ^ add;
|
||||
}
|
||||
|
||||
public long replace(long acc, long add, long rem, int n) {
|
||||
return Long.rotateLeft(acc, 1) ^ add ^ Long.rotateLeft(rem, n);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import nu.marginalia.WmsaHome;
|
||||
import nu.marginalia.language.sentence.SentenceExtractor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.Scanner;
|
||||
|
||||
public class NgramExporterMain {
|
||||
|
||||
public static void main(String... args) throws IOException {
|
||||
trial();
|
||||
}
|
||||
|
||||
static void trial() throws IOException {
|
||||
SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels());
|
||||
|
||||
NgramLexicon lexicon = new NgramLexicon();
|
||||
lexicon.loadCounts(Path.of("/home/vlofgren/ngram-counts.bin"));
|
||||
|
||||
System.out.println("Loaded!");
|
||||
|
||||
var scanner = new Scanner(System.in);
|
||||
for (;;) {
|
||||
System.out.println("Enter a sentence: ");
|
||||
String line = scanner.nextLine();
|
||||
System.out.println(".");
|
||||
if (line == null)
|
||||
break;
|
||||
|
||||
String[] terms = BasicSentenceExtractor.getStemmedParts(line);
|
||||
System.out.println(".");
|
||||
|
||||
for (int i = 2; i< 8; i++) {
|
||||
lexicon.findSegments(i, terms).forEach(p -> {
|
||||
System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}");
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import it.unimi.dsi.fastutil.longs.*;
|
||||
import nu.marginalia.hash.MurmurHash3_128;
|
||||
import org.jsoup.Jsoup;
|
||||
import org.jsoup.nodes.Document;
|
||||
import org.openzim.ZIMTypes.ZIMFile;
|
||||
import org.openzim.ZIMTypes.ZIMReader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class NgramExtractorMain {
|
||||
static MurmurHash3_128 hash = new MurmurHash3_128();
|
||||
|
||||
public static void main(String... args) {
|
||||
}
|
||||
|
||||
private static List<String> getNgramTerms(Document document) {
|
||||
List<String> terms = new ArrayList<>();
|
||||
|
||||
document.select("a[href]").forEach(e -> {
|
||||
var href = e.attr("href");
|
||||
if (href.contains(":"))
|
||||
return;
|
||||
if (href.contains("/"))
|
||||
return;
|
||||
|
||||
var text = e.text().toLowerCase();
|
||||
if (!text.contains(" "))
|
||||
return;
|
||||
|
||||
terms.add(text);
|
||||
});
|
||||
|
||||
return terms;
|
||||
}
|
||||
|
||||
public static void dumpNgramsList(
|
||||
Path zimFile,
|
||||
Path ngramFile
|
||||
) throws IOException, InterruptedException {
|
||||
ZIMReader reader = new ZIMReader(new ZIMFile(zimFile.toString()));
|
||||
|
||||
PrintWriter printWriter = new PrintWriter(Files.newOutputStream(ngramFile,
|
||||
StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE));
|
||||
|
||||
LongOpenHashSet known = new LongOpenHashSet();
|
||||
|
||||
try (var executor = Executors.newWorkStealingPool()) {
|
||||
reader.forEachArticles((title, body) -> {
|
||||
executor.submit(() -> {
|
||||
var terms = getNgramTerms(Jsoup.parse(body));
|
||||
synchronized (known) {
|
||||
for (String term : terms) {
|
||||
if (known.add(hash.hashNearlyASCII(term))) {
|
||||
printWriter.println(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
}, p -> true);
|
||||
}
|
||||
printWriter.close();
|
||||
}
|
||||
|
||||
public static void dumpCounts(Path zimInputFile,
|
||||
Path countsOutputFile) throws IOException, InterruptedException
|
||||
{
|
||||
ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString()));
|
||||
|
||||
NgramLexicon lexicon = new NgramLexicon();
|
||||
|
||||
var orderedHasher = HasherGroup.ordered();
|
||||
var unorderedHasher = HasherGroup.unordered();
|
||||
|
||||
try (var executor = Executors.newWorkStealingPool()) {
|
||||
reader.forEachArticles((title, body) -> {
|
||||
executor.submit(() -> {
|
||||
LongArrayList orderedHashes = new LongArrayList();
|
||||
LongArrayList unorderedHashes = new LongArrayList();
|
||||
|
||||
for (var sent : getNgramTerms(Jsoup.parse(body))) {
|
||||
String[] terms = BasicSentenceExtractor.getStemmedParts(sent);
|
||||
|
||||
orderedHashes.add(orderedHasher.rollingHash(terms));
|
||||
unorderedHashes.add(unorderedHasher.rollingHash(terms));
|
||||
}
|
||||
|
||||
synchronized (lexicon) {
|
||||
for (var hash : orderedHashes) {
|
||||
lexicon.incOrdered(hash);
|
||||
}
|
||||
for (var hash : unorderedHashes) {
|
||||
lexicon.addUnordered(hash);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
}, p -> true);
|
||||
}
|
||||
|
||||
lexicon.saveCounts(countsOutputFile);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,165 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
|
||||
import it.unimi.dsi.fastutil.longs.LongHash;
|
||||
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
|
||||
|
||||
import java.io.DataInputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
class NgramLexicon {
|
||||
private final Long2IntOpenCustomHashMap counts = new Long2IntOpenCustomHashMap(
|
||||
100_000_000,
|
||||
new KeyIsAlreadyHashStrategy()
|
||||
);
|
||||
private final LongOpenHashSet permutations = new LongOpenHashSet();
|
||||
|
||||
private static final HasherGroup orderedHasher = HasherGroup.ordered();
|
||||
private static final HasherGroup unorderedHasher = HasherGroup.unordered();
|
||||
|
||||
public List<SentenceSegment> findSegments(int length, String... parts) {
|
||||
// Don't look for ngrams longer than the sentence
|
||||
if (parts.length < length) return List.of();
|
||||
|
||||
List<SentenceSegment> positions = new ArrayList<>();
|
||||
|
||||
// Hash the parts
|
||||
long[] hashes = new long[parts.length];
|
||||
for (int i = 0; i < hashes.length; i++) {
|
||||
hashes[i] = HasherGroup.hash(parts[i]);
|
||||
}
|
||||
|
||||
long ordered = 0;
|
||||
long unordered = 0;
|
||||
int i = 0;
|
||||
|
||||
// Prepare by combining up to length hashes
|
||||
for (; i < length; i++) {
|
||||
ordered = orderedHasher.apply(ordered, hashes[i]);
|
||||
unordered = unorderedHasher.apply(unordered, hashes[i]);
|
||||
}
|
||||
|
||||
// Slide the window and look for matches
|
||||
for (;; i++) {
|
||||
int ct = counts.get(ordered);
|
||||
|
||||
if (ct > 0) {
|
||||
positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM));
|
||||
}
|
||||
else if (permutations.contains(unordered)) {
|
||||
positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION));
|
||||
}
|
||||
|
||||
if (i >= hashes.length)
|
||||
break;
|
||||
|
||||
// Remove the oldest hash and add the new one
|
||||
ordered = orderedHasher.replace(ordered,
|
||||
hashes[i],
|
||||
hashes[i - length],
|
||||
length);
|
||||
unordered = unorderedHasher.replace(unordered,
|
||||
hashes[i],
|
||||
hashes[i - length],
|
||||
length);
|
||||
}
|
||||
|
||||
return positions;
|
||||
}
|
||||
|
||||
public void incOrdered(long hashOrdered) {
|
||||
counts.addTo(hashOrdered, 1);
|
||||
}
|
||||
public void addUnordered(long hashUnordered) {
|
||||
permutations.add(hashUnordered);
|
||||
}
|
||||
|
||||
public void loadCounts(Path path) throws IOException {
|
||||
try (var dis = new DataInputStream(Files.newInputStream(path))) {
|
||||
long size = dis.readInt();
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
counts.put(dis.readLong(), dis.readInt());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void loadPermutations(Path path) throws IOException {
|
||||
try (var dis = new DataInputStream(Files.newInputStream(path))) {
|
||||
long size = dis.readInt();
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
permutations.add(dis.readLong());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void saveCounts(Path file) throws IOException {
|
||||
try (var dos = new DataOutputStream(Files.newOutputStream(file,
|
||||
StandardOpenOption.CREATE,
|
||||
StandardOpenOption.TRUNCATE_EXISTING,
|
||||
StandardOpenOption.WRITE))) {
|
||||
dos.writeInt(counts.size());
|
||||
|
||||
counts.forEach((k, v) -> {
|
||||
try {
|
||||
dos.writeLong(k);
|
||||
dos.writeInt(v);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
public void savePermutations(Path file) throws IOException {
|
||||
try (var dos = new DataOutputStream(Files.newOutputStream(file,
|
||||
StandardOpenOption.CREATE,
|
||||
StandardOpenOption.TRUNCATE_EXISTING,
|
||||
StandardOpenOption.WRITE))) {
|
||||
dos.writeInt(counts.size());
|
||||
|
||||
permutations.forEach(v -> {
|
||||
try {
|
||||
dos.writeLong(v);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
public void clear() {
|
||||
permutations.clear();
|
||||
counts.clear();
|
||||
}
|
||||
|
||||
public record SentenceSegment(int start, int length, int count, PositionType type) {
|
||||
public String[] project(String... parts) {
|
||||
return Arrays.copyOfRange(parts, start, start + length);
|
||||
}
|
||||
}
|
||||
|
||||
enum PositionType {
|
||||
NGRAM, PERMUTATION
|
||||
}
|
||||
|
||||
private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy {
|
||||
@Override
|
||||
public int hashCode(long l) {
|
||||
return (int) l;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(long l, long l1) {
|
||||
return l == l1;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,33 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class HasherGroupTest {
|
||||
|
||||
@Test
|
||||
void ordered() {
|
||||
long a = 5;
|
||||
long b = 3;
|
||||
long c = 2;
|
||||
|
||||
var group = HasherGroup.ordered();
|
||||
assertNotEquals(group.apply(a, b), group.apply(b, a));
|
||||
assertEquals(group.apply(b,c), group.replace(group.apply(a, b), c, a, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void unordered() {
|
||||
long a = 5;
|
||||
long b = 3;
|
||||
long c = 2;
|
||||
|
||||
var group = HasherGroup.unordered();
|
||||
|
||||
assertEquals(group.apply(a, b), group.apply(b, a));
|
||||
assertEquals(group.apply(b, c), group.replace(group.apply(a, b), c, a, 2));
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
package nu.marginalia.functions.searchquery.segmentation;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class NgramLexiconTest {
|
||||
NgramLexicon lexicon = new NgramLexicon();
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
lexicon.clear();
|
||||
}
|
||||
|
||||
void addNgram(String... ngram) {
|
||||
lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram));
|
||||
lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram));
|
||||
}
|
||||
|
||||
@Test
|
||||
void findSegments() {
|
||||
addNgram("hello", "world");
|
||||
addNgram("rye", "bread");
|
||||
addNgram("rye", "world");
|
||||
|
||||
String[] sent = { "hello", "world", "rye", "bread" };
|
||||
var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread");
|
||||
|
||||
assertEquals(3, segments.size());
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
var segment = segments.get(i);
|
||||
switch (i) {
|
||||
case 0 -> {
|
||||
assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent));
|
||||
assertEquals(1, segment.count());
|
||||
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
|
||||
}
|
||||
case 1 -> {
|
||||
assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent));
|
||||
assertEquals(0, segment.count());
|
||||
assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type());
|
||||
}
|
||||
case 2 -> {
|
||||
assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent));
|
||||
assertEquals(1, segment.count());
|
||||
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user