diff --git a/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java b/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java index 90baf009..98cf114e 100644 --- a/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java +++ b/code/execution/java/nu/marginalia/actor/task/ExportSegmentationModelActor.java @@ -21,6 +21,7 @@ public class ExportSegmentationModelActor extends RecordActorPrototype { private final Logger logger = LoggerFactory.getLogger(getClass()); public record Export(String zimFile) implements ActorStep {} + @Override public ActorStep transition(ActorStep self) throws Exception { return switch(self) { @@ -29,9 +30,8 @@ public class ExportSegmentationModelActor extends RecordActorPrototype { var storage = storageService.allocateStorage(FileStorageType.EXPORT, "segmentation-model", "Segmentation Model Export " + LocalDateTime.now()); Path countsFile = storage.asPath().resolve("ngram-counts.bin"); - Path permutationsFile = storage.asPath().resolve("ngram-permutations.bin"); - NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile, permutationsFile); + NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile); yield new End(); } diff --git a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java index 052516d8..9c9d81fa 100644 --- a/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java +++ b/code/functions/search-query/java/nu/marginalia/functions/searchquery/query_parser/QueryExpansion.java @@ -112,10 +112,15 @@ public class QueryExpansion { // Look for known segments within the query for (int length = 2; length < Math.min(10, words.length); length++) { - for (var segment : lexicon.findSegments(length, words)) { + for (var segment : lexicon.findSegmentOffsets(length, words)) { + int start = segment.start(); int end = segment.start() + segment.length(); - var word = IntStream.range(start, end).mapToObj(nodes::get).map(QWord::word).collect(Collectors.joining("_")); + + var word = IntStream.range(start, end) + .mapToObj(nodes::get) + .map(QWord::word) + .collect(Collectors.joining("_")); graph.addVariantForSpan(nodes.get(start), nodes.get(end - 1), word); } diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExporterMain.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExporterMain.java deleted file mode 100644 index ee6d2cd5..00000000 --- a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExporterMain.java +++ /dev/null @@ -1,46 +0,0 @@ -package nu.marginalia.segmentation; - -import nu.marginalia.LanguageModels; - -import java.io.IOException; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.Scanner; - -public class NgramExporterMain { - - public static void main(String... args) throws IOException { - trial(); - } - - static void trial() throws IOException { - NgramLexicon lexicon = new NgramLexicon( - LanguageModels.builder() - .segments(Path.of("/home/vlofgren/ngram-counts.bin")) - .build() - ); - - System.out.println("Loaded!"); - - var scanner = new Scanner(System.in); - for (;;) { - System.out.println("Enter a sentence: "); - String line = scanner.nextLine(); - System.out.println("."); - if (line == null) - break; - - String[] terms = BasicSentenceExtractor.getStemmedParts(line); - System.out.println("."); - - for (int i = 2; i< 8; i++) { - lexicon.findSegments(i, terms).forEach(p -> { - System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}"); - }); - } - - } - } - - -} diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java index 577aee6e..3f29c74c 100644 --- a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramExtractorMain.java @@ -115,8 +115,7 @@ public class NgramExtractorMain { } public static void dumpCounts(Path zimInputFile, - Path countsOutputFile, - Path permutationsOutputFile + Path countsOutputFile ) throws IOException, InterruptedException { ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString())); @@ -143,9 +142,6 @@ public class NgramExtractorMain { for (var hash : orderedHashes) { lexicon.incOrdered(hash); } - for (var hash : unorderedHashes) { - lexicon.addUnordered(hash); - } } }); @@ -153,7 +149,6 @@ public class NgramExtractorMain { } lexicon.saveCounts(countsOutputFile); - lexicon.savePermutations(permutationsOutputFile); } } diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java index 91cee314..e7dc1017 100644 --- a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java @@ -21,10 +21,8 @@ import java.util.List; @Singleton public class NgramLexicon { private final Long2IntOpenCustomHashMap counts; - private final LongOpenHashSet permutations = new LongOpenHashSet(); private static final HasherGroup orderedHasher = HasherGroup.ordered(); - private static final HasherGroup unorderedHasher = HasherGroup.unordered(); @Inject public NgramLexicon(LanguageModels models) { @@ -48,16 +46,57 @@ public class NgramLexicon { } public List findSegmentsStrings(int minLength, int maxLength, String... parts) { - List segments = new ArrayList<>(); + List segments = new ArrayList<>(); for (int i = minLength; i <= maxLength; i++) { segments.addAll(findSegments(i, parts)); } - return segments.stream().map(seg -> seg.project(parts)).toList(); + return segments; } - public List findSegments(int length, String... parts) { + public List findSegments(int length, String... parts) { + // Don't look for ngrams longer than the sentence + if (parts.length < length) return List.of(); + + List positions = new ArrayList<>(); + + // Hash the parts + long[] hashes = new long[parts.length]; + for (int i = 0; i < hashes.length; i++) { + hashes[i] = HasherGroup.hash(parts[i]); + } + + long ordered = 0; + int i = 0; + + // Prepare by combining up to length hashes + for (; i < length; i++) { + ordered = orderedHasher.apply(ordered, hashes[i]); + } + + // Slide the window and look for matches + for (;; i++) { + int ct = counts.get(ordered); + + if (ct > 0) { + positions.add(Arrays.copyOfRange(parts, i - length, length)); + } + + if (i >= hashes.length) + break; + + // Remove the oldest hash and add the new one + ordered = orderedHasher.replace(ordered, + hashes[i], + hashes[i - length], + length); + } + + return positions; + } + + public List findSegmentOffsets(int length, String... parts) { // Don't look for ngrams longer than the sentence if (parts.length < length) return List.of(); @@ -70,13 +109,11 @@ public class NgramLexicon { } long ordered = 0; - long unordered = 0; int i = 0; // Prepare by combining up to length hashes for (; i < length; i++) { ordered = orderedHasher.apply(ordered, hashes[i]); - unordered = unorderedHasher.apply(unordered, hashes[i]); } // Slide the window and look for matches @@ -84,10 +121,7 @@ public class NgramLexicon { int ct = counts.get(ordered); if (ct > 0) { - positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM)); - } - else if (permutations.contains(unordered)) { - positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION)); + positions.add(new SentenceSegment(i - length, length, ct)); } if (i >= hashes.length) @@ -98,10 +132,6 @@ public class NgramLexicon { hashes[i], hashes[i - length], length); - unordered = unorderedHasher.replace(unordered, - hashes[i], - hashes[i - length], - length); } return positions; @@ -110,20 +140,6 @@ public class NgramLexicon { public void incOrdered(long hashOrdered) { counts.addTo(hashOrdered, 1); } - public void addUnordered(long hashUnordered) { - permutations.add(hashUnordered); - } - - - public void loadPermutations(Path path) throws IOException { - try (var dis = new DataInputStream(Files.newInputStream(path))) { - long size = dis.readInt(); - - for (int i = 0; i < size; i++) { - permutations.add(dis.readLong()); - } - } - } public void saveCounts(Path file) throws IOException { try (var dos = new DataOutputStream(Files.newOutputStream(file, @@ -142,37 +158,17 @@ public class NgramLexicon { }); } } - public void savePermutations(Path file) throws IOException { - try (var dos = new DataOutputStream(Files.newOutputStream(file, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING, - StandardOpenOption.WRITE))) { - dos.writeInt(counts.size()); - permutations.forEach(v -> { - try { - dos.writeLong(v); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - } public void clear() { - permutations.clear(); counts.clear(); } - public record SentenceSegment(int start, int length, int count, PositionType type) { + public record SentenceSegment(int start, int length, int count) { public String[] project(String... parts) { return Arrays.copyOfRange(parts, start, start + length); } } - enum PositionType { - NGRAM, PERMUTATION - } - private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { @Override public int hashCode(long l) { diff --git a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java index d5065959..351ce869 100644 --- a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java +++ b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java @@ -14,7 +14,6 @@ class NgramLexiconTest { void addNgram(String... ngram) { lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram)); - lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram)); } @Test @@ -26,25 +25,16 @@ class NgramLexiconTest { String[] sent = { "hello", "world", "rye", "bread" }; var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread"); - assertEquals(3, segments.size()); + assertEquals(2, segments.size()); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 2; i++) { var segment = segments.get(i); switch (i) { case 0 -> { - assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent)); - assertEquals(1, segment.count()); - assertEquals(NgramLexicon.PositionType.NGRAM, segment.type()); + assertArrayEquals(new String[]{"hello", "world"}, segment); } case 1 -> { - assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent)); - assertEquals(0, segment.count()); - assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type()); - } - case 2 -> { - assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent)); - assertEquals(1, segment.count()); - assertEquals(NgramLexicon.PositionType.NGRAM, segment.type()); + assertArrayEquals(new String[]{"rye", "bread"}, segment); } } }