mirror of
https://github.com/MarginaliaSearch/MarginaliaSearch.git
synced 2025-02-24 13:19:02 +00:00
(WIP) Implement first take of new query segmentation algorithm
This commit is contained in:
parent
57e6a12d08
commit
8ae1f08095
@ -26,6 +26,9 @@ dependencies {
|
|||||||
implementation project(':code:libraries:term-frequency-dict')
|
implementation project(':code:libraries:term-frequency-dict')
|
||||||
|
|
||||||
implementation project(':third-party:porterstemmer')
|
implementation project(':third-party:porterstemmer')
|
||||||
|
implementation project(':third-party:openzim')
|
||||||
|
implementation project(':third-party:commons-codec')
|
||||||
|
|
||||||
implementation project(':code:libraries:language-processing')
|
implementation project(':code:libraries:language-processing')
|
||||||
implementation project(':code:libraries:term-frequency-dict')
|
implementation project(':code:libraries:term-frequency-dict')
|
||||||
implementation project(':code:features-convert:keyword-extraction')
|
implementation project(':code:features-convert:keyword-extraction')
|
||||||
@ -36,6 +39,8 @@ dependencies {
|
|||||||
implementation libs.bundles.grpc
|
implementation libs.bundles.grpc
|
||||||
implementation libs.notnull
|
implementation libs.notnull
|
||||||
implementation libs.guice
|
implementation libs.guice
|
||||||
|
implementation libs.jsoup
|
||||||
|
implementation libs.commons.lang3
|
||||||
implementation libs.trove
|
implementation libs.trove
|
||||||
implementation libs.fastutil
|
implementation libs.fastutil
|
||||||
implementation libs.bundles.gson
|
implementation libs.bundles.gson
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import ca.rmen.porterstemmer.PorterStemmer;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
public class BasicSentenceExtractor {
|
||||||
|
|
||||||
|
private static PorterStemmer porterStemmer = new PorterStemmer();
|
||||||
|
public static String[] getStemmedParts(String sentence) {
|
||||||
|
String[] parts = StringUtils.split(sentence, ' ');
|
||||||
|
for (int i = 0; i < parts.length; i++) {
|
||||||
|
parts[i] = porterStemmer.stemWord(parts[i]);
|
||||||
|
}
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import nu.marginalia.hash.MurmurHash3_128;
|
||||||
|
|
||||||
|
/** A group of hash functions that can be used to hash a sequence of strings,
|
||||||
|
* that also has an inverse operation that can be used to remove a previously applied
|
||||||
|
* string from the sequence. */
|
||||||
|
sealed interface HasherGroup {
|
||||||
|
/** Apply a hash to the accumulator */
|
||||||
|
long apply(long acc, long add);
|
||||||
|
|
||||||
|
/** Remove a hash that was added n operations ago from the accumulator, add a new one */
|
||||||
|
long replace(long acc, long add, long rem, int n);
|
||||||
|
|
||||||
|
/** Create a new hasher group that preserves the order of appleid hash functions */
|
||||||
|
static HasherGroup ordered() {
|
||||||
|
return new OrderedHasher();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Create a new hasher group that does not preserve the order of applied hash functions */
|
||||||
|
static HasherGroup unordered() {
|
||||||
|
return new UnorderedHasher();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Bake the words in the sentence into a hash successively using the group's apply function */
|
||||||
|
default long rollingHash(String[] parts) {
|
||||||
|
long code = 0;
|
||||||
|
for (String part : parts) {
|
||||||
|
code = apply(code, hash(part));
|
||||||
|
}
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
|
||||||
|
MurmurHash3_128 hash = new MurmurHash3_128();
|
||||||
|
/** Calculate the hash of a string */
|
||||||
|
static long hash(String term) {
|
||||||
|
return hash.hashNearlyASCII(term);
|
||||||
|
}
|
||||||
|
|
||||||
|
final class UnorderedHasher implements HasherGroup {
|
||||||
|
|
||||||
|
public long apply(long acc, long add) {
|
||||||
|
return acc ^ add;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long replace(long acc, long add, long rem, int n) {
|
||||||
|
return acc ^ rem ^ add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final class OrderedHasher implements HasherGroup {
|
||||||
|
|
||||||
|
public long apply(long acc, long add) {
|
||||||
|
return Long.rotateLeft(acc, 1) ^ add;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long replace(long acc, long add, long rem, int n) {
|
||||||
|
return Long.rotateLeft(acc, 1) ^ add ^ Long.rotateLeft(rem, n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import nu.marginalia.WmsaHome;
|
||||||
|
import nu.marginalia.language.sentence.SentenceExtractor;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Scanner;
|
||||||
|
|
||||||
|
public class NgramExporterMain {
|
||||||
|
|
||||||
|
public static void main(String... args) throws IOException {
|
||||||
|
trial();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void trial() throws IOException {
|
||||||
|
SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels());
|
||||||
|
|
||||||
|
NgramLexicon lexicon = new NgramLexicon();
|
||||||
|
lexicon.loadCounts(Path.of("/home/vlofgren/ngram-counts.bin"));
|
||||||
|
|
||||||
|
System.out.println("Loaded!");
|
||||||
|
|
||||||
|
var scanner = new Scanner(System.in);
|
||||||
|
for (;;) {
|
||||||
|
System.out.println("Enter a sentence: ");
|
||||||
|
String line = scanner.nextLine();
|
||||||
|
System.out.println(".");
|
||||||
|
if (line == null)
|
||||||
|
break;
|
||||||
|
|
||||||
|
String[] terms = BasicSentenceExtractor.getStemmedParts(line);
|
||||||
|
System.out.println(".");
|
||||||
|
|
||||||
|
for (int i = 2; i< 8; i++) {
|
||||||
|
lexicon.findSegments(i, terms).forEach(p -> {
|
||||||
|
System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,113 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import it.unimi.dsi.fastutil.longs.*;
|
||||||
|
import nu.marginalia.hash.MurmurHash3_128;
|
||||||
|
import org.jsoup.Jsoup;
|
||||||
|
import org.jsoup.nodes.Document;
|
||||||
|
import org.openzim.ZIMTypes.ZIMFile;
|
||||||
|
import org.openzim.ZIMTypes.ZIMReader;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.PrintWriter;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.StandardOpenOption;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
|
public class NgramExtractorMain {
|
||||||
|
static MurmurHash3_128 hash = new MurmurHash3_128();
|
||||||
|
|
||||||
|
public static void main(String... args) {
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<String> getNgramTerms(Document document) {
|
||||||
|
List<String> terms = new ArrayList<>();
|
||||||
|
|
||||||
|
document.select("a[href]").forEach(e -> {
|
||||||
|
var href = e.attr("href");
|
||||||
|
if (href.contains(":"))
|
||||||
|
return;
|
||||||
|
if (href.contains("/"))
|
||||||
|
return;
|
||||||
|
|
||||||
|
var text = e.text().toLowerCase();
|
||||||
|
if (!text.contains(" "))
|
||||||
|
return;
|
||||||
|
|
||||||
|
terms.add(text);
|
||||||
|
});
|
||||||
|
|
||||||
|
return terms;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void dumpNgramsList(
|
||||||
|
Path zimFile,
|
||||||
|
Path ngramFile
|
||||||
|
) throws IOException, InterruptedException {
|
||||||
|
ZIMReader reader = new ZIMReader(new ZIMFile(zimFile.toString()));
|
||||||
|
|
||||||
|
PrintWriter printWriter = new PrintWriter(Files.newOutputStream(ngramFile,
|
||||||
|
StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE));
|
||||||
|
|
||||||
|
LongOpenHashSet known = new LongOpenHashSet();
|
||||||
|
|
||||||
|
try (var executor = Executors.newWorkStealingPool()) {
|
||||||
|
reader.forEachArticles((title, body) -> {
|
||||||
|
executor.submit(() -> {
|
||||||
|
var terms = getNgramTerms(Jsoup.parse(body));
|
||||||
|
synchronized (known) {
|
||||||
|
for (String term : terms) {
|
||||||
|
if (known.add(hash.hashNearlyASCII(term))) {
|
||||||
|
printWriter.println(term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
}, p -> true);
|
||||||
|
}
|
||||||
|
printWriter.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void dumpCounts(Path zimInputFile,
|
||||||
|
Path countsOutputFile) throws IOException, InterruptedException
|
||||||
|
{
|
||||||
|
ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString()));
|
||||||
|
|
||||||
|
NgramLexicon lexicon = new NgramLexicon();
|
||||||
|
|
||||||
|
var orderedHasher = HasherGroup.ordered();
|
||||||
|
var unorderedHasher = HasherGroup.unordered();
|
||||||
|
|
||||||
|
try (var executor = Executors.newWorkStealingPool()) {
|
||||||
|
reader.forEachArticles((title, body) -> {
|
||||||
|
executor.submit(() -> {
|
||||||
|
LongArrayList orderedHashes = new LongArrayList();
|
||||||
|
LongArrayList unorderedHashes = new LongArrayList();
|
||||||
|
|
||||||
|
for (var sent : getNgramTerms(Jsoup.parse(body))) {
|
||||||
|
String[] terms = BasicSentenceExtractor.getStemmedParts(sent);
|
||||||
|
|
||||||
|
orderedHashes.add(orderedHasher.rollingHash(terms));
|
||||||
|
unorderedHashes.add(unorderedHasher.rollingHash(terms));
|
||||||
|
}
|
||||||
|
|
||||||
|
synchronized (lexicon) {
|
||||||
|
for (var hash : orderedHashes) {
|
||||||
|
lexicon.incOrdered(hash);
|
||||||
|
}
|
||||||
|
for (var hash : unorderedHashes) {
|
||||||
|
lexicon.addUnordered(hash);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
}, p -> true);
|
||||||
|
}
|
||||||
|
|
||||||
|
lexicon.saveCounts(countsOutputFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,165 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
|
||||||
|
import it.unimi.dsi.fastutil.longs.LongHash;
|
||||||
|
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
|
||||||
|
|
||||||
|
import java.io.DataInputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.StandardOpenOption;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
class NgramLexicon {
|
||||||
|
private final Long2IntOpenCustomHashMap counts = new Long2IntOpenCustomHashMap(
|
||||||
|
100_000_000,
|
||||||
|
new KeyIsAlreadyHashStrategy()
|
||||||
|
);
|
||||||
|
private final LongOpenHashSet permutations = new LongOpenHashSet();
|
||||||
|
|
||||||
|
private static final HasherGroup orderedHasher = HasherGroup.ordered();
|
||||||
|
private static final HasherGroup unorderedHasher = HasherGroup.unordered();
|
||||||
|
|
||||||
|
public List<SentenceSegment> findSegments(int length, String... parts) {
|
||||||
|
// Don't look for ngrams longer than the sentence
|
||||||
|
if (parts.length < length) return List.of();
|
||||||
|
|
||||||
|
List<SentenceSegment> positions = new ArrayList<>();
|
||||||
|
|
||||||
|
// Hash the parts
|
||||||
|
long[] hashes = new long[parts.length];
|
||||||
|
for (int i = 0; i < hashes.length; i++) {
|
||||||
|
hashes[i] = HasherGroup.hash(parts[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
long ordered = 0;
|
||||||
|
long unordered = 0;
|
||||||
|
int i = 0;
|
||||||
|
|
||||||
|
// Prepare by combining up to length hashes
|
||||||
|
for (; i < length; i++) {
|
||||||
|
ordered = orderedHasher.apply(ordered, hashes[i]);
|
||||||
|
unordered = unorderedHasher.apply(unordered, hashes[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slide the window and look for matches
|
||||||
|
for (;; i++) {
|
||||||
|
int ct = counts.get(ordered);
|
||||||
|
|
||||||
|
if (ct > 0) {
|
||||||
|
positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM));
|
||||||
|
}
|
||||||
|
else if (permutations.contains(unordered)) {
|
||||||
|
positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i >= hashes.length)
|
||||||
|
break;
|
||||||
|
|
||||||
|
// Remove the oldest hash and add the new one
|
||||||
|
ordered = orderedHasher.replace(ordered,
|
||||||
|
hashes[i],
|
||||||
|
hashes[i - length],
|
||||||
|
length);
|
||||||
|
unordered = unorderedHasher.replace(unordered,
|
||||||
|
hashes[i],
|
||||||
|
hashes[i - length],
|
||||||
|
length);
|
||||||
|
}
|
||||||
|
|
||||||
|
return positions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void incOrdered(long hashOrdered) {
|
||||||
|
counts.addTo(hashOrdered, 1);
|
||||||
|
}
|
||||||
|
public void addUnordered(long hashUnordered) {
|
||||||
|
permutations.add(hashUnordered);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadCounts(Path path) throws IOException {
|
||||||
|
try (var dis = new DataInputStream(Files.newInputStream(path))) {
|
||||||
|
long size = dis.readInt();
|
||||||
|
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
counts.put(dis.readLong(), dis.readInt());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadPermutations(Path path) throws IOException {
|
||||||
|
try (var dis = new DataInputStream(Files.newInputStream(path))) {
|
||||||
|
long size = dis.readInt();
|
||||||
|
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
permutations.add(dis.readLong());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveCounts(Path file) throws IOException {
|
||||||
|
try (var dos = new DataOutputStream(Files.newOutputStream(file,
|
||||||
|
StandardOpenOption.CREATE,
|
||||||
|
StandardOpenOption.TRUNCATE_EXISTING,
|
||||||
|
StandardOpenOption.WRITE))) {
|
||||||
|
dos.writeInt(counts.size());
|
||||||
|
|
||||||
|
counts.forEach((k, v) -> {
|
||||||
|
try {
|
||||||
|
dos.writeLong(k);
|
||||||
|
dos.writeInt(v);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public void savePermutations(Path file) throws IOException {
|
||||||
|
try (var dos = new DataOutputStream(Files.newOutputStream(file,
|
||||||
|
StandardOpenOption.CREATE,
|
||||||
|
StandardOpenOption.TRUNCATE_EXISTING,
|
||||||
|
StandardOpenOption.WRITE))) {
|
||||||
|
dos.writeInt(counts.size());
|
||||||
|
|
||||||
|
permutations.forEach(v -> {
|
||||||
|
try {
|
||||||
|
dos.writeLong(v);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public void clear() {
|
||||||
|
permutations.clear();
|
||||||
|
counts.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
public record SentenceSegment(int start, int length, int count, PositionType type) {
|
||||||
|
public String[] project(String... parts) {
|
||||||
|
return Arrays.copyOfRange(parts, start, start + length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum PositionType {
|
||||||
|
NGRAM, PERMUTATION
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy {
|
||||||
|
@Override
|
||||||
|
public int hashCode(long l) {
|
||||||
|
return (int) l;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(long l, long l1) {
|
||||||
|
return l == l1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,33 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class HasherGroupTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void ordered() {
|
||||||
|
long a = 5;
|
||||||
|
long b = 3;
|
||||||
|
long c = 2;
|
||||||
|
|
||||||
|
var group = HasherGroup.ordered();
|
||||||
|
assertNotEquals(group.apply(a, b), group.apply(b, a));
|
||||||
|
assertEquals(group.apply(b,c), group.replace(group.apply(a, b), c, a, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void unordered() {
|
||||||
|
long a = 5;
|
||||||
|
long b = 3;
|
||||||
|
long c = 2;
|
||||||
|
|
||||||
|
var group = HasherGroup.unordered();
|
||||||
|
|
||||||
|
assertEquals(group.apply(a, b), group.apply(b, a));
|
||||||
|
assertEquals(group.apply(b, c), group.replace(group.apply(a, b), c, a, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
package nu.marginalia.functions.searchquery.segmentation;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class NgramLexiconTest {
|
||||||
|
NgramLexicon lexicon = new NgramLexicon();
|
||||||
|
@BeforeEach
|
||||||
|
public void setUp() {
|
||||||
|
lexicon.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void addNgram(String... ngram) {
|
||||||
|
lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram));
|
||||||
|
lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findSegments() {
|
||||||
|
addNgram("hello", "world");
|
||||||
|
addNgram("rye", "bread");
|
||||||
|
addNgram("rye", "world");
|
||||||
|
|
||||||
|
String[] sent = { "hello", "world", "rye", "bread" };
|
||||||
|
var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread");
|
||||||
|
|
||||||
|
assertEquals(3, segments.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
var segment = segments.get(i);
|
||||||
|
switch (i) {
|
||||||
|
case 0 -> {
|
||||||
|
assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent));
|
||||||
|
assertEquals(1, segment.count());
|
||||||
|
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
|
||||||
|
}
|
||||||
|
case 1 -> {
|
||||||
|
assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent));
|
||||||
|
assertEquals(0, segment.count());
|
||||||
|
assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type());
|
||||||
|
}
|
||||||
|
case 2 -> {
|
||||||
|
assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent));
|
||||||
|
assertEquals(1, segment.count());
|
||||||
|
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user