From 8ae1f080956ee72a3d5f8a4a0971376ea242d3d5 Mon Sep 17 00:00:00 2001 From: Viktor Lofgren Date: Tue, 12 Mar 2024 13:12:50 +0100 Subject: [PATCH] (WIP) Implement first take of new query segmentation algorithm --- code/functions/search-query/build.gradle | 5 + .../segmentation/BasicSentenceExtractor.java | 16 ++ .../searchquery/segmentation/HasherGroup.java | 61 +++++++ .../segmentation/NgramExporterMain.java | 46 +++++ .../segmentation/NgramExtractorMain.java | 113 ++++++++++++ .../segmentation/NgramLexicon.java | 165 ++++++++++++++++++ .../segmentation/HasherGroupTest.java | 33 ++++ .../segmentation/NgramLexiconTest.java | 53 ++++++ 8 files changed, 492 insertions(+) create mode 100644 code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/BasicSentenceExtractor.java create mode 100644 code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/HasherGroup.java create mode 100644 code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExporterMain.java create mode 100644 code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExtractorMain.java create mode 100644 code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramLexicon.java create mode 100644 code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/HasherGroupTest.java create mode 100644 code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/NgramLexiconTest.java diff --git a/code/functions/search-query/build.gradle b/code/functions/search-query/build.gradle index 86cafefa..76c520fb 100644 --- a/code/functions/search-query/build.gradle +++ b/code/functions/search-query/build.gradle @@ -26,6 +26,9 @@ dependencies { implementation project(':code:libraries:term-frequency-dict') implementation project(':third-party:porterstemmer') + implementation project(':third-party:openzim') + implementation project(':third-party:commons-codec') + implementation project(':code:libraries:language-processing') implementation project(':code:libraries:term-frequency-dict') implementation project(':code:features-convert:keyword-extraction') @@ -36,6 +39,8 @@ dependencies { implementation libs.bundles.grpc implementation libs.notnull implementation libs.guice + implementation libs.jsoup + implementation libs.commons.lang3 implementation libs.trove implementation libs.fastutil implementation libs.bundles.gson diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/BasicSentenceExtractor.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/BasicSentenceExtractor.java new file mode 100644 index 00000000..e65c243d --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/BasicSentenceExtractor.java @@ -0,0 +1,16 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import ca.rmen.porterstemmer.PorterStemmer; +import org.apache.commons.lang3.StringUtils; + +public class BasicSentenceExtractor { + + private static PorterStemmer porterStemmer = new PorterStemmer(); + public static String[] getStemmedParts(String sentence) { + String[] parts = StringUtils.split(sentence, ' '); + for (int i = 0; i < parts.length; i++) { + parts[i] = porterStemmer.stemWord(parts[i]); + } + return parts; + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/HasherGroup.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/HasherGroup.java new file mode 100644 index 00000000..60bbb4dd --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/HasherGroup.java @@ -0,0 +1,61 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import nu.marginalia.hash.MurmurHash3_128; + +/** A group of hash functions that can be used to hash a sequence of strings, + * that also has an inverse operation that can be used to remove a previously applied + * string from the sequence. */ +sealed interface HasherGroup { + /** Apply a hash to the accumulator */ + long apply(long acc, long add); + + /** Remove a hash that was added n operations ago from the accumulator, add a new one */ + long replace(long acc, long add, long rem, int n); + + /** Create a new hasher group that preserves the order of appleid hash functions */ + static HasherGroup ordered() { + return new OrderedHasher(); + } + + /** Create a new hasher group that does not preserve the order of applied hash functions */ + static HasherGroup unordered() { + return new UnorderedHasher(); + } + + /** Bake the words in the sentence into a hash successively using the group's apply function */ + default long rollingHash(String[] parts) { + long code = 0; + for (String part : parts) { + code = apply(code, hash(part)); + } + return code; + } + + MurmurHash3_128 hash = new MurmurHash3_128(); + /** Calculate the hash of a string */ + static long hash(String term) { + return hash.hashNearlyASCII(term); + } + + final class UnorderedHasher implements HasherGroup { + + public long apply(long acc, long add) { + return acc ^ add; + } + + public long replace(long acc, long add, long rem, int n) { + return acc ^ rem ^ add; + } + } + + final class OrderedHasher implements HasherGroup { + + public long apply(long acc, long add) { + return Long.rotateLeft(acc, 1) ^ add; + } + + public long replace(long acc, long add, long rem, int n) { + return Long.rotateLeft(acc, 1) ^ add ^ Long.rotateLeft(rem, n); + } + } +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExporterMain.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExporterMain.java new file mode 100644 index 00000000..087345f6 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExporterMain.java @@ -0,0 +1,46 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import nu.marginalia.WmsaHome; +import nu.marginalia.language.sentence.SentenceExtractor; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Scanner; + +public class NgramExporterMain { + + public static void main(String... args) throws IOException { + trial(); + } + + static void trial() throws IOException { + SentenceExtractor se = new SentenceExtractor(WmsaHome.getLanguageModels()); + + NgramLexicon lexicon = new NgramLexicon(); + lexicon.loadCounts(Path.of("/home/vlofgren/ngram-counts.bin")); + + System.out.println("Loaded!"); + + var scanner = new Scanner(System.in); + for (;;) { + System.out.println("Enter a sentence: "); + String line = scanner.nextLine(); + System.out.println("."); + if (line == null) + break; + + String[] terms = BasicSentenceExtractor.getStemmedParts(line); + System.out.println("."); + + for (int i = 2; i< 8; i++) { + lexicon.findSegments(i, terms).forEach(p -> { + System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}"); + }); + } + + } + } + + +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExtractorMain.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExtractorMain.java new file mode 100644 index 00000000..0339b2c1 --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramExtractorMain.java @@ -0,0 +1,113 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import it.unimi.dsi.fastutil.longs.*; +import nu.marginalia.hash.MurmurHash3_128; +import org.jsoup.Jsoup; +import org.jsoup.nodes.Document; +import org.openzim.ZIMTypes.ZIMFile; +import org.openzim.ZIMTypes.ZIMReader; + +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; + +public class NgramExtractorMain { + static MurmurHash3_128 hash = new MurmurHash3_128(); + + public static void main(String... args) { + } + + private static List getNgramTerms(Document document) { + List terms = new ArrayList<>(); + + document.select("a[href]").forEach(e -> { + var href = e.attr("href"); + if (href.contains(":")) + return; + if (href.contains("/")) + return; + + var text = e.text().toLowerCase(); + if (!text.contains(" ")) + return; + + terms.add(text); + }); + + return terms; + } + + public static void dumpNgramsList( + Path zimFile, + Path ngramFile + ) throws IOException, InterruptedException { + ZIMReader reader = new ZIMReader(new ZIMFile(zimFile.toString())); + + PrintWriter printWriter = new PrintWriter(Files.newOutputStream(ngramFile, + StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE)); + + LongOpenHashSet known = new LongOpenHashSet(); + + try (var executor = Executors.newWorkStealingPool()) { + reader.forEachArticles((title, body) -> { + executor.submit(() -> { + var terms = getNgramTerms(Jsoup.parse(body)); + synchronized (known) { + for (String term : terms) { + if (known.add(hash.hashNearlyASCII(term))) { + printWriter.println(term); + } + } + } + }); + + }, p -> true); + } + printWriter.close(); + } + + public static void dumpCounts(Path zimInputFile, + Path countsOutputFile) throws IOException, InterruptedException + { + ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString())); + + NgramLexicon lexicon = new NgramLexicon(); + + var orderedHasher = HasherGroup.ordered(); + var unorderedHasher = HasherGroup.unordered(); + + try (var executor = Executors.newWorkStealingPool()) { + reader.forEachArticles((title, body) -> { + executor.submit(() -> { + LongArrayList orderedHashes = new LongArrayList(); + LongArrayList unorderedHashes = new LongArrayList(); + + for (var sent : getNgramTerms(Jsoup.parse(body))) { + String[] terms = BasicSentenceExtractor.getStemmedParts(sent); + + orderedHashes.add(orderedHasher.rollingHash(terms)); + unorderedHashes.add(unorderedHasher.rollingHash(terms)); + } + + synchronized (lexicon) { + for (var hash : orderedHashes) { + lexicon.incOrdered(hash); + } + for (var hash : unorderedHashes) { + lexicon.addUnordered(hash); + } + } + }); + + }, p -> true); + } + + lexicon.saveCounts(countsOutputFile); + } + +} diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramLexicon.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramLexicon.java new file mode 100644 index 00000000..948347bf --- /dev/null +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/segmentation/NgramLexicon.java @@ -0,0 +1,165 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap; +import it.unimi.dsi.fastutil.longs.LongHash; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +class NgramLexicon { + private final Long2IntOpenCustomHashMap counts = new Long2IntOpenCustomHashMap( + 100_000_000, + new KeyIsAlreadyHashStrategy() + ); + private final LongOpenHashSet permutations = new LongOpenHashSet(); + + private static final HasherGroup orderedHasher = HasherGroup.ordered(); + private static final HasherGroup unorderedHasher = HasherGroup.unordered(); + + public List findSegments(int length, String... parts) { + // Don't look for ngrams longer than the sentence + if (parts.length < length) return List.of(); + + List positions = new ArrayList<>(); + + // Hash the parts + long[] hashes = new long[parts.length]; + for (int i = 0; i < hashes.length; i++) { + hashes[i] = HasherGroup.hash(parts[i]); + } + + long ordered = 0; + long unordered = 0; + int i = 0; + + // Prepare by combining up to length hashes + for (; i < length; i++) { + ordered = orderedHasher.apply(ordered, hashes[i]); + unordered = unorderedHasher.apply(unordered, hashes[i]); + } + + // Slide the window and look for matches + for (;; i++) { + int ct = counts.get(ordered); + + if (ct > 0) { + positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM)); + } + else if (permutations.contains(unordered)) { + positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION)); + } + + if (i >= hashes.length) + break; + + // Remove the oldest hash and add the new one + ordered = orderedHasher.replace(ordered, + hashes[i], + hashes[i - length], + length); + unordered = unorderedHasher.replace(unordered, + hashes[i], + hashes[i - length], + length); + } + + return positions; + } + + public void incOrdered(long hashOrdered) { + counts.addTo(hashOrdered, 1); + } + public void addUnordered(long hashUnordered) { + permutations.add(hashUnordered); + } + + public void loadCounts(Path path) throws IOException { + try (var dis = new DataInputStream(Files.newInputStream(path))) { + long size = dis.readInt(); + + for (int i = 0; i < size; i++) { + counts.put(dis.readLong(), dis.readInt()); + } + } + } + + public void loadPermutations(Path path) throws IOException { + try (var dis = new DataInputStream(Files.newInputStream(path))) { + long size = dis.readInt(); + + for (int i = 0; i < size; i++) { + permutations.add(dis.readLong()); + } + } + } + + public void saveCounts(Path file) throws IOException { + try (var dos = new DataOutputStream(Files.newOutputStream(file, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE))) { + dos.writeInt(counts.size()); + + counts.forEach((k, v) -> { + try { + dos.writeLong(k); + dos.writeInt(v); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + } + public void savePermutations(Path file) throws IOException { + try (var dos = new DataOutputStream(Files.newOutputStream(file, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE))) { + dos.writeInt(counts.size()); + + permutations.forEach(v -> { + try { + dos.writeLong(v); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + } + public void clear() { + permutations.clear(); + counts.clear(); + } + + public record SentenceSegment(int start, int length, int count, PositionType type) { + public String[] project(String... parts) { + return Arrays.copyOfRange(parts, start, start + length); + } + } + + enum PositionType { + NGRAM, PERMUTATION + } + + private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { + @Override + public int hashCode(long l) { + return (int) l; + } + + @Override + public boolean equals(long l, long l1) { + return l == l1; + } + } + +} + diff --git a/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/HasherGroupTest.java b/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/HasherGroupTest.java new file mode 100644 index 00000000..174bd553 --- /dev/null +++ b/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/HasherGroupTest.java @@ -0,0 +1,33 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class HasherGroupTest { + + @Test + void ordered() { + long a = 5; + long b = 3; + long c = 2; + + var group = HasherGroup.ordered(); + assertNotEquals(group.apply(a, b), group.apply(b, a)); + assertEquals(group.apply(b,c), group.replace(group.apply(a, b), c, a, 2)); + } + + @Test + void unordered() { + long a = 5; + long b = 3; + long c = 2; + + var group = HasherGroup.unordered(); + + assertEquals(group.apply(a, b), group.apply(b, a)); + assertEquals(group.apply(b, c), group.replace(group.apply(a, b), c, a, 2)); + } + + +} diff --git a/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/NgramLexiconTest.java b/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/NgramLexiconTest.java new file mode 100644 index 00000000..28b9ef2f --- /dev/null +++ b/code/functions/search-query/test/nu/marginalia/functions/searchquery/segmentation/NgramLexiconTest.java @@ -0,0 +1,53 @@ +package nu.marginalia.functions.searchquery.segmentation; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class NgramLexiconTest { + NgramLexicon lexicon = new NgramLexicon(); + @BeforeEach + public void setUp() { + lexicon.clear(); + } + + void addNgram(String... ngram) { + lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram)); + lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram)); + } + + @Test + void findSegments() { + addNgram("hello", "world"); + addNgram("rye", "bread"); + addNgram("rye", "world"); + + String[] sent = { "hello", "world", "rye", "bread" }; + var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread"); + + assertEquals(3, segments.size()); + + for (int i = 0; i < 3; i++) { + var segment = segments.get(i); + switch (i) { + case 0 -> { + assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent)); + assertEquals(1, segment.count()); + assertEquals(NgramLexicon.PositionType.NGRAM, segment.type()); + } + case 1 -> { + assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent)); + assertEquals(0, segment.count()); + assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type()); + } + case 2 -> { + assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent)); + assertEquals(1, segment.count()); + assertEquals(NgramLexicon.PositionType.NGRAM, segment.type()); + } + } + } + + } +} \ No newline at end of file