diff --git a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java index 7a6beeb8..5a82ab3e 100644 --- a/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java +++ b/code/libraries/term-frequency-dict/java/nu/marginalia/segmentation/NgramLexicon.java @@ -4,7 +4,6 @@ import com.google.inject.Inject; import com.google.inject.Singleton; import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap; import it.unimi.dsi.fastutil.longs.LongHash; -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import nu.marginalia.LanguageModels; import java.io.BufferedInputStream; @@ -45,55 +44,54 @@ public class NgramLexicon { counts = new Long2IntOpenCustomHashMap(100_000_000, new KeyIsAlreadyHashStrategy()); } - public List findSegmentsStrings(int minLength, int maxLength, String... parts) { + public List findSegmentsStrings(int minLength, + int maxLength, + String... parts) + { List segments = new ArrayList<>(); - for (int i = minLength; i <= maxLength; i++) { - segments.addAll(findSegments(i, parts)); - } - - return segments; - } - - public List findSegments(int length, String... parts) { - // Don't look for ngrams longer than the sentence - if (parts.length < length) return List.of(); - - List positions = new ArrayList<>(); - // Hash the parts long[] hashes = new long[parts.length]; for (int i = 0; i < hashes.length; i++) { hashes[i] = HasherGroup.hash(parts[i]); } - long ordered = 0; + for (int i = minLength; i <= maxLength; i++) { + findSegments(segments, i, parts, hashes); + } + + return segments; + } + + public void findSegments(List positions, + int length, + String[] parts, + long[] hashes) + { + // Don't look for ngrams longer than the sentence + if (parts.length < length) return; + + long hash = 0; int i = 0; // Prepare by combining up to length hashes for (; i < length; i++) { - ordered = orderedHasher.apply(ordered, hashes[i]); + hash = orderedHasher.apply(hash, hashes[i]); } // Slide the window and look for matches - for (;; i++) { - int ct = counts.get(ordered); - - if (ct > 0) { + for (;;) { + if (counts.get(hash) > 0) { positions.add(Arrays.copyOfRange(parts, i - length, i)); } - if (i >= hashes.length) + if (i < hashes.length) { + hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length); + i++; + } else { break; - - // Remove the oldest hash and add the new one - ordered = orderedHasher.replace(ordered, - hashes[i], - hashes[i - length], - length); + } } - - return positions; } public List findSegmentOffsets(int length, String... parts) { @@ -108,30 +106,28 @@ public class NgramLexicon { hashes[i] = HasherGroup.hash(parts[i]); } - long ordered = 0; + long hash = 0; int i = 0; // Prepare by combining up to length hashes for (; i < length; i++) { - ordered = orderedHasher.apply(ordered, hashes[i]); + hash = orderedHasher.apply(hash, hashes[i]); } // Slide the window and look for matches - for (;; i++) { - int ct = counts.get(ordered); + for (;;) { + int ct = counts.get(hash); if (ct > 0) { positions.add(new SentenceSegment(i - length, length, ct)); } - if (i >= hashes.length) + if (i < hashes.length) { + hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length); + i++; + } else { break; - - // Remove the oldest hash and add the new one - ordered = orderedHasher.replace(ordered, - hashes[i], - hashes[i - length], - length); + } } return positions; @@ -167,6 +163,10 @@ public class NgramLexicon { public String[] project(String... parts) { return Arrays.copyOfRange(parts, start, start + length); } + + public boolean overlaps(SentenceSegment other) { + return start < other.start + other.length && start + length > other.start; + } } private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { diff --git a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java index 351ce869..f5068d07 100644 --- a/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java +++ b/code/libraries/term-frequency-dict/test/nu/marginalia/segmentation/NgramLexiconTest.java @@ -3,6 +3,8 @@ package nu.marginalia.segmentation; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.List; + import static org.junit.jupiter.api.Assertions.*; class NgramLexiconTest { @@ -22,8 +24,7 @@ class NgramLexiconTest { addNgram("rye", "bread"); addNgram("rye", "world"); - String[] sent = { "hello", "world", "rye", "bread" }; - var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread"); + List segments = lexicon.findSegmentsStrings(2, 2, "hello", "world", "rye", "bread"); assertEquals(2, segments.size());