diff --git a/third-party/rdrpostagger/build.gradle b/third-party/rdrpostagger/build.gradle index de627417..a9231958 100644 --- a/third-party/rdrpostagger/build.gradle +++ b/third-party/rdrpostagger/build.gradle @@ -9,6 +9,7 @@ java { } dependencies { + implementation libs.trove } test { diff --git a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/FWObject.java b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/FWObject.java index 9017f23a..709c3504 100644 --- a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/FWObject.java +++ b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/FWObject.java @@ -13,6 +13,9 @@ import java.util.Arrays; public class FWObject { public String[] context; + + int[] objectCtxI = new int[13]; + private final static String[] contextPrototype; static { contextPrototype = new String[13]; diff --git a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Node.java b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Node.java index 281180da..d0e9ee59 100644 --- a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Node.java +++ b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Node.java @@ -40,26 +40,12 @@ public class Node this.fatherNode = node; } - public int countNodes() - { - int count = 1; - if (exceptNode != null) { - count += exceptNode.countNodes(); - } - if (ifnotNode != null) { - count += ifnotNode.countNodes(); - } - return count; - } - public boolean satisfy(FWObject object) { for (int i = 0; i < 13; i++) { String key = condition.context[i]; - if (key != null) { - if (!key.equals(object.context[i])) { - return false; - } + if (key != null && !key.equals(object.context[i])) { // this is not equivalent to Objects.equals(a,b) + return false; } } return true; diff --git a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/RDRPOSTagger.java b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/RDRPOSTagger.java index a0bea5b2..9feb2de0 100644 --- a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/RDRPOSTagger.java +++ b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/RDRPOSTagger.java @@ -1,8 +1,11 @@ package com.github.datquocnguyen; +import gnu.trove.map.hash.TObjectIntHashMap; + import java.io.*; import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.util.Arrays; import java.util.HashMap; /** @@ -12,22 +15,59 @@ import java.util.HashMap; public class RDRPOSTagger { private final HashMap FREQDICT; - public final Node root; + final int OUGHT_TO_BE_ENOUGH = 5000; + final int CONTEXT_SIZE = 13; + + // Use dense array representation to reduce the level of indirection + // and improve the performance of the tagger + int[] conditions = new int[OUGHT_TO_BE_ENOUGH * CONTEXT_SIZE]; + String[] conclusions = new String[OUGHT_TO_BE_ENOUGH]; + short[] exceptIdx = new short[OUGHT_TO_BE_ENOUGH]; + short[] ifNotIdx = new short[OUGHT_TO_BE_ENOUGH]; + short[] fatherIdx = new short[OUGHT_TO_BE_ENOUGH]; + byte[] depthL = new byte[OUGHT_TO_BE_ENOUGH]; + + short size = 0; + + private final TObjectIntHashMap tagDict = new TObjectIntHashMap<>(10000, 0.75f, -1); + + private short addNode(FWObject condition, String conclusion, byte d) { + short idx = size++; + + for (int i = 0; i < CONTEXT_SIZE; i++) { + String context = condition.context[i]; + if (context != null) { + tagDict.putIfAbsent(context, tagDict.size()); + + conditions[idx * CONTEXT_SIZE + i] = tagDict.get(context); + } + else { + conditions[idx * CONTEXT_SIZE + i] = -1; + } + } + + conclusions[idx] = conclusion; + exceptIdx[idx] = -1; + ifNotIdx[idx] = -1; + fatherIdx[idx] = -1; + depthL[idx] = d; + + return idx; + } public RDRPOSTagger(Path dictPath, Path rulesFilePath) throws IOException { this.FREQDICT = Utils.getDictionary(dictPath.toString()); + Arrays.fill(conditions, -1); BufferedReader buffer = new BufferedReader(new InputStreamReader( new FileInputStream(rulesFilePath.toFile()), StandardCharsets.UTF_8)); String line = buffer.readLine(); - this.root = new Node(new FWObject(false), "NN", null, null, null, 0); - - Node currentNode = this.root; - int currentDepth = 0; + short currentIdx = addNode(new FWObject(false), "NN", (byte) 0); + byte currentDepth = 0; while ((line = buffer.readLine()) != null) { - int depth = 0; + byte depth = 0; for (int i = 0; i <= 6; i++) { // Supposed that the maximum // exception level is up to 6. if (line.charAt(i) == '\t') @@ -48,53 +88,72 @@ public class RDRPOSTagger String conclusion = Utils.getConcreteValue(line.split(" : ")[1] .trim()); - Node node = new Node(condition, conclusion, null, null, null, depth); + short newIdx = addNode(condition, conclusion, depth); if (depth > currentDepth) { - currentNode.setExceptNode(node); + exceptIdx[currentIdx] = newIdx; } else if (depth == currentDepth) { - currentNode.setIfnotNode(node); + ifNotIdx[currentIdx] = newIdx; } else { - while (currentNode.depth != depth) - currentNode = currentNode.fatherNode; - currentNode.setIfnotNode(node); + while (depthL[currentIdx] != depth) { + currentIdx = fatherIdx[currentIdx]; + } + ifNotIdx[currentIdx] = newIdx; } - node.setFatherNode(currentNode); - currentNode = node; + fatherIdx[newIdx] = currentIdx; + + currentIdx = newIdx; currentDepth = depth; } buffer.close(); } - public Node findFiredNode(FWObject object) + public String findFiredNode(FWObject object) { - Node currentN = root; - Node firedN = null; - while (true) { - if (currentN.satisfy(object)) { - firedN = currentN; - if (currentN.exceptNode == null) { - break; - } - else { - currentN = currentN.exceptNode; - } - } - else { - if (currentN.ifnotNode == null) { - break; - } - else { - currentN = currentN.ifnotNode; - } - } + int currentIdx = 0; + int firedIdx = -1; + int[] objCtxI = object.objectCtxI; + + for (int i = 0; i < CONTEXT_SIZE; i++) { + objCtxI[i] = tagDict.get(object.context[i]); } - return firedN; + int[] conditionsL = conditions; + short[] exceptIdxL = exceptIdx; + short[] ifNotIdxL = ifNotIdx; + + while (currentIdx >= 0) { + if (satisfy(objCtxI, conditionsL, currentIdx)) { + firedIdx = currentIdx; + currentIdx = exceptIdxL[currentIdx]; + } + else { + currentIdx = ifNotIdxL[currentIdx]; + } + } + + if (firedIdx >= 0) { + return conclusions[firedIdx]; + } + else { + return ""; + } + } + + public boolean satisfy(int[] objectCtxI, int[] conditions, int contextIdx) + { + // This is a good candidate for a vector operation + for (int i = 0; i < CONTEXT_SIZE; i++) { + int key = conditions[CONTEXT_SIZE *contextIdx + i]; + if (key >= 0 && key != objectCtxI[i]) { + return false; + } + } + return true; } public String[] tagsForEnSentence(String[] sentence) @@ -107,7 +166,7 @@ public class RDRPOSTagger for (int i = 0; i < initialTags.length; i++) { Utils.getObject(object, sentence, initialTags, initialTags.length, i); - tags[i] = findFiredNode(object).conclusion; + tags[i] = findFiredNode(object); } return tags; diff --git a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Utils.java b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Utils.java index 4cd91d58..6602db96 100644 --- a/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Utils.java +++ b/third-party/rdrpostagger/src/main/java/com/github/datquocnguyen/Utils.java @@ -9,6 +9,8 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.function.Function; /** * @author DatQuocNguyen @@ -69,6 +71,7 @@ public class Utils return true; } + static Map conditionInstancePool = new HashMap<>(); public static FWObject getCondition(String strCondition) { FWObject condition = new FWObject(false); @@ -120,6 +123,16 @@ public class Utils } } + // pool the conditions to increase the chances the data is in cache + // when comparing later + + for (var i = 0; i < condition.context.length; i++) { + if (condition.context[i] != null) { + condition.context[i] = conditionInstancePool + .computeIfAbsent(condition.context[i], Function.identity()); + } + } + return condition; }