(index) Add min-dist factor and adjust rankings

This commit is contained in:
Viktor Lofgren 2024-08-03 12:04:23 +02:00
parent bf26ead010
commit 8462e88b8f
4 changed files with 95 additions and 9 deletions

View File

@ -50,7 +50,7 @@ public class ResultRankingParameters {
.shortSentencePenalty(5)
.bm25Weight(1.)
.tcfAvgDist(25.)
.tcfFirstPosition(1) // FIXME: what's a good default?
.tcfFirstPosition(5) // FIXME: what's a good default?
.temporalBias(TemporalBias.NONE)
.temporalBiasWeight(1. / (5.))
.exportDebugData(false)

View File

@ -1,5 +1,6 @@
package nu.marginalia.index.results;
import it.unimi.dsi.fastutil.ints.IntIterator;
import nu.marginalia.api.searchquery.model.compiled.CompiledQuery;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryInt;
import nu.marginalia.api.searchquery.model.compiled.CompiledQueryLong;
@ -25,6 +26,8 @@ import nu.marginalia.sequence.SequenceOperations;
import javax.annotation.Nullable;
import java.lang.foreign.Arena;
import java.util.ArrayList;
import java.util.List;
import static nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates.booleanAggregate;
import static nu.marginalia.api.searchquery.model.compiled.aggregate.CompiledQueryAggregates.intMaxMinAggregate;
@ -221,18 +224,35 @@ public class IndexResultScoreCalculator {
float[] weightedCounts = new float[compiledQuery.size()];
int firstPosition = Integer.MAX_VALUE;
for (int i = 0; i < weightedCounts.length; i++) {
if (positions[i] != null) {
var iter = positions[i].iterator();
float keywordMinDistFac = 0;
if (positions.length > 2) {
List<IntIterator> iterators = new ArrayList<>(positions.length);
if (!ctx.regularMask.get(i)) {
continue;
for (int i = 0; i < positions.length; i++) {
if (positions[i] != null && ctx.regularMask.get(i)) {
iterators.add(positions[i].iterator());
}
}
if (iterators.size() > 2) {
int minDist = SequenceOperations.minDistance(iterators);
if (minDist < 32) {
keywordMinDistFac = 2.0f / (1.f + (float) Math.sqrt(minDist));
} else {
keywordMinDistFac = -1.0f * (float) Math.sqrt(minDist);
}
}
}
for (int i = 0; i < weightedCounts.length; i++) {
if (positions[i] != null && ctx.regularMask.get(i)) {
var iter = positions[i].iterator();
while (iter.hasNext()) {
int pos = iter.nextInt();
firstPosition = Math.min(firstPosition, pos);
firstPosition = Math.max(firstPosition, pos);
if (spans.title.containsPosition(pos) || spans.heading.containsPosition(pos))
weightedCounts[i] += 2.5f;
@ -254,10 +274,11 @@ public class IndexResultScoreCalculator {
+ topologyBonus
+ temporalBias
+ flagsPenalty
+ coherenceScore;
+ coherenceScore
+ keywordMinDistFac;
double tcfAvgDist = rankingParams.tcfAvgDist * (1.0 / calculateAvgMinDistance(positionsQuery, ctx));
double tcfFirstPosition = rankingParams.tcfFirstPosition * (1.0 / Math.max(1, firstPosition));
double tcfFirstPosition = rankingParams.tcfFirstPosition * (1.0 / Math.sqrt(Math.max(1, firstPosition)));
double bM25 = rankingParams.bm25Weight * wordFlagsQuery.root.visit(new Bm25GraphVisitor(rankingParams.bm25Params, weightedCounts, length, ctx));
double bFlags = rankingParams.bm25Weight * wordFlagsQuery.root.visit(new TermFlagsGraphVisitor(rankingParams.bm25Params, wordFlagsQuery.data, weightedCounts, ctx));

View File

@ -4,6 +4,8 @@ import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntList;
import java.util.List;
public class SequenceOperations {
/** Return true if the sequences intersect, false otherwise.
@ -142,4 +144,55 @@ public class SequenceOperations {
return minDistance;
}
public static int minDistance(List<IntIterator> iterators) {
if (iterators.size() <= 1)
return 0;
int[] values = new int[iterators.size()];
for (int i = 0; i < iterators.size(); i++) {
if (iterators.get(i).hasNext())
values[i] = iterators.get(i).nextInt();
else
return 0;
}
int minDist = Integer.MAX_VALUE;
int successes = 0;
int minVal = Integer.MAX_VALUE;
int maxVal = Integer.MIN_VALUE;
for (int val : values) {
minVal = Math.min(minVal, val);
maxVal = Math.max(maxVal, val);
}
minDist = Math.min(minDist, maxVal - minVal);
for (int i = 0; successes < iterators.size(); i = (i + 1) % iterators.size())
{
if (values[i] == minVal) {
if (!iterators.get(i).hasNext()) {
break;
}
values[i] = iterators.get(i).nextInt();
if (values[i] > maxVal) {
maxVal = values[i];
}
if (values[i] > minVal) {
minVal = Integer.MAX_VALUE;
for (int val : values) {
minVal = Math.min(minVal, val);
}
}
minDist = Math.min(minDist, maxVal - minVal);
}
}
return minDist;
}
}

View File

@ -4,6 +4,7 @@ import it.unimi.dsi.fastutil.ints.IntList;
import org.junit.jupiter.api.Test;
import java.nio.ByteBuffer;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@ -83,4 +84,15 @@ class SequenceOperationsTest {
assertFalse(SequenceOperations.intersectSequences(seq1.iterator(), seq2.iterator()));
}
@Test
void testMinDistance() {
ByteBuffer wa = ByteBuffer.allocate(1024);
GammaCodedSequence seq1 = GammaCodedSequence.generate(wa, 11, 80, 160);
GammaCodedSequence seq2 = GammaCodedSequence.generate(wa, 20, 50, 100);
GammaCodedSequence seq3 = GammaCodedSequence.generate(wa, 30, 60, 90);
assertEquals(19, SequenceOperations.minDistance(List.of(seq1.iterator(), seq2.iterator(), seq3.iterator())));
}
}