Optimize RDRPosTagger to use integer comparisons instead of string comparisons.

Also reduce the cache-thrashing by deconstructing the tree's nodes into arrays.
This commit is contained in:
Viktor Lofgren 2023-06-12 17:42:30 +02:00 committed by Viktor
parent 6f2a7977c1
commit 186a02acfd
5 changed files with 115 additions and 53 deletions

View File

@ -9,6 +9,7 @@ java {
} }
dependencies { dependencies {
implementation libs.trove
} }
test { test {

View File

@ -13,6 +13,9 @@ import java.util.Arrays;
public class FWObject public class FWObject
{ {
public String[] context; public String[] context;
int[] objectCtxI = new int[13];
private final static String[] contextPrototype; private final static String[] contextPrototype;
static { static {
contextPrototype = new String[13]; contextPrototype = new String[13];

View File

@ -40,28 +40,14 @@ public class Node
this.fatherNode = node; this.fatherNode = node;
} }
public int countNodes()
{
int count = 1;
if (exceptNode != null) {
count += exceptNode.countNodes();
}
if (ifnotNode != null) {
count += ifnotNode.countNodes();
}
return count;
}
public boolean satisfy(FWObject object) public boolean satisfy(FWObject object)
{ {
for (int i = 0; i < 13; i++) { for (int i = 0; i < 13; i++) {
String key = condition.context[i]; String key = condition.context[i];
if (key != null) { if (key != null && !key.equals(object.context[i])) { // this is not equivalent to Objects.equals(a,b)
if (!key.equals(object.context[i])) {
return false; return false;
} }
} }
}
return true; return true;
} }
} }

View File

@ -1,8 +1,11 @@
package com.github.datquocnguyen; package com.github.datquocnguyen;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.*; import java.io.*;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
/** /**
@ -12,22 +15,59 @@ import java.util.HashMap;
public class RDRPOSTagger public class RDRPOSTagger
{ {
private final HashMap<String, String> FREQDICT; private final HashMap<String, String> FREQDICT;
public final Node root; final int OUGHT_TO_BE_ENOUGH = 5000;
final int CONTEXT_SIZE = 13;
// Use dense array representation to reduce the level of indirection
// and improve the performance of the tagger
int[] conditions = new int[OUGHT_TO_BE_ENOUGH * CONTEXT_SIZE];
String[] conclusions = new String[OUGHT_TO_BE_ENOUGH];
short[] exceptIdx = new short[OUGHT_TO_BE_ENOUGH];
short[] ifNotIdx = new short[OUGHT_TO_BE_ENOUGH];
short[] fatherIdx = new short[OUGHT_TO_BE_ENOUGH];
byte[] depthL = new byte[OUGHT_TO_BE_ENOUGH];
short size = 0;
private final TObjectIntHashMap<String> tagDict = new TObjectIntHashMap<>(10000, 0.75f, -1);
private short addNode(FWObject condition, String conclusion, byte d) {
short idx = size++;
for (int i = 0; i < CONTEXT_SIZE; i++) {
String context = condition.context[i];
if (context != null) {
tagDict.putIfAbsent(context, tagDict.size());
conditions[idx * CONTEXT_SIZE + i] = tagDict.get(context);
}
else {
conditions[idx * CONTEXT_SIZE + i] = -1;
}
}
conclusions[idx] = conclusion;
exceptIdx[idx] = -1;
ifNotIdx[idx] = -1;
fatherIdx[idx] = -1;
depthL[idx] = d;
return idx;
}
public RDRPOSTagger(Path dictPath, Path rulesFilePath) throws IOException { public RDRPOSTagger(Path dictPath, Path rulesFilePath) throws IOException {
this.FREQDICT = Utils.getDictionary(dictPath.toString()); this.FREQDICT = Utils.getDictionary(dictPath.toString());
Arrays.fill(conditions, -1);
BufferedReader buffer = new BufferedReader(new InputStreamReader( BufferedReader buffer = new BufferedReader(new InputStreamReader(
new FileInputStream(rulesFilePath.toFile()), StandardCharsets.UTF_8)); new FileInputStream(rulesFilePath.toFile()), StandardCharsets.UTF_8));
String line = buffer.readLine(); String line = buffer.readLine();
this.root = new Node(new FWObject(false), "NN", null, null, null, 0); short currentIdx = addNode(new FWObject(false), "NN", (byte) 0);
byte currentDepth = 0;
Node currentNode = this.root;
int currentDepth = 0;
while ((line = buffer.readLine()) != null) { while ((line = buffer.readLine()) != null) {
int depth = 0; byte depth = 0;
for (int i = 0; i <= 6; i++) { // Supposed that the maximum for (int i = 0; i <= 6; i++) { // Supposed that the maximum
// exception level is up to 6. // exception level is up to 6.
if (line.charAt(i) == '\t') if (line.charAt(i) == '\t')
@ -48,53 +88,72 @@ public class RDRPOSTagger
String conclusion = Utils.getConcreteValue(line.split(" : ")[1] String conclusion = Utils.getConcreteValue(line.split(" : ")[1]
.trim()); .trim());
Node node = new Node(condition, conclusion, null, null, null, depth); short newIdx = addNode(condition, conclusion, depth);
if (depth > currentDepth) { if (depth > currentDepth) {
currentNode.setExceptNode(node); exceptIdx[currentIdx] = newIdx;
} }
else if (depth == currentDepth) { else if (depth == currentDepth) {
currentNode.setIfnotNode(node); ifNotIdx[currentIdx] = newIdx;
} }
else { else {
while (currentNode.depth != depth) while (depthL[currentIdx] != depth) {
currentNode = currentNode.fatherNode; currentIdx = fatherIdx[currentIdx];
currentNode.setIfnotNode(node); }
ifNotIdx[currentIdx] = newIdx;
} }
node.setFatherNode(currentNode);
currentNode = node; fatherIdx[newIdx] = currentIdx;
currentIdx = newIdx;
currentDepth = depth; currentDepth = depth;
} }
buffer.close(); buffer.close();
} }
public Node findFiredNode(FWObject object) public String findFiredNode(FWObject object)
{ {
Node currentN = root; int currentIdx = 0;
Node firedN = null; int firedIdx = -1;
while (true) {
if (currentN.satisfy(object)) { int[] objCtxI = object.objectCtxI;
firedN = currentN;
if (currentN.exceptNode == null) { for (int i = 0; i < CONTEXT_SIZE; i++) {
break; objCtxI[i] = tagDict.get(object.context[i]);
}
int[] conditionsL = conditions;
short[] exceptIdxL = exceptIdx;
short[] ifNotIdxL = ifNotIdx;
while (currentIdx >= 0) {
if (satisfy(objCtxI, conditionsL, currentIdx)) {
firedIdx = currentIdx;
currentIdx = exceptIdxL[currentIdx];
} }
else { else {
currentN = currentN.exceptNode; currentIdx = ifNotIdxL[currentIdx];
}
}
else {
if (currentN.ifnotNode == null) {
break;
}
else {
currentN = currentN.ifnotNode;
} }
} }
if (firedIdx >= 0) {
return conclusions[firedIdx];
}
else {
return "";
}
} }
return firedN; public boolean satisfy(int[] objectCtxI, int[] conditions, int contextIdx)
{
// This is a good candidate for a vector operation
for (int i = 0; i < CONTEXT_SIZE; i++) {
int key = conditions[CONTEXT_SIZE *contextIdx + i];
if (key >= 0 && key != objectCtxI[i]) {
return false;
}
}
return true;
} }
public String[] tagsForEnSentence(String[] sentence) public String[] tagsForEnSentence(String[] sentence)
@ -107,7 +166,7 @@ public class RDRPOSTagger
for (int i = 0; i < initialTags.length; i++) { for (int i = 0; i < initialTags.length; i++) {
Utils.getObject(object, sentence, initialTags, initialTags.length, i); Utils.getObject(object, sentence, initialTags, initialTags.length, i);
tags[i] = findFiredNode(object).conclusion; tags[i] = findFiredNode(object);
} }
return tags; return tags;

View File

@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.function.Function;
/** /**
* @author DatQuocNguyen * @author DatQuocNguyen
@ -69,6 +71,7 @@ public class Utils
return true; return true;
} }
static Map<String, String> conditionInstancePool = new HashMap<>();
public static FWObject getCondition(String strCondition) public static FWObject getCondition(String strCondition)
{ {
FWObject condition = new FWObject(false); FWObject condition = new FWObject(false);
@ -120,6 +123,16 @@ public class Utils
} }
} }
// pool the conditions to increase the chances the data is in cache
// when comparing later
for (var i = 0; i < condition.context.length; i++) {
if (condition.context[i] != null) {
condition.context[i] = conditionInstancePool
.computeIfAbsent(condition.context[i], Function.identity());
}
}
return condition; return condition;
} }