(ngrams) Remove the vestigial logic for capturing permutations of n-grams

The change also reduces the object churn in NGramLexicon, as this is a very hot method in the converter.
This commit is contained in:
Viktor Lofgren 2024-04-11 18:12:01 +02:00
parent 8bf7d090fd
commit 7dd8c78c6b
6 changed files with 60 additions and 120 deletions

View File

@ -21,6 +21,7 @@ public class ExportSegmentationModelActor extends RecordActorPrototype {
private final Logger logger = LoggerFactory.getLogger(getClass()); private final Logger logger = LoggerFactory.getLogger(getClass());
public record Export(String zimFile) implements ActorStep {} public record Export(String zimFile) implements ActorStep {}
@Override @Override
public ActorStep transition(ActorStep self) throws Exception { public ActorStep transition(ActorStep self) throws Exception {
return switch(self) { return switch(self) {
@ -29,9 +30,8 @@ public class ExportSegmentationModelActor extends RecordActorPrototype {
var storage = storageService.allocateStorage(FileStorageType.EXPORT, "segmentation-model", "Segmentation Model Export " + LocalDateTime.now()); var storage = storageService.allocateStorage(FileStorageType.EXPORT, "segmentation-model", "Segmentation Model Export " + LocalDateTime.now());
Path countsFile = storage.asPath().resolve("ngram-counts.bin"); Path countsFile = storage.asPath().resolve("ngram-counts.bin");
Path permutationsFile = storage.asPath().resolve("ngram-permutations.bin");
NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile, permutationsFile); NgramExtractorMain.dumpCounts(Path.of(zimFile), countsFile);
yield new End(); yield new End();
} }

View File

@ -112,10 +112,15 @@ public class QueryExpansion {
// Look for known segments within the query // Look for known segments within the query
for (int length = 2; length < Math.min(10, words.length); length++) { for (int length = 2; length < Math.min(10, words.length); length++) {
for (var segment : lexicon.findSegments(length, words)) { for (var segment : lexicon.findSegmentOffsets(length, words)) {
int start = segment.start(); int start = segment.start();
int end = segment.start() + segment.length(); int end = segment.start() + segment.length();
var word = IntStream.range(start, end).mapToObj(nodes::get).map(QWord::word).collect(Collectors.joining("_"));
var word = IntStream.range(start, end)
.mapToObj(nodes::get)
.map(QWord::word)
.collect(Collectors.joining("_"));
graph.addVariantForSpan(nodes.get(start), nodes.get(end - 1), word); graph.addVariantForSpan(nodes.get(start), nodes.get(end - 1), word);
} }

View File

@ -1,46 +0,0 @@
package nu.marginalia.segmentation;
import nu.marginalia.LanguageModels;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Scanner;
public class NgramExporterMain {
public static void main(String... args) throws IOException {
trial();
}
static void trial() throws IOException {
NgramLexicon lexicon = new NgramLexicon(
LanguageModels.builder()
.segments(Path.of("/home/vlofgren/ngram-counts.bin"))
.build()
);
System.out.println("Loaded!");
var scanner = new Scanner(System.in);
for (;;) {
System.out.println("Enter a sentence: ");
String line = scanner.nextLine();
System.out.println(".");
if (line == null)
break;
String[] terms = BasicSentenceExtractor.getStemmedParts(line);
System.out.println(".");
for (int i = 2; i< 8; i++) {
lexicon.findSegments(i, terms).forEach(p -> {
System.out.println(STR."\{Arrays.toString(p.project(terms))}: \{p.count()}");
});
}
}
}
}

View File

@ -115,8 +115,7 @@ public class NgramExtractorMain {
} }
public static void dumpCounts(Path zimInputFile, public static void dumpCounts(Path zimInputFile,
Path countsOutputFile, Path countsOutputFile
Path permutationsOutputFile
) throws IOException, InterruptedException ) throws IOException, InterruptedException
{ {
ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString())); ZIMReader reader = new ZIMReader(new ZIMFile(zimInputFile.toString()));
@ -143,9 +142,6 @@ public class NgramExtractorMain {
for (var hash : orderedHashes) { for (var hash : orderedHashes) {
lexicon.incOrdered(hash); lexicon.incOrdered(hash);
} }
for (var hash : unorderedHashes) {
lexicon.addUnordered(hash);
}
} }
}); });
@ -153,7 +149,6 @@ public class NgramExtractorMain {
} }
lexicon.saveCounts(countsOutputFile); lexicon.saveCounts(countsOutputFile);
lexicon.savePermutations(permutationsOutputFile);
} }
} }

View File

@ -21,10 +21,8 @@ import java.util.List;
@Singleton @Singleton
public class NgramLexicon { public class NgramLexicon {
private final Long2IntOpenCustomHashMap counts; private final Long2IntOpenCustomHashMap counts;
private final LongOpenHashSet permutations = new LongOpenHashSet();
private static final HasherGroup orderedHasher = HasherGroup.ordered(); private static final HasherGroup orderedHasher = HasherGroup.ordered();
private static final HasherGroup unorderedHasher = HasherGroup.unordered();
@Inject @Inject
public NgramLexicon(LanguageModels models) { public NgramLexicon(LanguageModels models) {
@ -48,16 +46,57 @@ public class NgramLexicon {
} }
public List<String[]> findSegmentsStrings(int minLength, int maxLength, String... parts) { public List<String[]> findSegmentsStrings(int minLength, int maxLength, String... parts) {
List<SentenceSegment> segments = new ArrayList<>(); List<String[]> segments = new ArrayList<>();
for (int i = minLength; i <= maxLength; i++) { for (int i = minLength; i <= maxLength; i++) {
segments.addAll(findSegments(i, parts)); segments.addAll(findSegments(i, parts));
} }
return segments.stream().map(seg -> seg.project(parts)).toList(); return segments;
} }
public List<SentenceSegment> findSegments(int length, String... parts) { public List<String[]> findSegments(int length, String... parts) {
// Don't look for ngrams longer than the sentence
if (parts.length < length) return List.of();
List<String[]> positions = new ArrayList<>();
// Hash the parts
long[] hashes = new long[parts.length];
for (int i = 0; i < hashes.length; i++) {
hashes[i] = HasherGroup.hash(parts[i]);
}
long ordered = 0;
int i = 0;
// Prepare by combining up to length hashes
for (; i < length; i++) {
ordered = orderedHasher.apply(ordered, hashes[i]);
}
// Slide the window and look for matches
for (;; i++) {
int ct = counts.get(ordered);
if (ct > 0) {
positions.add(Arrays.copyOfRange(parts, i - length, length));
}
if (i >= hashes.length)
break;
// Remove the oldest hash and add the new one
ordered = orderedHasher.replace(ordered,
hashes[i],
hashes[i - length],
length);
}
return positions;
}
public List<SentenceSegment> findSegmentOffsets(int length, String... parts) {
// Don't look for ngrams longer than the sentence // Don't look for ngrams longer than the sentence
if (parts.length < length) return List.of(); if (parts.length < length) return List.of();
@ -70,13 +109,11 @@ public class NgramLexicon {
} }
long ordered = 0; long ordered = 0;
long unordered = 0;
int i = 0; int i = 0;
// Prepare by combining up to length hashes // Prepare by combining up to length hashes
for (; i < length; i++) { for (; i < length; i++) {
ordered = orderedHasher.apply(ordered, hashes[i]); ordered = orderedHasher.apply(ordered, hashes[i]);
unordered = unorderedHasher.apply(unordered, hashes[i]);
} }
// Slide the window and look for matches // Slide the window and look for matches
@ -84,10 +121,7 @@ public class NgramLexicon {
int ct = counts.get(ordered); int ct = counts.get(ordered);
if (ct > 0) { if (ct > 0) {
positions.add(new SentenceSegment(i - length, length, ct, PositionType.NGRAM)); positions.add(new SentenceSegment(i - length, length, ct));
}
else if (permutations.contains(unordered)) {
positions.add(new SentenceSegment(i - length, length, 0, PositionType.PERMUTATION));
} }
if (i >= hashes.length) if (i >= hashes.length)
@ -98,10 +132,6 @@ public class NgramLexicon {
hashes[i], hashes[i],
hashes[i - length], hashes[i - length],
length); length);
unordered = unorderedHasher.replace(unordered,
hashes[i],
hashes[i - length],
length);
} }
return positions; return positions;
@ -110,20 +140,6 @@ public class NgramLexicon {
public void incOrdered(long hashOrdered) { public void incOrdered(long hashOrdered) {
counts.addTo(hashOrdered, 1); counts.addTo(hashOrdered, 1);
} }
public void addUnordered(long hashUnordered) {
permutations.add(hashUnordered);
}
public void loadPermutations(Path path) throws IOException {
try (var dis = new DataInputStream(Files.newInputStream(path))) {
long size = dis.readInt();
for (int i = 0; i < size; i++) {
permutations.add(dis.readLong());
}
}
}
public void saveCounts(Path file) throws IOException { public void saveCounts(Path file) throws IOException {
try (var dos = new DataOutputStream(Files.newOutputStream(file, try (var dos = new DataOutputStream(Files.newOutputStream(file,
@ -142,37 +158,17 @@ public class NgramLexicon {
}); });
} }
} }
public void savePermutations(Path file) throws IOException {
try (var dos = new DataOutputStream(Files.newOutputStream(file,
StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING,
StandardOpenOption.WRITE))) {
dos.writeInt(counts.size());
permutations.forEach(v -> {
try {
dos.writeLong(v);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
}
public void clear() { public void clear() {
permutations.clear();
counts.clear(); counts.clear();
} }
public record SentenceSegment(int start, int length, int count, PositionType type) { public record SentenceSegment(int start, int length, int count) {
public String[] project(String... parts) { public String[] project(String... parts) {
return Arrays.copyOfRange(parts, start, start + length); return Arrays.copyOfRange(parts, start, start + length);
} }
} }
enum PositionType {
NGRAM, PERMUTATION
}
private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy { private static class KeyIsAlreadyHashStrategy implements LongHash.Strategy {
@Override @Override
public int hashCode(long l) { public int hashCode(long l) {

View File

@ -14,7 +14,6 @@ class NgramLexiconTest {
void addNgram(String... ngram) { void addNgram(String... ngram) {
lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram)); lexicon.incOrdered(HasherGroup.ordered().rollingHash(ngram));
lexicon.addUnordered(HasherGroup.unordered().rollingHash(ngram));
} }
@Test @Test
@ -26,25 +25,16 @@ class NgramLexiconTest {
String[] sent = { "hello", "world", "rye", "bread" }; String[] sent = { "hello", "world", "rye", "bread" };
var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread"); var segments = lexicon.findSegments(2, "hello", "world", "rye", "bread");
assertEquals(3, segments.size()); assertEquals(2, segments.size());
for (int i = 0; i < 3; i++) { for (int i = 0; i < 2; i++) {
var segment = segments.get(i); var segment = segments.get(i);
switch (i) { switch (i) {
case 0 -> { case 0 -> {
assertArrayEquals(new String[]{"hello", "world"}, segment.project(sent)); assertArrayEquals(new String[]{"hello", "world"}, segment);
assertEquals(1, segment.count());
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
} }
case 1 -> { case 1 -> {
assertArrayEquals(new String[]{"world", "rye"}, segment.project(sent)); assertArrayEquals(new String[]{"rye", "bread"}, segment);
assertEquals(0, segment.count());
assertEquals(NgramLexicon.PositionType.PERMUTATION, segment.type());
}
case 2 -> {
assertArrayEquals(new String[]{"rye", "bread"}, segment.project(sent));
assertEquals(1, segment.count());
assertEquals(NgramLexicon.PositionType.NGRAM, segment.type());
} }
} }
} }