(ngram) Clean up ngram lexicon code

This is both an optimization that removes some GC churn, as well as a clean-up of the code that removes references to outdated concepts.
This commit is contained in:
Viktor Lofgren 2024-04-12 17:45:06 +02:00
parent c96da0ce1e
commit 150ee21f3c
2 changed files with 44 additions and 43 deletions

View File

@ -4,7 +4,6 @@ import com.google.inject.Inject;
import com.google.inject.Singleton; import com.google.inject.Singleton;
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
import it.unimi.dsi.fastutil.longs.LongHash; import it.unimi.dsi.fastutil.longs.LongHash;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import nu.marginalia.LanguageModels; import nu.marginalia.LanguageModels;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
@ -45,55 +44,54 @@ public class NgramLexicon {
counts = new Long2IntOpenCustomHashMap(100_000_000, new KeyIsAlreadyHashStrategy()); counts = new Long2IntOpenCustomHashMap(100_000_000, new KeyIsAlreadyHashStrategy());
} }
public List<String[]> findSegmentsStrings(int minLength, int maxLength, String... parts) { public List<String[]> findSegmentsStrings(int minLength,
int maxLength,
String... parts)
{
List<String[]> segments = new ArrayList<>(); List<String[]> segments = new ArrayList<>();
for (int i = minLength; i <= maxLength; i++) {
segments.addAll(findSegments(i, parts));
}
return segments;
}
public List<String[]> findSegments(int length, String... parts) {
// Don't look for ngrams longer than the sentence
if (parts.length < length) return List.of();
List<String[]> positions = new ArrayList<>();
// Hash the parts // Hash the parts
long[] hashes = new long[parts.length]; long[] hashes = new long[parts.length];
for (int i = 0; i < hashes.length; i++) { for (int i = 0; i < hashes.length; i++) {
hashes[i] = HasherGroup.hash(parts[i]); hashes[i] = HasherGroup.hash(parts[i]);
} }
long ordered = 0; for (int i = minLength; i <= maxLength; i++) {
findSegments(segments, i, parts, hashes);
}
return segments;
}
public void findSegments(List<String[]> positions,
int length,
String[] parts,
long[] hashes)
{
// Don't look for ngrams longer than the sentence
if (parts.length < length) return;
long hash = 0;
int i = 0; int i = 0;
// Prepare by combining up to length hashes // Prepare by combining up to length hashes
for (; i < length; i++) { for (; i < length; i++) {
ordered = orderedHasher.apply(ordered, hashes[i]); hash = orderedHasher.apply(hash, hashes[i]);
} }
// Slide the window and look for matches // Slide the window and look for matches
for (;; i++) { for (;;) {
int ct = counts.get(ordered); if (counts.get(hash) > 0) {
if (ct > 0) {
positions.add(Arrays.copyOfRange(parts, i - length, i)); positions.add(Arrays.copyOfRange(parts, i - length, i));
} }
if (i >= hashes.length) if (i < hashes.length) {
hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length);
i++;
} else {
break; break;
// Remove the oldest hash and add the new one
ordered = orderedHasher.replace(ordered,
hashes[i],
hashes[i - length],
length);
} }
}
return positions;
} }
public List<SentenceSegment> findSegmentOffsets(int length, String... parts) { public List<SentenceSegment> findSegmentOffsets(int length, String... parts) {
@ -108,30 +106,28 @@ public class NgramLexicon {
hashes[i] = HasherGroup.hash(parts[i]); hashes[i] = HasherGroup.hash(parts[i]);
} }
long ordered = 0; long hash = 0;
int i = 0; int i = 0;
// Prepare by combining up to length hashes // Prepare by combining up to length hashes
for (; i < length; i++) { for (; i < length; i++) {
ordered = orderedHasher.apply(ordered, hashes[i]); hash = orderedHasher.apply(hash, hashes[i]);
} }
// Slide the window and look for matches // Slide the window and look for matches
for (;; i++) { for (;;) {
int ct = counts.get(ordered); int ct = counts.get(hash);
if (ct > 0) { if (ct > 0) {
positions.add(new SentenceSegment(i - length, length, ct)); positions.add(new SentenceSegment(i - length, length, ct));
} }
if (i >= hashes.length) if (i < hashes.length) {
hash = orderedHasher.replace(hash, hashes[i], hashes[i - length], length);
i++;
} else {
break; break;
}
// Remove the oldest hash and add the new one
ordered = orderedHasher.replace(ordered,
hashes[i],
hashes[i - length],
length);
} }
return positions; return positions;
@ -167,6 +163,10 @@ public class NgramLexicon {
public String[] project(String... parts) { public String[] project(String... parts) {
return Arrays.copyOfRange(parts, start, start + length); return Arrays.copyOfRange(parts, start, start + length);
} }
public boolean overlaps(SentenceSegment other) {
return start < other.start + other.length && start + length > other.start;
}
} }
private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy {

View File

@ -3,6 +3,8 @@ package nu.marginalia.segmentation;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
class NgramLexiconTest { class NgramLexiconTest {
@ -22,8 +24,7 @@ class NgramLexiconTest {
addNgram("rye", "bread"); addNgram("rye", "bread");
addNgram("rye", "world"); addNgram("rye", "world");
String[] sent = { "hello", "world", "rye", "bread" }; List<String[]> segments = lexicon.findSegmentsStrings(2, 2, "hello", "world", "rye", "bread");
var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread");
assertEquals(2, segments.size()); assertEquals(2, segments.size());