(WIP) Implement first take of new query segmentation algorithm

This commit is contained in:
Viktor Lofgren 2024-03-12 13:12:50 +01:00
parent 57e6a12d08
commit 8ae1f08095
8 changed files with 492 additions and 0 deletions

View File

@ -26,6 +26,9 @@ dependencies {
implementation project(':code:libraries:term-frequency-dict') implementation project(':code:libraries:term-frequency-dict')
implementation project(':third-party:porterstemmer') implementation project(':third-party:porterstemmer')
implementation project(':third-party:openzim')
implementation project(':third-party:commons-codec')
implementation project(':code:libraries:language-processing') implementation project(':code:libraries:language-processing')
implementation project(':code:libraries:term-frequency-dict') implementation project(':code:libraries:term-frequency-dict')
implementation project(':code:features-convert:keyword-extraction') implementation project(':code:features-convert:keyword-extraction')
@ -36,6 +39,8 @@ dependencies {
implementation libs.bundles.grpc implementation libs.bundles.grpc
implementation libs.notnull implementation libs.notnull
implementation libs.guice implementation libs.guice
implementation libs.jsoup
implementation libs.commons.lang3
implementation libs.trove implementation libs.trove
implementation libs.fastutil implementation libs.fastutil
implementation libs.bundles.gson implementation libs.bundles.gson

View File

@ -0,0 +1,16 @@
package nu.marginalia.functions.searchquery.segmentation;
import ca.rmen.porterstemmer.PorterStemmer;
import org.apache.commons.lang3.StringUtils;
public class BasicSentenceExtractor {
private static PorterStemmer porterStemmer = new PorterStemmer();
public static String[] getStemmedParts(String sentence) {
String[] parts = StringUtils.split(sentence, ' ');
for (int i = 0; i < parts.length; i++) {
parts[i] = porterStemmer.stemWord(parts[i]);
}
return parts;
}
}

View File

@ -0,0 +1,61 @@
package nu.marginalia.functions.searchquery.segmentation;
import nu.marginalia.hash.MurmurHash3_128;
/** A group of hash functions that can be used to hash a sequence of strings,
* that also has an inverse operation that can be used to remove a previously applied
* string from the sequence. */
sealed interface HasherGroup {
/** Apply a hash to the accumulator */
long apply(long acc, long add);
/** Remove a hash that was added n operations ago from the accumulator, add a new one */
long replace(long acc, long add, long rem, int n);
/** Create a new hasher group that preserves the order of appleid hash functions */
static HasherGroup ordered() {
return new OrderedHasher();
}
/** Create a new hasher group that does not preserve the order of applied hash functions */
static HasherGroup unordered() {
return new UnorderedHasher();
}
/** Bake the words in the sentence into a hash successively using the group's apply function */
default long rollingHash(String[] parts) {
long code = 0;
for (String part : parts) {
code = apply(code, hash(part));
}
return code;
}
MurmurHash3_128 hash = new MurmurHash3_128();
/** Calculate the hash of a string */
static long hash(String term) {
return hash.hashNearlyASCII(term);
}
final class UnorderedHasher implements HasherGroup {
public long apply(long acc, long add) {
return acc ^ add;
}
public long replace(long acc, long add, long rem, int n) {
return acc ^ rem ^ add;
}
}
final class OrderedHasher implements HasherGroup {
public long apply(long acc, long add) {
return Long.rotateLeft(acc, 1) ^ add;
}
public long replace(long acc, long add, long rem, int n) {
return Long.rotateLeft(acc, 1) ^ add ^ Long.rotateLeft(rem, n);
}
}
}

View File

@ -0,0 +1,46 @@
package nu.marginalia.functions.searchquery.segmentation;
import nu.marginalia.WmsaHome;
import nu.marginalia.language.sentence.SentenceExtractor;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Scanner;
public class NgramExporterMain {
public static void main(String... args) throws IOException {
trial();
}
static void trial() throws IOException {
SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels());
NgramLexicon lexicon = new NgramLexicon();
lexicon.loadCounts(Path.of("/home/vlofgren/ngram-counts.bin"));
System.out.println("Loaded!");
var scanner = new Scanner(System.in);
for (;;) {
System.out.println("Enter a sentence: ");
String line = scanner.nextLine();
System.out.println(".");
if (line == null)
break;
String[] terms = BasicSentenceExtractor.getStemmedParts(line);
System.out.println(".");
for (int i = 2; i< 8; i++) {
lexicon.findSegments(i, terms).forEach(p -> {
System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}");
});
}
}
}
}

View File

@ -0,0 +1,113 @@
package nu.marginalia.functions.searchquery.segmentation;
import it.unimi.dsi.fastutil.longs.*;
import nu.marginalia.hash.MurmurHash3_128;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.openzim.ZIMTypes.ZIMFile;
import org.openzim.ZIMTypes.ZIMReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
public class NgramExtractorMain {
static MurmurHash3_128 hash = new MurmurHash3_128();
public static void main(String... args) {
}
private static List<String> getNgramTerms(Document document) {
List<String> terms = new ArrayList<>();
document.select("a[href]").forEach(e -> {
var href = e.attr("href");
if (href.contains(":"))
return;
if (href.contains("/"))
return;
var text = e.text().toLowerCase();
if (!text.contains(" "))
return;
terms.add(text);
});
return terms;
}
public static void dumpNgramsList(
Path zimFile,
Path ngramFile
) throws IOException, InterruptedException {
ZIMReader reader = new ZIMReader(new ZIMFile(zimFile.toString()));
PrintWriter printWriter = new PrintWriter(Files.newOutputStream(ngramFile,
StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE));
LongOpenHashSet known = new LongOpenHashSet();
try (var executor = Executors.newWorkStealingPool()) {
reader.forEachArticles((title, body) -> {
executor.submit(() -> {
var terms = getNgramTerms(Jsoup.parse(body));
synchronized (known) {
for (String term : terms) {
if (known.add(hash.hashNearlyASCII(term))) {
printWriter.println(term);
}
}
}
});
}, p -> true);
}
printWriter.close();
}
public static void dumpCounts(Path zimInputFile,
Path countsOutputFile) throws IOException, InterruptedException
{
ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString()));
NgramLexicon lexicon = new NgramLexicon();
var orderedHasher = HasherGroup.ordered();
var unorderedHasher = HasherGroup.unordered();
try (var executor = Executors.newWorkStealingPool()) {
reader.forEachArticles((title, body) -> {
executor.submit(() -> {
LongArrayList orderedHashes = new LongArrayList();
LongArrayList unorderedHashes = new LongArrayList();
for (var sent : getNgramTerms(Jsoup.parse(body))) {
String[] terms = BasicSentenceExtractor.getStemmedParts(sent);
orderedHashes.add(orderedHasher.rollingHash(terms));
unorderedHashes.add(unorderedHasher.rollingHash(terms));
}
synchronized (lexicon) {
for (var hash : orderedHashes) {
lexicon.incOrdered(hash);
}
for (var hash : unorderedHashes) {
lexicon.addUnordered(hash);
}
}
});
}, p -> true);
}
lexicon.saveCounts(countsOutputFile);
}
}

View File

@ -0,0 +1,165 @@
package nu.marginalia.functions.searchquery.segmentation;
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
import it.unimi.dsi.fastutil.longs.LongHash;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
class NgramLexicon {
private final Long2IntOpenCustomHashMap counts = new Long2IntOpenCustomHashMap(
100_000_000,
new KeyIsAlreadyHashStrategy()
);
private final LongOpenHashSet permutations = new LongOpenHashSet();
private static final HasherGroup orderedHasher = HasherGroup.ordered();
private static final HasherGroup unorderedHasher = HasherGroup.unordered();
public List<SentenceSegment> findSegments(int length, String... parts) {
// Don't look for ngrams longer than the sentence
if (parts.length < length) return List.of();
List<SentenceSegment> positions = new ArrayList<>();
// Hash the parts
long[] hashes = new long[parts.length];
for (int i = 0; i < hashes.length; i++) {
hashes[i] = HasherGroup.hash(parts[i]);
}
long ordered = 0;
long unordered = 0;
int i = 0;
// Prepare by combining up to length hashes
for (; i < length; i++) {
ordered = orderedHasher.apply(ordered, hashes[i]);
unordered = unorderedHasher.apply(unordered, hashes[i]);
}
// Slide the window and look for matches
for (;; i++) {
int ct = counts.get(ordered);
if (ct > 0) {
positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM));
}
else if (permutations.contains(unordered)) {
positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION));
}
if (i >= hashes.length)
break;
// Remove the oldest hash and add the new one
ordered = orderedHasher.replace(ordered,
hashes[i],
hashes[i - length],
length);
unordered = unorderedHasher.replace(unordered,
hashes[i],
hashes[i - length],
length);
}
return positions;
}
public void incOrdered(long hashOrdered) {
counts.addTo(hashOrdered, 1);
}
public void addUnordered(long hashUnordered) {
permutations.add(hashUnordered);
}
public void loadCounts(Path path) throws IOException {
try (var dis = new DataInputStream(Files.newInputStream(path))) {
long size = dis.readInt();
for (int i = 0; i < size; i++) {
counts.put(dis.readLong(), dis.readInt());
}
}
}
public void loadPermutations(Path path) throws IOException {
try (var dis = new DataInputStream(Files.newInputStream(path))) {
long size = dis.readInt();
for (int i = 0; i < size; i++) {
permutations.add(dis.readLong());
}
}
}
public void saveCounts(Path file) throws IOException {
try (var dos = new DataOutputStream(Files.newOutputStream(file,
StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING,
StandardOpenOption.WRITE))) {
dos.writeInt(counts.size());
counts.forEach((k, v) -> {
try {
dos.writeLong(k);
dos.writeInt(v);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
}
public void savePermutations(Path file) throws IOException {
try (var dos = new DataOutputStream(Files.newOutputStream(file,
StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING,
StandardOpenOption.WRITE))) {
dos.writeInt(counts.size());
permutations.forEach(v -> {
try {
dos.writeLong(v);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
}
public void clear() {
permutations.clear();
counts.clear();
}
public record SentenceSegment(int start, int length, int count, PositionType type) {
public String[] project(String... parts) {
return Arrays.copyOfRange(parts, start, start + length);
}
}
enum PositionType {
NGRAM, PERMUTATION
}
private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy {
@Override
public int hashCode(long l) {
return (int) l;
}
@Override
public boolean equals(long l, long l1) {
return l == l1;
}
}
}

View File

@ -0,0 +1,33 @@
package nu.marginalia.functions.searchquery.segmentation;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class HasherGroupTest {
@Test
void ordered() {
long a = 5;
long b = 3;
long c = 2;
var group = HasherGroup.ordered();
assertNotEquals(group.apply(a, b), group.apply(b, a));
assertEquals(group.apply(b,c), group.replace(group.apply(a, b), c, a, 2));
}
@Test
void unordered() {
long a = 5;
long b = 3;
long c = 2;
var group = HasherGroup.unordered();
assertEquals(group.apply(a, b), group.apply(b, a));
assertEquals(group.apply(b, c), group.replace(group.apply(a, b), c, a, 2));
}
}

View File

@ -0,0 +1,53 @@
package nu.marginalia.functions.searchquery.segmentation;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class NgramLexiconTest {
NgramLexicon lexicon = new NgramLexicon();
@BeforeEach
public void setUp() {
lexicon.clear();
}
void addNgram(String... ngram) {
lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram));
lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram));
}
@Test
void findSegments() {
addNgram("hello", "world");
addNgram("rye", "bread");
addNgram("rye", "world");
String[] sent = { "hello", "world", "rye", "bread" };
var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread");
assertEquals(3, segments.size());
for (int i = 0; i < 3; i++) {
var segment = segments.get(i);
switch (i) {
case 0 -> {
assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent));
assertEquals(1, segment.count());
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
}
case 1 -> {
assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent));
assertEquals(0, segment.count());
assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type());
}
case 2 -> {
assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent));
assertEquals(1, segment.count());
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
}
}
}
}
}