MarginaliaSearch/code/libraries/coded-sequence/java/nu/marginalia/sequence/VarintCodedSequence.java

305 lines
8.5 KiB
Java
Raw Normal View History

2024-08-02 18:22:56 +00:00
package nu.marginalia.sequence;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntList;
import java.nio.ByteBuffer;
import java.util.Objects;
2024-08-02 18:22:56 +00:00
public class VarintCodedSequence implements CodedSequence {
private final ByteBuffer raw;
private final int startPos;
private final int startLimit;
public VarintCodedSequence(ByteBuffer buffer) {
this.raw = buffer;
this.startPos = buffer.position();
this.startLimit = buffer.limit();
}
public VarintCodedSequence(ByteBuffer buffer, int startPos, int startLimit) {
this.raw = buffer;
this.startPos = startPos;
this.startLimit = startLimit;
}
public static VarintCodedSequence generate(IntList values) {
int bufferSize = requiredBufferSize(values);
ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
int prev = 0;
encodeValue(buffer, values.size() + 1);
for (int i = 0; i < values.size(); i++) {
int value = values.getInt(i);
int toEncode = value - prev;
assert toEncode > 0 : "Values must be strictly increasing";
encodeValue(buffer, toEncode);
prev = value;
}
buffer.flip();
return new VarintCodedSequence(buffer);
}
2024-08-02 18:22:56 +00:00
public static VarintCodedSequence generate(int... values) {
int bufferSize = requiredBufferSize(values);
ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
int prev = 0;
encodeValue(buffer, values.length + 1);
for (int value : values) {
int toEncode = value - prev;
assert toEncode > 0 : "Values must be strictly increasing";
encodeValue(buffer, toEncode);
prev = value;
}
buffer.flip();
return new VarintCodedSequence(buffer);
}
/** Calculate the number of bytes required to encode a sequence of values as a varint. */
private static int requiredBufferSize(int[] values) {
int prev = 0;
int size = 0;
for (int value : values) {
size += varintSize(value - prev);
prev = value;
}
return size + varintSize(size + 1);
}
/** Calculate the number of bytes required to encode a sequence of values as a varint. */
private static int requiredBufferSize(IntList values) {
int prev = 0;
int size = 0;
for (int i = 0; i < values.size(); i++) {
int value = values.getInt(i);
size += varintSize(value - prev);
prev = value;
}
return size + varintSize(size + 1);
}
/** Calculate the number of bytes required to encode a value as a varint. */
private static int varintSize(int value) {
int bits = 32 - Integer.numberOfLeadingZeros(value);
return (bits + 6) / 7;
}
2024-08-02 18:22:56 +00:00
private static void encodeValue(ByteBuffer buffer, int value) {
if (value < (1<<7)) {
2024-08-02 18:22:56 +00:00
buffer.put((byte) value);
}
else if (value < (1<<14)) {
2024-08-02 18:22:56 +00:00
buffer.put((byte) (value >>> (7) | 0x80));
buffer.put((byte) (value & 0x7F));
}
else if (value < (1<<21)) {
2024-08-02 18:22:56 +00:00
buffer.put((byte) (value >>> (14) | 0x80));
buffer.put((byte) (value >>> (7) | 0x80));
buffer.put((byte) (value & 0x7F));
}
else if (value < (1<<28)) {
buffer.put((byte) ((value >>> 21) | 0x80));
buffer.put((byte) ((value >>> 14) | 0x80));
buffer.put((byte) ((value >>> 7) | 0x80));
buffer.put((byte) (value & 0x7F));
2024-08-02 18:22:56 +00:00
}
else {
throw new IllegalArgumentException("Value too large to encode");
}
}
@Override
public byte[] bytes() {
return raw.array();
}
@Override
public IntIterator iterator() {
return new VarintSequenceIterator(raw, startPos);
2024-08-02 18:22:56 +00:00
}
@Override
public IntIterator offsetIterator(int offset) {
return new VarintSequenceIterator(raw, startPos, offset);
2024-08-02 18:22:56 +00:00
}
@Override
public IntList values() {
var buffer = buffer();
int val = 0;
int count = decodeValue(buffer) - 1;
IntArrayList list = new IntArrayList(count);
while (buffer.hasRemaining()) {
val += decodeValue(buffer);
list.add(val);
}
return list;
}
@Override
public ByteBuffer buffer() {
raw.position(startPos);
raw.limit(startLimit);
return raw;
}
@Override
public int bufferSize() {
return raw.capacity();
}
@Override
public int valueCount() {
var buffer = buffer();
return decodeValue(buffer) - 1;
}
private static int decodeValue(ByteBuffer buffer) {
// most common case gets a fast path, this is a fairly large performance win
// on average, something like 10-20% faster than not having this check
2024-08-02 18:22:56 +00:00
byte b = buffer.get();
if ((b & 0x80) == 0) {
return b;
}
int value = b & 0x7F;
2024-08-02 18:22:56 +00:00
do {
b = buffer.get();
value = (value << 7) | (b & 0x7F);
2024-08-02 18:22:56 +00:00
} while ((b & 0x80) != 0);
return value;
}
public static class VarintSequenceIterator implements IntIterator {
private final ByteBuffer buffer;
// The position in the buffer where the next value is read from,
// we don't use the buffer's position, because we might want to access
// this buffer from multiple iterators simultaneously without interference.
private int bufferPos;
/** The number of values remaining to be decoded from the buffer. */
private int numRemainingValues;
/** The previous value that was read from the buffer,
* used in differential decoding. */
private int previousValue;
/** The next value that will be returned by nextInt,
* set to MIN_VALUE if no value is yet decoded */
private int nextValue = Integer.MIN_VALUE;
/** Create a new VarintSequenceIterator from a buffer.
* <p></p>
* The iterator will start at the given position in the buffer.
* The zero point is added to each value being read from the buffer.
* */
public VarintSequenceIterator(ByteBuffer buffer,
int startPos,
int zero) {
2024-08-02 18:22:56 +00:00
this.buffer = buffer;
if (zero == Integer.MIN_VALUE) {
throw new IllegalArgumentException("Integer.MIN_VALUE is a reserved offset that may not be used as zero point");
}
bufferPos = startPos;
previousValue = zero;
numRemainingValues = decodeValue() - 1;
2024-08-02 18:22:56 +00:00
}
/** Create a new VarintSequenceIterator from a buffer.
* <p></p>
* The iterator will start at the given position in the buffer.
* The zero point is 0.
* */
public VarintSequenceIterator(ByteBuffer buffer, int startPos) {
this(buffer, startPos, 0);
2024-08-02 18:22:56 +00:00
}
// This is BitWriter.getGamma with more checks in place for streaming iteration
@Override
public boolean hasNext() {
if (nextValue != Integer.MIN_VALUE) return true;
if (--numRemainingValues < 0) return false;
2024-08-02 18:22:56 +00:00
int delta = decodeValue();
2024-08-02 18:22:56 +00:00
previousValue += delta;
nextValue = previousValue;
2024-08-02 18:22:56 +00:00
return true;
}
// This is the same operation as decodeValue in the outer class,
// except we don't use the buffer's inbuilt position().
private int decodeValue() {
byte b = buffer.get(bufferPos++);
if ((b & 0x80) == 0) {
return b;
}
int value = b & 0x7F;
do {
b = buffer.get(bufferPos++);
value = (value << 7) | (b & 0x7F);
} while ((b & 0x80) != 0);
return value;
}
2024-08-02 18:22:56 +00:00
@Override
public int nextInt() {
if (hasNext()) {
int ret = nextValue;
nextValue = Integer.MIN_VALUE;
2024-08-02 18:22:56 +00:00
return ret;
}
throw new ArrayIndexOutOfBoundsException("No more data to read");
}
}
public int hashCode() {
return values().hashCode();
}
public boolean equals(Object other) {
if (other instanceof CodedSequence cs) {
return Objects.equals(values(), cs.values());
}
return false;
}
2024-08-02 18:22:56 +00:00
}