Normal file
Normal file
@ -0,0 +1,9 @@
# marginalia.nu
This is the source code for marginalia.nu, including the search engine,
the MEMEX/gemini server, the and the encyclopedia service.
As it stands now, the project is a bit of a mess as it wasn't developed
with the intention of going open source, a lot of tests and so on make
assumptions about the directory structure, much configuration is hard coded
and so on. It's a work in progress.
@ -0,0 +1,74 @@
plugins {
id 'java'
id 'com.github.johnrengelman.shadow' version '6.0.0'
group 'nu.marginalia'
version 'SNAPSHOT'
compileJava.options.encoding = "UTF-8"
compileTestJava.options.encoding = "UTF-8"
repositories {
maven { url "https://artifactory.cronapp.io/public-release/" }
maven { url "https://repo1.maven.org/maven2/" }
maven { url "https://www2.ph.ed.ac.uk/maven2/" }
maven { url "https://jitpack.io/" }
exclusiveContent {
forRepository {
maven {
url = uri("https://jitpack.io")
filter {
// Only use JitPack for the `gson-record-type-adapter-factory` library
includeModule("com.github.Marcono1234", "gson-record-type-adapter-factory")
shadowJar {
jar {
manifest {
attributes 'Main-Class': "nu.marginalia.wmsa.configuration.ServiceDescriptor"
from {
configurations.shadow.collect { it.isDirectory() ? it : zipTree(it) }
java {
toolchain {
dependencies {
implementation project(':marginalia_nu')
task version() { //
test {
maxParallelForks = 16
forkEvery = 1
maxHeapSize = "8G"
useJUnitPlatform {
excludeTags "db"
excludeTags "nobuild"
task dbTest(type: Test) {
maxParallelForks = 1
forkEvery = 1
maxHeapSize = "8G"
useJUnitPlatform {
includeTags "db"
# Copyright © 2015-2021 the original authors.
@ -0,0 +1,133 @@
plugins {
id 'java'
id "io.freefair.lombok" version ""
id "me.champeau.jmh" version "0.6.6"
repositories {
maven { url "https://artifactory.cronapp.io/public-release/" }
maven { url "https://repo1.maven.org/maven2/" }
maven { url "https://www2.ph.ed.ac.uk/maven2/" }
maven { url "https://jitpack.io/" }
exclusiveContent {
forRepository {
maven {
url = uri("https://jitpack.io")
filter {
// Only use JitPack for the `gson-record-type-adapter-factory` library
includeModule("com.github.Marcono1234", "gson-record-type-adapter-factory")
dependencies {
implementation project(':third_party')
implementation 'junit:junit:4.13.2'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine'
implementation 'org.projectlombok:lombok:1.18.22'
annotationProcessor 'org.projectlombok:lombok:1.18.22'
testCompileOnly 'org.projectlombok:lombok:1.18.22'
testImplementation 'org.projectlombok:lombok:1.18.22'
testAnnotationProcessor 'org.projectlombok:lombok:1.18.22'
implementation 'com.github.jknack:handlebars:4.3.0'
implementation 'com.github.jknack:handlebars-markdown:4.2.1'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.0'
implementation 'io.reactivex.rxjava3:rxjava:3.1.4'
implementation "com.sparkjava:spark-core:2.9.3"
implementation 'com.opencsv:opencsv:5.6'
implementation group: 'org.apache.logging.log4j', name: 'log4j-api', version: '2.17.1'
implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version: '2.17.1'
implementation group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl', version: '2.17.1'
implementation group: 'org.apache.logging.log4j', name: 'log4j-api', version: '2.17.1'
implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version: '2.17.1'
implementation group: 'org.apache.logging.log4j', name: 'log4j-slf4j-impl', version: '2.17.1'
implementation 'org.slf4j:slf4j-api:1.7.36'
implementation 'com.google.guava:guava:31.1-jre'
implementation 'com.google.inject:guice:5.1.0'
implementation 'com.github.jnr:jnr-ffi:2.1.1'
implementation 'org.apache.httpcomponents:httpcore:4.4.15'
implementation 'org.apache.httpcomponents:httpclient:4.5.13'
implementation 'com.github.ThatJavaNerd:JRAW:1.1.0'
implementation group: 'com.h2database', name: 'h2', version: '2.1.210'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.3.1'
implementation 'org.jsoup:jsoup:1.14.3'
implementation group: 'com.github.crawler-commons', name: 'crawler-commons', version: '1.2'
implementation 'org.mariadb.jdbc:mariadb-java-client:3.0.3'
implementation group: 'net.sf.trove4j', name: 'trove4j', version: '3.0.3'
implementation 'com.zaxxer:HikariCP:5.0.1'
implementation 'org.apache.opennlp:opennlp-tools:1.9.4'
implementation 'io.prometheus:simpleclient:0.15.0'
implementation 'io.prometheus:simpleclient_servlet:0.15.0'
implementation 'io.prometheus:simpleclient_httpserver:0.15.0'
implementation 'io.prometheus:simpleclient_hotspot:0.15.0'
implementation 'com.fasterxml.jackson.core:jackson-databind:'
implementation 'org.apache.opennlp:opennlp-tools:1.9.4'
implementation 'io.prometheus:simpleclient:0.15.0'
implementation 'io.prometheus:simpleclient_servlet:0.15.0'
implementation 'io.prometheus:simpleclient_httpserver:0.15.0'
implementation 'io.prometheus:simpleclient_hotspot:0.15.0'
implementation 'com.fasterxml.jackson.core:jackson-databind:'
implementation group: 'org.yaml', name: 'snakeyaml', version: '1.30'
implementation 'com.syncthemall:boilerpipe:1.2.2'
implementation 'com.github.luben:zstd-jni:1.5.2-2'
implementation 'com.github.vladimir-bukhtoyarov:bucket4j-core:7.3.0'
implementation 'de.rototor.jeuclid:jeuclid-core:3.1.14'
implementation 'org.imgscalr:imgscalr-lib:4.2'
implementation 'org.jclarion:image4j:0.7'
implementation 'commons-net:commons-net:3.6'
implementation 'org.eclipse.jgit:org.eclipse.jgit:'
implementation 'org.eclipse.jgit:org.eclipse.jgit.ssh.jsch:'
implementation 'com.jcraft:jsch:0.1.55'
implementation group: 'org.apache.commons', name: 'commons-compress', version: '1.21'
implementation 'edu.stanford.nlp:stanford-corenlp:4.4.0'
implementation group: 'it.unimi.dsi', name: 'fastutil', version: '8.5.8'
implementation 'org.roaringbitmap:RoaringBitmap:[0.6,)'
implementation group: 'mysql', name: 'mysql-connector-java', version: '8.0.29'
implementation 'com.github.Marcono1234:gson-record-type-adapter-factory:0.2.0'
test {
maxParallelForks = 16
forkEvery = 1
maxHeapSize = "8G"
useJUnitPlatform {
excludeTags "db"
task dbTest(type: Test) {
maxParallelForks = 1
forkEvery = 1
maxHeapSize = "8G"
useJUnitPlatform {
includeTags "db"
@ -0,0 +1,37 @@
package bs_vs_ls;
import org.openjdk.jmh.annotations.*;
import java.util.Arrays;
import java.util.stream.LongStream;
public class BinSearchVsLinSearch {
static long[] data = LongStream.generate(() -> (long) (Long.MAX_VALUE * Math.random())).limit(512).sorted().toArray();
public static class Target {
long targetValue = 0;
public void setUp() {
targetValue = data[(int)(data.length * Math.random())];
// @Benchmark
public long testBs(Target t) {
return Arrays.binarySearch(data, t.targetValue);
// @Benchmark
public long testLs(Target t) {
for (int i = 0; i < 512; i++) {
if (data[i] > t.targetValue)
else if (data[i] == t.targetValue)
return i;
return -1;
@ -0,0 +1,68 @@
package bs_vs_ls;
import nu.marginalia.util.multimap.MultimapFileLong;
import nu.marginalia.util.multimap.MultimapSearcher;
import org.openjdk.jmh.annotations.*;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.stream.LongStream;
public class BinSearchVsLinSearch2 {
static long[] data = LongStream.generate(() -> (long) (Long.MAX_VALUE * Math.random())).limit(512).sorted().toArray();
public static class Target {
Path tf;
MultimapFileLong file;
MultimapSearcher searcher;
long[] data = new long[512];
try {
tf = Files.createTempFile("tmpFileIOTest", "dat");
file = MultimapFileLong.forOutput(tf, 1024);
searcher = file.createSearcher();
for (int i = 0; i < 65535; i++) {
file.put(i, i);
} catch (IOException e) {
@Measurement(iterations = 1)
@Warmup(iterations = 1)
public long testLs(Target t) {
int target = (int)(4096 + 512 * Math.random());
for (int i = 4096; i < (4096+512); i++) {
long val = t.file.get(i);
if (val > target)
if (val == target)
return val;
return -1;
@Measurement(iterations = 1)
@Warmup(iterations = 1)
public long testLs2(Target t) {
int target = (int)(4096 + 512 * Math.random());
t.file.read(t.data, 4096);
for (int i = 0; i < (512); i++) {
long val = t.file.get(i);
if (val > target)
if (val == target)
return val;
return -1;
@ -0,0 +1,43 @@
package nu.marginalia.gemini;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetAddress;
import java.util.HashSet;
import java.util.Set;
public class BadBotList {
private final Set<InetAddress> shitlist = new HashSet<>();
public static BadBotList INSTANCE = new BadBotList();
private final Logger logger = LoggerFactory.getLogger(getClass().getSimpleName());
private BadBotList() {}
public boolean isAllowed(InetAddress address) {
return !shitlist.contains(address);
public boolean isQueryPermitted(InetAddress address, String query) {
if (isBadQuery(query)) {
logger.info("Banning {}", address);
return false;
return true;
private boolean isBadQuery(String query) {
if (query.startsWith("GET")) {
return true;
if (query.startsWith("OPTIONS")) {
return true;
if (query.contains("mstshash")) {
return true;
return false;
@ -0,0 +1,21 @@
package nu.marginalia.gemini;
import com.google.inject.AbstractModule;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.name.Named;
import com.google.inject.name.Names;
import nu.marginalia.wmsa.memex.system.MemexFileWriter;
import java.nio.file.Path;
public class GeminiConfigurationModule extends AbstractModule {
public void configure() {
@ -0,0 +1,164 @@
package nu.marginalia.gemini;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.google.inject.name.Named;
import nu.marginalia.gemini.io.GeminiConnection;
import nu.marginalia.gemini.io.GeminiSSLSetUp;
import nu.marginalia.gemini.io.GeminiStatusCode;
import nu.marginalia.gemini.io.GeminiUserException;
import nu.marginalia.gemini.plugins.BareStaticPagePlugin;
import nu.marginalia.gemini.plugins.Plugin;
import nu.marginalia.gemini.plugins.SearchPlugin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class GeminiService {
public static final String DEFAULT_FILENAME = "index.gmi";
public final Path serverRoot;
private final Logger logger = LoggerFactory.getLogger("GeminiServer");
private final Executor pool = Executors.newFixedThreadPool(32);
private final SSLServerSocket serverSocket;
private final Plugin[] plugins;
private final BadBotList badBotList = BadBotList.INSTANCE;
public GeminiService(@Named("gemini-server-root") Path serverRoot,
@Named("gemini-server-port") Integer port,
GeminiSSLSetUp sslSetUp,
BareStaticPagePlugin pagePlugin,
SearchPlugin searchPlugin) throws Exception {
this.serverRoot = serverRoot;
logger.info("Setting up crypto");
final SSLServerSocketFactory socketFactory = sslSetUp.getServerSocketFactory();
serverSocket = (SSLServerSocket) socketFactory.createServerSocket(port /* 1965 */);
serverSocket.setEnabledProtocols(new String[] {"TLSv1.3", "TLSv1.2"});
logger.info("Verifying setup");
if (!Files.exists(this.serverRoot)) {
logger.error("Could not find SERVER_ROOT {}", this.serverRoot);
plugins = new Plugin[] {
public void run() {
logger.info("Awaiting connections");
try {
for (; ; ) {
SSLSocket connection = (SSLSocket) serverSocket.accept();
if (!badBotList.isAllowed(connection.getInetAddress())) {
} else {
pool.execute(() -> serve(connection));
catch (IOException ex) {
logger.error("IO Exception in gemini server", ex);
private void serve(SSLSocket socket) {
final GeminiConnection connection;
try {
connection = new GeminiConnection(socket);
catch (IOException ex) {
logger.error("Failed to create connection object", ex);
try {
catch (GeminiUserException ex) {
errorResponse(connection, ex.getMessage());
catch (SSLException ex) {
logger.error(connection.getAddress() + " SSL error");
catch (Exception ex) {
errorResponse(connection, "Error");
logger.error(connection.getAddress(), ex);
finally {
private void errorResponse(GeminiConnection connection, String message) {
if (connection.isConnected()) {
try {
logger.error("=> " + connection.getAddress(), message);
connection.writeStatusLine(GeminiStatusCode.ERROR_PERMANENT, message);
catch (IOException ex) {
logger.error("Exception while sending error", ex);
private void handleRequest(GeminiConnection connection) throws Exception {
final String address = connection.getAddress();
logger.info("Connect: " + address);
final Optional<URI> maybeUri = connection.readUrl();
if (maybeUri.isEmpty()) {
logger.info("Done: {}", address);
final URI uri = maybeUri.get();
logger.info("Request {}", uri);
if (!uri.getScheme().equals("gemini")) {
throw new GeminiUserException("Unsupported protocol");
servePage(connection, uri);
logger.info("Done: {}", address);
private void servePage(GeminiConnection connection, URI url) throws IOException {
String path = url.getPath();
for (Plugin p : plugins) {
if (p.serve(url, connection)) {
logger.error("FileNotFound {}", path);
connection.writeStatusLine(GeminiStatusCode.ERROR_TEMPORARY, "No such file");
@ -0,0 +1,130 @@
package nu.marginalia.gemini.client;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.cert.X509Certificate;
/** Unstable code! */
public class GeminiClient {
private final SSLSocketFactory socketFactory;
// Create a trust manager that does not validate anything
public static final TrustManager[] trustAllCerts = new TrustManager[]{
new X509TrustManager() {
public void checkClientTrusted(X509Certificate[] chain,
String authType) {
public void checkServerTrusted(X509Certificate[] chain,
String authType) {
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
public static SSLSocketFactory buildSocketFactory() throws Exception {
// Install the all-trusting trust manager
final SSLContext sslContext = SSLContext.getInstance("SSL");
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
return sslContext.getSocketFactory();
public GeminiClient() throws Exception {
socketFactory = buildSocketFactory();
public Response get(URI uri) throws IOException {
final int port = uri.getPort() == -1 ? 1965 : uri.getPort();
final String host = uri.getHost();
var requestString = String.format("%s\r\n", uri).getBytes(StandardCharsets.UTF_8);
try (var socket = socketFactory.createSocket(host, port)) {
var is = socket.getInputStream();
String statusLine = new GeminiInput(is).get();
int code = Integer.parseInt(statusLine.substring(0,2));
String meta = statusLine.substring(3);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
return new Response(code, meta, baos.toByteArray());
public static class Response {
public final int code;
public final String meta;
public final byte[] data;
Response(int code, String meta, byte[] data) {
this.code = code;
this.meta = meta;
this.data = data;
public static class GeminiInput {
private final InputStream is;
private final byte[] buffer = new byte[1024];
private int idx;
final String result;
public GeminiInput(InputStream is) throws IOException {
this.is = is;
for (idx = 0; idx < buffer.length; idx++) {
if (hasEndOfLine()) {
result = new String(buffer, 0, idx-2, StandardCharsets.UTF_8);
throw new RuntimeException("String too long");
public String get() {
return result;
private void readCharacter() throws IOException {
int rb = is.read();
if (-1 == rb) {
throw new RuntimeException("URL incomplete (no CR LF)");
buffer[idx] = (byte) rb;
public boolean hasEndOfLine() {
return idx > 2
&& buffer[idx - 1] == (byte) '\n'
&& buffer[idx - 2] == (byte) '\r';
@ -0,0 +1,53 @@
package nu.marginalia.gemini.gmi;
import lombok.Getter;
import nu.marginalia.gemini.gmi.line.AbstractGemtextLine;
import nu.marginalia.gemini.gmi.parser.GemtextParser;
import nu.marginalia.gemini.gmi.renderer.GemtextRenderer;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeUrl;
import java.io.IOException;
import java.io.Writer;
import java.util.Arrays;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class Gemtext {
private final AbstractGemtextLine[] lines;
private final MemexNodeUrl url;
public Gemtext(MemexNodeUrl url, String[] lines, MemexNodeHeadingId headingRoot) {
this.lines = GemtextParser.parse(lines, headingRoot);
this.url = url;
public Gemtext(MemexNodeUrl url, String[] lines) {
this.lines = GemtextParser.parse(lines, new MemexNodeHeadingId(0));
this.url = url;
public String render(GemtextRenderer renderer) {
return Arrays.stream(lines).map(renderer::renderLine).collect(Collectors.joining());
public void render(GemtextRenderer renderer, Writer w) throws IOException {
for (var line : lines) {
public Stream<AbstractGemtextLine> stream() {
return Arrays.stream(lines);
public AbstractGemtextLine get(int idx) {
return lines[idx];
public int size() {
return lines.length;
@ -0,0 +1,72 @@
package nu.marginalia.gemini.gmi;
import com.google.common.collect.Sets;
import nu.marginalia.gemini.gmi.line.GemtextLineVisitorAdapter;
import nu.marginalia.gemini.gmi.line.GemtextLink;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeUrl;
import nu.marginalia.wmsa.memex.model.MemexUrl;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
public class GemtextDatabase extends Gemtext {
public Map<String, Integer> links;
public GemtextDatabase(MemexNodeUrl url, String[] lines) {
super(url, lines);
links = new HashMap<>();
for (int i = 0; i < size(); i++) {
int linkIdx = i;
get(i).visit(new GemtextLineVisitorAdapter<>() {
public Object visit(GemtextLink g) {
links.put(g.getUrl().toString(), linkIdx);
return null;
public Set<String> keys() {
return links.keySet();
public Optional<String> getLinkData(MemexUrl url) {
Integer idx = links.get(url.getUrl());
if (idx != null) {
return Optional.empty();
public static GemtextDatabase of(MemexNodeUrl url, String[] lines) {
return new GemtextDatabase(url, lines);
public static GemtextDatabase of(MemexNodeUrl url, Path file) throws IOException {
try (var s = Files.lines(file)) {
return new GemtextDatabase(url, s.toArray(String[]::new));
public Set<MemexNodeUrl> difference(GemtextDatabase other) {
Set<MemexNodeUrl> differences = new HashSet<>();
Sets.difference(keys(), other.keys()).stream().map(MemexNodeUrl::new).forEach(differences::add);
Sets.intersection(keys(), other.keys())
.filter(url -> !Objects.equals(getLinkData(url), other.getLinkData(url)))
return differences;
@ -0,0 +1,163 @@
package nu.marginalia.gemini.gmi;
import lombok.Getter;
import nu.marginalia.gemini.gmi.line.*;
import nu.marginalia.gemini.gmi.renderer.GemtextRenderer;
import nu.marginalia.gemini.gmi.renderer.GemtextRendererFactory;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeTaskId;
import nu.marginalia.wmsa.memex.model.MemexNodeUrl;
import nu.marginalia.wmsa.memex.model.MemexTaskState;
import org.apache.commons.lang3.tuple.Pair;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
public class GemtextDocument extends Gemtext {
private final Map<MemexNodeHeadingId, String> headings;
private final Map<String, List<MemexNodeHeadingId>> headingsByName;
private final Set<String> pragmas;
private final List<GemtextTask> tasks;
private final String title;
private final String date;
private final List<GemtextLink> links;
private final int hashCode;
private static final Pattern datePattern = Pattern.compile(".*(\\d{4}-\\d{2}-\\d{2}).*");
private static final GemtextRenderer rawRenderer = new GemtextRendererFactory().gemtextRendererAsIs();
super(url, lines, headingRoot);
this.hashCode = Arrays.hashCode(lines);
GemtextDataExtractor extractor = new GemtextDataExtractor();
this.headings = extractor.getHeadings();
this.links = extractor.getLinks();
this.title = Objects.requireNonNullElse(extractor.getTitle(), url.getUrl());
this.pragmas = extractor.getPragmas();
this.headingsByName = extractor.getHeadingsByName();
this.tasks = extractor.getTasks();
this.date = extractor.getDate();
public String getHeadingForElement(AbstractGemtextLine line) {
return headings.getOrDefault(line.getHeading(), "");
public List<AbstractGemtextLine> getSection(MemexNodeHeadingId headingId) {
return stream()
.filter(line -> line.getHeading().isChildOf(headingId))
public String getSectionGemtext(MemexNodeHeadingId headingId) {
if (headingId.equals(new MemexNodeHeadingId(0))) {
return stream()
return stream()
.filter(line -> line.getHeading().isChildOf(headingId))
public Map<MemexNodeTaskId, Pair<String, MemexTaskState>> getOpenTopTasks() {
return tasks.stream()
.filter(task -> MemexTaskState.TODO.equals(task.getState())
|| MemexTaskState.URGENT.equals(task.getState()))
.filter(task -> task.getId().level() == 1)
.collect(Collectors.toMap(GemtextTask::getId, task -> Pair.of(task.getTask(), task.getState())));
public static GemtextDocument of(MemexNodeUrl url, String... lines) {
return new GemtextDocument(url, lines, new MemexNodeHeadingId(0));
public static GemtextDocument of(MemexNodeUrl url, Path file) throws IOException {
try (var s = Files.lines(file)) {
return new GemtextDocument(url, s.toArray(String[]::new), new MemexNodeHeadingId(0));
public boolean isIndex() {
return getUrl().getFilename().equals("index.gmi");
public int hashCode() {
return hashCode;
public Optional<String> getHeading(MemexNodeHeadingId heading) {
return Optional.ofNullable(headings.get(heading));
public Optional<MemexNodeHeadingId> getHeadingByName(MemexNodeHeadingId parent, String name) {
var headings = headingsByName.get(name);
if (null == headings) {
return Optional.empty();
return headings.stream().filter(heading -> heading.isChildOf(parent)).findAny();
private static class GemtextDataExtractor extends GemtextLineVisitorAdapter<Object> {
private String title;
private String date;
private final Map<MemexNodeHeadingId, String> headings = new TreeMap<>((a, b) -> Arrays.compare(a.getIds(), b.getIds()));
private final Map<String, List<MemexNodeHeadingId>> headingsByName = new HashMap<>();
private final Set<String> pragmas = new HashSet<>();
private final List<GemtextLink> links = new ArrayList<>();
private final List<GemtextTask> tasks = new ArrayList<>();
public Object visit(GemtextHeading g) {
headings.put(g.getLevel(), g.getName());
headingsByName.computeIfAbsent(g.getName(), t -> new ArrayList<>()).add(g.getLevel());
if (title == null) {
title = g.getName();
var dateMatcher = datePattern.matcher(title);
if (dateMatcher.matches()) {
date = dateMatcher.group(1);
return null;
public Object visit(GemtextLink g) {
return null;
public Object visit(GemtextTask g) {
return null;
public Object visit(GemtextPragma g) {
return null;
@ -0,0 +1,18 @@
package nu.marginalia.gemini.gmi.line;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.Optional;
import java.util.function.Function;
public abstract class AbstractGemtextLine {
public <T> Optional<T> mapLink(Function<GemtextLink, T> mapper) {
return Optional.empty();
public <T> Optional<T> mapHeading(Function<GemtextHeading, T> mapper) { return Optional.empty(); }
public <T> Optional<T> mapTask(Function<GemtextTask, T> mapper) { return Optional.empty(); }
public abstract <T> T visit(GemtextLineVisitor<T> visitor);
public abstract boolean breaksTask();
public abstract MemexNodeHeadingId getHeading();
@ -0,0 +1,21 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
@AllArgsConstructor @Getter @ToString
public class GemtextAside extends AbstractGemtextLine {
private final String line;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return false;
@ -0,0 +1,32 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.Optional;
import java.util.function.Function;
public class GemtextHeading extends AbstractGemtextLine {
private final MemexNodeHeadingId level;
private final String name;
private final MemexNodeHeadingId heading;
public <T> Optional<T> mapHeading(Function<GemtextHeading, T> mapper) {
return Optional.of(mapper.apply(this));
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return true;
@ -0,0 +1,18 @@
package nu.marginalia.gemini.gmi.line;
public interface GemtextLineVisitor<T> {
default T take(AbstractGemtextLine line) {
return line.visit(this);
T visit(GemtextHeading g);
T visit(GemtextLink g);
T visit(GemtextList g);
T visit(GemtextPreformat g);
T visit(GemtextQuote g);
T visit(GemtextText g);
T visit(GemtextTextLiteral g);
T visit(GemtextAside g);
T visit(GemtextTask g);
T visit(GemtextPragma g);
@ -0,0 +1,53 @@
package nu.marginalia.gemini.gmi.line;
public class GemtextLineVisitorAdapter<T> implements GemtextLineVisitor<T> {
public T visit(GemtextHeading g) {
return null;
public T visit(GemtextLink g) {
return null;
public T visit(GemtextList g) {
return null;
public T visit(GemtextPreformat g) {
return null;
public T visit(GemtextQuote g) {
return null;
public T visit(GemtextText g) {
return null;
public T visit(GemtextTextLiteral g) {
return null;
public T visit(GemtextAside g) {
return null;
public T visit(GemtextTask g) {
return null;
public T visit(GemtextPragma g) {
return null;
@ -0,0 +1,33 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexUrl;
import javax.annotation.Nullable;
import java.util.Optional;
import java.util.function.Function;
@AllArgsConstructor @Getter @ToString
public class GemtextLink extends AbstractGemtextLine {
private final MemexUrl url;
private final String title;
private final MemexNodeHeadingId heading;
public <T> Optional<T> mapLink(Function<GemtextLink, T> mapper) {
return Optional.ofNullable(mapper.apply(this));
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return false;
@ -0,0 +1,23 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.List;
@AllArgsConstructor @Getter @ToString
public class GemtextList extends AbstractGemtextLine {
private final List<String> items;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return true;
@ -0,0 +1,21 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
@AllArgsConstructor @Getter @ToString
public class GemtextPragma extends AbstractGemtextLine {
private final String line;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return false;
@ -0,0 +1,23 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.List;
@AllArgsConstructor @Getter @ToString
public class GemtextPreformat extends AbstractGemtextLine {
private final List<String> items;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return true;
@ -0,0 +1,23 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.List;
@AllArgsConstructor @Getter @ToString
public class GemtextQuote extends AbstractGemtextLine {
private final List<String> items;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return true;
@ -0,0 +1,42 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeTaskId;
import nu.marginalia.wmsa.memex.model.MemexTaskState;
import nu.marginalia.wmsa.memex.model.MemexTaskTags;
import java.util.Optional;
import java.util.function.Function;
@AllArgsConstructor @Getter @ToString
public class GemtextTask extends AbstractGemtextLine {
private final MemexNodeTaskId id;
private final String task;
private final MemexNodeHeadingId heading;
private final MemexTaskTags tags;
public MemexTaskState getState() {
return MemexTaskState.of(tags);
public int getLevel() {
return id.level();
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return true;
public <T> Optional<T> mapTask(Function<GemtextTask, T> mapper) {
return Optional.of(mapper.apply(this));
@ -0,0 +1,21 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
@AllArgsConstructor @Getter @ToString
public class GemtextText extends AbstractGemtextLine {
private final String line;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return !line.isBlank();
@ -0,0 +1,23 @@
package nu.marginalia.gemini.gmi.line;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.ToString;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.List;
@AllArgsConstructor @Getter @ToString
public class GemtextTextLiteral extends AbstractGemtextLine {
private final List<String> items;
private final MemexNodeHeadingId heading;
public <T> T visit(GemtextLineVisitor<T> visitor) {
return visitor.visit(this);
public boolean breaksTask() {
return false;
@ -0,0 +1,20 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.GemtextAside;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.regex.Pattern;
public class GemtextAsideParser {
private static final Pattern listItemPattern = Pattern.compile("^\\((.*)\\)$");
public static GemtextAside parse(String s, MemexNodeHeadingId heading) {
var matcher = listItemPattern.matcher(s);
if (!matcher.matches()) {
return null;
return new GemtextAside(matcher.group(1), heading);
@ -0,0 +1,26 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.AbstractGemtextLine;
import nu.marginalia.gemini.gmi.line.GemtextHeading;
import nu.marginalia.gemini.gmi.line.GemtextText;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.regex.Pattern;
public class GemtextHeadingParser {
private static final Pattern headingPattern = Pattern.compile("^(#+)\\s*([^#].*|$)$");
public static AbstractGemtextLine parse(String s, MemexNodeHeadingId heading) {
var matcher = headingPattern.matcher(s);
if (!matcher.matches()) {
return new GemtextText(s, heading);
int level = matcher.group(1).length() - 1;
var newHeading = heading.next(level);
return new GemtextHeading(newHeading, matcher.group(2), newHeading);
@ -0,0 +1,42 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.AbstractGemtextLine;
import nu.marginalia.gemini.gmi.line.GemtextLink;
import nu.marginalia.gemini.gmi.line.GemtextText;
import nu.marginalia.wmsa.memex.model.MemexExternalUrl;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeUrl;
import nu.marginalia.wmsa.memex.model.MemexUrl;
import javax.annotation.Nullable;
import java.util.regex.Pattern;
public class GemtextLinkParser {
private static Pattern linkPattern = Pattern.compile("^=>\\s?([^\\s]+)\\s*(.+)?$");
public static AbstractGemtextLine parse(String s, MemexNodeHeadingId heading) {
var matcher = linkPattern.matcher(s);
if (!matcher.matches()) {
return new GemtextText(s, heading);
if (matcher.groupCount() == 2) {
return new GemtextLink(toMemexUrl(matcher.group(1)), matcher.group(2), heading);
else {
return new GemtextLink(toMemexUrl(matcher.group(1)), null, heading);
private static MemexUrl toMemexUrl(String url) {
if (url.startsWith("/")) {
return new MemexNodeUrl(url);
else {
return new MemexExternalUrl(url);
@ -0,0 +1,17 @@
package nu.marginalia.gemini.gmi.parser;
import java.util.regex.Pattern;
public class GemtextListParser {
private static final Pattern listItemPattern = Pattern.compile("^\\*\\s?(.+)$");
public static String parse(String s) {
var matcher = listItemPattern.matcher(s);
if (!matcher.matches()) {
return null;
return matcher.group(1);
@ -0,0 +1,135 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.*;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeTaskId;
import java.util.*;
public class GemtextParser {
private static final String PREFORMAT_MARKER = "```";
private static final String LITERAL_MARKER = " ";
private static final String LINK_MARKER = "=>";
private static final String HEADING_MARKER = "#";
private static final String LIST_MARKER = "*";
private static final String QUOTE_MARKER = ">";
private static final String ASIDE_MARKER = "(";
private static final String TASK_MARKER = "-";
private static final String PRAGMA_MARKER = "%%%";
public static AbstractGemtextLine[] parse(String[] lines, MemexNodeHeadingId headingRoot) {
List<AbstractGemtextLine> items = new ArrayList<>();
MemexNodeHeadingId heading = headingRoot;
MemexNodeTaskId task = new MemexNodeTaskId(0);
Set<String> pragmas = new HashSet<>();
for (int i = 0; i < lines.length; i++) {
String line = lines[i];
if (line.startsWith(PREFORMAT_MARKER)) {
i = getBlockQuote(items, lines, heading, i);
else if (line.startsWith(PRAGMA_MARKER)) {
var pragma = GemtextPragmaParser.parse(line, heading);
if (pragma instanceof GemtextPragma) {
GemtextPragma gtp = (GemtextPragma) pragma;
else if (line.startsWith(LINK_MARKER)) {
var link = GemtextLinkParser.parse(line, heading);
else if (line.startsWith(HEADING_MARKER)) {
var tag = GemtextHeadingParser.parse(line, heading);
heading = tag.mapHeading(GemtextHeading::getHeading).orElse(heading);
else if (line.startsWith(LIST_MARKER)) {
i = getList(items, lines, heading, i);
else if (line.startsWith(LITERAL_MARKER)) {
i = getLitteral(items, lines, heading, i);
else if (pragmas.contains("TASKS")
&& line.startsWith(TASK_MARKER))
var tag = GemtextTaskParser.parse(line, heading, task);
task = tag.mapTask(GemtextTask::getId).orElse(task);
else if (line.startsWith(QUOTE_MARKER)) {
i = getQuote(items, lines, heading, i);
else if (line.startsWith(ASIDE_MARKER)) {
var aside = GemtextAsideParser.parse(line, heading);
items.add(Objects.requireNonNullElse(aside, new GemtextText(line, heading)));
else {
items.add(new GemtextText(line, heading));
return items.toArray(AbstractGemtextLine[]::new);
private static int getBlockQuote(List<AbstractGemtextLine> items, String[] lines, MemexNodeHeadingId heading, int i) {
int j = i+1;
List<String> quotedLines = new ArrayList<>();
for (;j < lines.length; j++) {
if (lines[j].startsWith(PREFORMAT_MARKER)) {
items.add(new GemtextPreformat(quotedLines, heading));
return j;
private static int getList(List<AbstractGemtextLine> items, String[] lines, MemexNodeHeadingId heading, int i) {
int j = i;
List<String> listLines = new ArrayList<>();
for (;j < lines.length; j++) {
if (!lines[j].startsWith(LIST_MARKER)) {
items.add(new GemtextList(listLines, heading));
return j-1;
private static int getLitteral(List<AbstractGemtextLine> items, String[] lines, MemexNodeHeadingId heading, int i) {
int j = i;
List<String> listLines = new ArrayList<>();
for (;j < lines.length; j++) {
if (!lines[j].startsWith(LITERAL_MARKER)) {
items.add(new GemtextTextLiteral(listLines, heading));
return j-1;
private static int getQuote(List<AbstractGemtextLine> items, String[] lines, MemexNodeHeadingId heading, int i) {
int j = i;
List<String> listLines = new ArrayList<>();
for (;j < lines.length; j++) {
if (!lines[j].startsWith(QUOTE_MARKER)) {
items.add(new GemtextQuote(listLines, heading));
return j-1;
@ -0,0 +1,26 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.AbstractGemtextLine;
import nu.marginalia.gemini.gmi.line.GemtextPragma;
import nu.marginalia.gemini.gmi.line.GemtextText;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import java.util.regex.Pattern;
public class GemtextPragmaParser {
private static final Pattern pragmaPattern = Pattern.compile("^%%%\\s*(.*|$)$");
public static AbstractGemtextLine parse(String s, MemexNodeHeadingId heading) {
var matcher = pragmaPattern.matcher(s);
if (!matcher.matches()) {
return new GemtextText(s, heading);
String task = matcher.group(1);
return new GemtextPragma(task, heading);
@ -0,0 +1,17 @@
package nu.marginalia.gemini.gmi.parser;
import java.util.regex.Pattern;
public class GemtextQuoteParser {
private static final Pattern listItemPattern = Pattern.compile("^>(.+)$");
public static String parse(String s) {
var matcher = listItemPattern.matcher(s);
if (!matcher.matches()) {
return null;
return matcher.group(1);
@ -0,0 +1,31 @@
package nu.marginalia.gemini.gmi.parser;
import nu.marginalia.gemini.gmi.line.AbstractGemtextLine;
import nu.marginalia.gemini.gmi.line.GemtextTask;
import nu.marginalia.gemini.gmi.line.GemtextText;
import nu.marginalia.wmsa.memex.model.MemexNodeHeadingId;
import nu.marginalia.wmsa.memex.model.MemexNodeTaskId;
import nu.marginalia.wmsa.memex.model.MemexTaskTags;
import java.util.regex.Pattern;
public class GemtextTaskParser {
private static final Pattern taskPattern = Pattern.compile("^(-+)\\s*([^-].*|$)$");
public static AbstractGemtextLine parse(String s, MemexNodeHeadingId heading,
MemexNodeTaskId taskId) {
var matcher = taskPattern.matcher(s);
if (!matcher.matches()) {
return new GemtextText(s, heading);
int level = matcher.group(1).length() - 1;
String task = matcher.group(2);
return new GemtextTask(taskId.next(level), task, heading, new MemexTaskTags(task));
@ -0,0 +1,91 @@
package nu.marginalia.gemini.gmi.renderer;
import nu.marginalia.gemini.gmi.line.*;
import java.util.function.Function;
public class GemtextRenderer implements GemtextLineVisitor<String> {
private final Function<GemtextHeading, String> headingConverter;
private final Function<GemtextLink, String> linkConverter;
private final Function<GemtextList, String> listConverter;
private final Function<GemtextPreformat, String> preformatConverter;
private final Function<GemtextQuote, String> quoteConverter;
private final Function<GemtextText, String> textConverter;
private final Function<GemtextAside, String> asideConverter;
private final Function<GemtextTask, String> taskConverter;
private final Function<GemtextTextLiteral, String> literalConverter;
private final Function<GemtextPragma, String> pragmaConverter;
public GemtextRenderer(Function<GemtextHeading, String> headingConverter,
Function<GemtextLink, String> linkConverter,
Function<GemtextList, String> listConverter,
Function<GemtextPreformat, String> preformatConverter,
Function<GemtextQuote, String> quoteConverter,
Function<GemtextText, String> textConverter,
Function<GemtextAside, String> asideConverter,
Function<GemtextTask, String> taskConverter,
Function<GemtextTextLiteral, String> literalConverter,
Function<GemtextPragma, String> pragmaConverter
) {
this.headingConverter = headingConverter;
this.linkConverter = linkConverter;
this.listConverter = listConverter;
this.preformatConverter = preformatConverter;
this.quoteConverter = quoteConverter;
this.textConverter = textConverter;
this.asideConverter = asideConverter;
this.taskConverter = taskConverter;
this.literalConverter = literalConverter;
this.pragmaConverter = pragmaConverter;
public String renderLine(AbstractGemtextLine line) {
return line.visit(this);
public String visit(GemtextHeading g) {
return headingConverter.apply(g);
public String visit(GemtextLink g) {
return linkConverter.apply(g);
public String visit(GemtextList g) {
return listConverter.apply(g);
public String visit(GemtextPreformat g) {
return preformatConverter.apply(g);
public String visit(GemtextQuote g) {
return quoteConverter.apply(g);
public String visit(GemtextText g) {
return textConverter.apply(g);
public String visit(GemtextTextLiteral g) {
return literalConverter.apply(g);
public String visit(GemtextAside g) { return asideConverter.apply(g); }
public String visit(GemtextTask g) { return taskConverter.apply(g); }
public String visit(GemtextPragma g) { return pragmaConverter.apply(g); }
@ -0,0 +1,227 @@
package nu.marginalia.gemini.gmi.renderer;
import nu.marginalia.gemini.gmi.line.*;
import nu.marginalia.wmsa.memex.model.MemexNodeUrl;
import nu.marginalia.wmsa.memex.model.MemexUrl;
import org.apache.logging.log4j.util.Strings;
import java.util.Objects;
import java.util.stream.Collectors;
public class GemtextRendererFactory {
public final String urlBase;
public final String docUrl;
public GemtextRendererFactory(String urlBase, String docUrl) {
this.urlBase = Objects.requireNonNull(urlBase, "urlBase must not be null");
this.docUrl = Objects.requireNonNull(docUrl, "docUrl must not be null");
public GemtextRendererFactory(String urlBase) {
this.urlBase = Objects.requireNonNull(urlBase, "urlBase must not be null");
this.docUrl = null;
public GemtextRendererFactory() {
this.urlBase = null;
this.docUrl = null;
public GemtextRenderer htmlRendererEditable() {
return new GemtextRenderer(this::htmlHeadingEditable,
this::htmlLink, this::htmlList,
this::htmlPre, this::htmlQuote,
this::htmlText, this::htmlAside,
this::htmlTask, this::htmlLiteral,
public GemtextRenderer htmlRendererReadOnly() {
return new GemtextRenderer(this::htmlHeadingReadOnly,
this::htmlLink, this::htmlList,
this::htmlPre, this::htmlQuote,
this::htmlText, this::htmlAside,
this::htmlTask, this::htmlLiteral,
public GemtextRenderer gemtextRendererAsIs() {
return new GemtextRenderer(this::rawHeading,
this::rawLink, this::rawList,
this::rawPre, this::rawQuote,
this::rawText, this::rawAside,
this::rawTask, this::rawLiteral,
public GemtextRenderer gemtextRendererPublic() {
return new GemtextRenderer(this::rawHeading,
this::rawLink, this::rawList,
this::rawPre, this::rawQuote,
this::rawText, this::rawAside,
this::rawTask, this::rawLiteral,
private String htmlPragma(GemtextPragma gemtextPragma) {
return "<!-- pragma: " + sanitizeText(gemtextPragma.getLine()) + " -->\n";
public String htmlHeadingEditable(GemtextHeading g) {
if (docUrl == null) {
throw new UnsupportedOperationException("Wrong constructor used, need urlBase and docUrl");
// String editLink = String.format("\n<a class=\"utility\" href=\"%s/edit/%s\">Edit</a>\n", urlBase + docUrl, g.getLevel());
return htmlHeadingReadOnly(g);
public String htmlHeadingReadOnly(GemtextHeading g) {
if (g.getLevel().getLevel() == 1)
return String.format("<h1 id=\"%s\">%s</h1>\n", g.getLevel(), sanitizeText(g.getName()));
if (g.getLevel().getLevel() == 2)
return String.format("<h2 id=\"%s\">%s</h2>\n", g.getLevel(), sanitizeText(g.getName()));
if (g.getLevel().getLevel() == 3)
return String.format("<h3 id=\"%s\">%s</h3>\n", g.getLevel(), sanitizeText(g.getName()));
return String.format("<h4 id=\"%s\">%s</h4>\n", g.getLevel(), sanitizeText(g.getName()));
public String htmlLink(GemtextLink g) {
if (urlBase == null) {
throw new UnsupportedOperationException("Wrong constructor used, need urlBase");
final String linkClass = getLinkClass(g.getUrl());
final String linkUrl = getLinkUrl(g.getUrl()).replaceFirst("^gemini://", "https://proxy.vulpes.one/gemini/");
if (g.getTitle() != null) {
return String.format("<dl class=\"link\"><dt><a class=\"%s\" href=\"%s\">%s</a></dt><dd>%s</dd></dl>\n",
linkClass, linkUrl, g.getUrl(), sanitizeText(g.getTitle()));
else {
return String.format("<a class=\"%s\" href=\"%s\">%s</a><br>\n",
linkClass, linkUrl, g.getUrl());
private String getLinkUrl(MemexUrl url) {
if (url instanceof MemexNodeUrl || url.getUrl().startsWith("/")) {
return urlBase + url;
return url.toString();
private String getLinkClass(MemexUrl url) {
if (url instanceof MemexNodeUrl) {
return "internal";
return "external";
public String htmlList(GemtextList g) {
return g.getItems()
.map(s -> "<li>" + sanitizeText(s) + "</li>")
Collectors.joining("\n", "<ul>\n", "</ul>\n"));
public String htmlPre(GemtextPreformat g) {
return g.getItems().stream()
Collectors.joining("\n", "<pre>\n", "</pre>\n"));
public String htmlLiteral(GemtextTextLiteral g) {
return g.getItems().stream()
Collectors.joining("\n", "<pre class=\"literal\">\n", "</pre>\n"));
public String htmlQuote(GemtextQuote g) {
return g.getItems().stream()
Collectors.joining("<br>\n", "<blockquote>\n", "</blockquote>\n"));
public String htmlText(GemtextText g) {
return sanitizeText(g.getLine()) + "<br>\n";
public String htmlAside(GemtextAside g) {
return "<aside>" + sanitizeText(g.getLine()) + "</aside>\n";
public String sanitizeText(String s) {
return s.replaceAll("<", "<").replaceAll(">", ">");
public String htmlTask(GemtextTask g) {
return String.format("<a class=\"task-pointer\" name=\"t%s\"></a><div class=\"task %s\" id=\"%s\">%s %s</div>\n",
public String rawHeading(GemtextHeading g) {
if (g.getLevel().getLevel() == 1)
return "# " + g.getName();
if (g.getLevel().getLevel() == 2)
return "## " + g.getName();
if (g.getLevel().getLevel() == 3)
return "### " + g.getName();
return "### " + g.getName();
public String rawLink(GemtextLink g) {
if (g.getTitle() != null && !g.getTitle().isBlank()) {
return "=> " + g.getUrl().getUrl() + "\t" + g.getTitle();
return "=> " + g.getUrl().getUrl();
public String rawList(GemtextList g) {
return g.getItems()
.map(s -> "* " + s)
public String rawPre(GemtextPreformat g) {
return g.getItems().stream()
.collect(Collectors.joining("\n", "```\n", "\n```"));
public String rawQuote(GemtextQuote g) {
return g.getItems().stream()
.map(s -> "> " + s)
public String rawText(GemtextText g) {
return g.getLine();
public String rawLiteral(GemtextTextLiteral g) {
return Strings.join(g.getItems(), '\n');
public String rawAside(GemtextAside g) {
return "(" + g.getLine() + ")";
public String rawTask(GemtextTask g) {
return "-".repeat(Math.max(0, g.getLevel())) + " " + g.getTask();
private String rawPragma(GemtextPragma gemtextPragma) {
return "%%% " + gemtextPragma.getLine();
private String rawSupressPragma(GemtextPragma gemtextPragma) {
return "";
@ -0,0 +1,185 @@
package nu.marginalia.gemini.io;
import nu.marginalia.gemini.BadBotList;
import nu.marginalia.gemini.plugins.FileType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLSocket;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;
import java.util.stream.Stream;
public class GeminiConnection {
private final SSLSocket connection;
private final Logger logger = LoggerFactory.getLogger("Server");
private final OutputStream os;
private final InputStream is;
private static final BadBotList badBotList = BadBotList.INSTANCE;
public GeminiConnection(SSLSocket connection) throws IOException {
this.connection = connection;
this.os = connection.getOutputStream();
this.is = connection.getInputStream();
public String getAddress() {
return connection.getInetAddress().getHostAddress();
public Optional<URI> readUrl() throws Exception {
var str = new GeminiInput().get();
if (!badBotList.isQueryPermitted(connection.getInetAddress(), str)) {
return Optional.empty();
if (!str.isBlank()) {
return Optional.of(new URI(str));
throw new GeminiUserException("Bad URI");
public void redirect(String address) throws IOException {
writeStatusLine(GeminiStatusCode.REDIRECT, address);
public void redirectPermanent(String address) throws IOException {
writeStatusLine(GeminiStatusCode.REDIRECT_PERMANENT, address);
public GeminiConnection writeStatusLine(int code, String meta) throws IOException {
write(String.format("%2d %s", code, meta));
return this;
public GeminiConnection writeBytes(byte[] data) throws IOException {
return this;
public GeminiConnection printf(String pattern, Object...args) throws IOException {
write(String.format(pattern, args));
return this;
public GeminiConnection writeLines(String... lines) throws IOException {
for (String s : lines) {
return this;
public GeminiConnection writeLinesFromFile(Path file) throws IOException {
try (Stream<String> lines = Files.lines(file)) {
lines.forEach(line -> {
try {
} catch (IOException e) {
logger.error("IO Error", e);
return this;
public GeminiConnection acceptLines(Stream<String> lines) {
lines.forEach(line -> {
try {
} catch (IOException e) {
logger.error("IO exception", e);
return this;
private void write(String s) throws IOException {
os.write(new byte[] { '\r', '\n'});
private void write(byte[] bs) throws IOException {
// This is a weird pattern but it makes the listing code very much cleaner
public void error(String message) {
logger.error("{}", message);
throw new GeminiUserException(message);
public void close() {
try {
} catch (IOException e) {
public boolean isConnected() {
return connection.isConnected();
public void respondWithFile(Path serverPath, FileType fileType) throws IOException {
if (fileType.binary) {
writeStatusLine(GeminiStatusCode.SUCCESS, fileType.mime)
else {
writeStatusLine(GeminiStatusCode.SUCCESS, fileType.mime)
public class GeminiInput {
private final byte[] buffer = new byte[1024];
private int idx = 0;
final String result;
public GeminiInput() throws IOException {
for (idx = 0; idx < buffer.length; idx++) {
if (hasEndOfLine()) {
result = new String(buffer, 0, idx-2, StandardCharsets.UTF_8);
error("String too long");
// unreachable
result = "";
public String get() {
return result;
private void readCharacter() throws IOException {
int rb = is.read();
if (-1 == rb) {
error("URL incomplete (no CR LF)");
buffer[idx] = (byte) rb;
public boolean hasEndOfLine() {
return idx > 2
&& buffer[idx - 1] == (byte) '\n'
&& buffer[idx - 2] == (byte) '\r';
@ -0,0 +1,49 @@
package nu.marginalia.gemini.io;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import javax.net.ssl.*;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.KeyStore;
import java.security.SecureRandom;
public class GeminiSSLSetUp {
private final Path certPasswordFile;
private final Path certFile;
public GeminiSSLSetUp(
@Named("gemini-cert-file") Path certFile,
@Named("gemini-cert-password-file") Path certPasswordFile) {
this.certFile = certFile;
this.certPasswordFile = certPasswordFile;
public String getCertPassword() throws IOException {
return Files.readString(certPasswordFile);
private SSLContext getContext() throws Exception {
KeyStore ks = KeyStore.getInstance("JKS", "SUN");
ks.load(Files.newInputStream(certFile), getCertPassword().toCharArray());
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, getCertPassword().toCharArray());
KeyManager[] keyManagers = kmf.getKeyManagers();
TrustManagerFactory tmf = TrustManagerFactory.getInstance("X509");
TrustManager[] trustManagers = tmf.getTrustManagers();
var ctx = SSLContext.getInstance("TLSv1.3");
ctx.init(keyManagers, trustManagers, new SecureRandom());
return ctx;
public SSLServerSocketFactory getServerSocketFactory() throws Exception {
return getContext().getServerSocketFactory();
@ -0,0 +1,11 @@
package nu.marginalia.gemini.io;
public class GeminiStatusCode {
public static final int INPUT = 10;
public static final int SUCCESS = 20;
public static final int ERROR_PERMANENT = 50;
public static final int ERROR_TEMPORARY = 40;
public static final int PROXY_ERROR = 43;
public static final int REDIRECT = 30;
public static final int REDIRECT_PERMANENT = 31;
@ -0,0 +1,8 @@
package nu.marginalia.gemini.io;
/** Throw to report message to user */
public class GeminiUserException extends RuntimeException {
public GeminiUserException(String message) {
@ -0,0 +1,53 @@
package nu.marginalia.gemini.plugins;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import nu.marginalia.gemini.io.GeminiConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import static nu.marginalia.gemini.GeminiService.DEFAULT_FILENAME;
public class BareStaticPagePlugin implements Plugin {
private final Logger logger = LoggerFactory.getLogger(getClass());
private Path geminiServerRoot;
public BareStaticPagePlugin(@Named("gemini-server-root") Path geminiServerRoot) {
this.geminiServerRoot = geminiServerRoot;
public boolean serve(URI url, GeminiConnection connection) throws IOException {
final Path serverPath = getServerPath(url.getPath());
if (!Files.isRegularFile(serverPath)) {
return false;
verifyPath(geminiServerRoot, serverPath);
logger.info("Serving {}", serverPath);
connection.respondWithFile(serverPath, FileType.match(serverPath));
return true;
private Path getServerPath(String requestPath) {
final Path serverPath = Path.of(geminiServerRoot + requestPath);
if (Files.isDirectory(serverPath) && Files.isRegularFile(serverPath.resolve(DEFAULT_FILENAME))) {
return serverPath.resolve(DEFAULT_FILENAME);
return serverPath;
@ -0,0 +1,58 @@
package nu.marginalia.gemini.plugins;
import java.nio.file.Path;
public enum FileType {
GMI("gmi", "text/gemini", FileIcons.DOCUMENT, false),
GEM("gem", "text/gemini", FileIcons.DOCUMENT, false),
TXT("txt", "text/plain", FileIcons.DOCUMENT, false),
MARKDOWN("md", "text/markdown", FileIcons.DOCUMENT, false),
JAVA("java", "text/java", FileIcons.JAVA, false),
PROPERTIES("properties", "text/properties", FileIcons.SETTINGS, false),
GRADLE("gradle", "text/gradle", FileIcons.SETTINGS, false),
ZIP("zip", "application/zip", FileIcons.ZIP, true),
PNG("png", "image/png", FileIcons.IMAGE, true),
JPG("jpg", "image/jpg", FileIcons.IMAGE, true),
JPEG("jpeg", "image/jpg", FileIcons.IMAGE, true),
BIN("bin", "application/binary", FileIcons.BINARY, true),
SH("sh", "text/sh", FileIcons.SETTINGS, false),
XML("xml", "text/xml", FileIcons.DOCUMENT, false),
DOCKERFILE("Dockerfile", "text/dockerfile", FileIcons.SETTINGS, false)
public static FileType match(String fileName) {
for (var type : values()) {
if (fileName.endsWith(type.suffix)) {
return type;
return BIN;
public static FileType match(Path path) {
return match(path.toString());
FileType(String suffix, String mime, String icon, boolean binary) {
this.suffix = suffix;
this.mime = mime;
this.icon = icon;
this.binary = binary;
public final String suffix;
public final String mime;
public final String icon;
public final boolean binary;
class FileIcons {
public static final String DOCUMENT = "🗒";
public static final String JAVA = "♨";
public static final String SETTINGS = "💻";
public static final String ZIP = "🗜";
public static final String IMAGE = "🖼";
public static final String DIRECTORY = "🗂";
public static final String BINARY = "📚";
@ -0,0 +1,19 @@
package nu.marginalia.gemini.plugins;
import nu.marginalia.gemini.io.GeminiConnection;
import nu.marginalia.gemini.io.GeminiUserException;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
public interface Plugin {
/** @return true if content served */
boolean serve(URI url, GeminiConnection connection) throws IOException;
default void verifyPath(Path root, Path p) {
if (!p.normalize().startsWith(root)) {
throw new GeminiUserException("ಠ_ಠ That path is off limits!");
@ -0,0 +1,78 @@
package nu.marginalia.gemini.plugins;
import com.google.inject.Inject;
import nu.marginalia.gemini.io.GeminiConnection;
import nu.marginalia.gemini.io.GeminiStatusCode;
import org.apache.http.HttpHost;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.conn.routing.HttpRoute;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
public class SearchPlugin implements Plugin {
private final PoolingHttpClientConnectionManager connectionManager;
private final Logger logger = LoggerFactory.getLogger(getClass());
public SearchPlugin() {
connectionManager = new PoolingHttpClientConnectionManager();
HttpHost host = new HttpHost("https://search.marginalia.nu/");
connectionManager.setMaxPerRoute(new HttpRoute(host), 20);
public boolean serve(URI url, GeminiConnection connection) throws IOException {
var client = HttpClients.custom()
if (!"/search".equals(url.getPath())) {
return false;
String query = url.getRawQuery();
if (null == query || "".equals(query)) {
logger.info("Requesting search terms");
connection.writeStatusLine(GeminiStatusCode.INPUT, "Please enter a search query");
else {
logger.info("Delegating search query '{}'", query);
final HttpGet get = new HttpGet(createSearchUri(query));
final byte[] binaryResponse;
try (var rsp = client.execute(get)) {
binaryResponse = rsp.getEntity().getContent().readAllBytes();
catch (IOException ex) {
logger.error("backend error", ex);
connection.writeStatusLine(GeminiStatusCode.PROXY_ERROR, "Failed to reach backend server");
return true;
.writeStatusLine(GeminiStatusCode.SUCCESS, "text/gemini")
return true;
private URI createSearchUri(String query) {
try {
return new URI("https://search.marginalia.nu/search?format=gmi&query="+query);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
@ -0,0 +1,80 @@
package nu.marginalia.util;
public class ByteFolder {
public byte[] foldBytes(int p, int q) {
int pw = bitWidth(p);
int qw = bitWidth(q);
int qpw = qw + pw;
long qp = Integer.toUnsignedLong(q) << pw | Integer.toUnsignedLong(p);
int qpwBytes = ((qpw - 1) / Byte.SIZE) + 1;
byte[] bytes = new byte[qpwBytes + 1];
bytes[0] = (byte) pw;
for (int i = 1; i < bytes.length; i++) {
bytes[i] = (byte) (qp >>> (qpwBytes - i) * Byte.SIZE & 0xff);
return bytes;
// Function such that (decodeBytes o foldBytes) = identity
public static int[] decodeBytes(byte[] data) {
int[] dest = new int[2];
decodeBytes(data, data.length, dest);
return dest;
public static void decodeBytes(byte[] data, int length, int[] dest) {
long val = 0;
for (int i = 1; i < length; i++) {
val = (val << 8) | ((0xFF)&data[i]);
dest[1] = (int)(val >>> data[0]);
dest[0] = (int)(val & ~(dest[1]<<data[0]));
private static int bitWidth(int q) {
int v = Integer.numberOfLeadingZeros(q);
if (v == 32) return 1;
return 32-v;
public static String byteBits(byte[] b) {
return byteBits(b, b.length);
public static String byteBits(byte[] b, int n) {
StringBuilder s = new StringBuilder();
for (int j = 0; j < n;j++) {
if (!s.toString().isBlank()) {
for (int i = 7; i >= 0; i--) {
s.append((b[j] & (1L << i)) > 0 ? 1 : 0);
return s.toString();
public static String intBits(int v) {
StringBuilder s = new StringBuilder();
for (int i = 32; i >=0; i--) {
s.append((v & (1L << i)) > 0 ? 1 : 0);
return s.toString();
public static String longBits(long v) {
StringBuilder s = new StringBuilder();
for (int i = 64; i >=0; i--) {
s.append((v & (1L << i)) > 0 ? 1 : 0);
return s.toString();
@ -0,0 +1,18 @@
package nu.marginalia.util;
public class FileSizeUtil {
public static String readableSize(long byteCount) {
if (byteCount < 1024L) {
return String.format("%db", byteCount);
if (byteCount < 1024*1024L) {
return String.format("%2.2fKb", byteCount/1024.);
if (byteCount < 1024*1024*1024L) {
return String.format("%2.2fMb", byteCount/1024/1024.);
return String.format("%2.2fGb", byteCount/1024/1024L/1024.);
@ -0,0 +1,101 @@
package nu.marginalia.util;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public abstract class ParallelPipe<INPUT,INTERMEDIATE> {
private final LinkedBlockingQueue<INPUT> inputs;
private final LinkedBlockingQueue<INTERMEDIATE> intermediates;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final List<Thread> processThreads = new ArrayList<>();
private final Thread receiverThread;
private volatile boolean expectingInput = true;
private volatile boolean expectingOutput = true;
public ParallelPipe(String name, int numberOfThreads, int inputQueueSize, int intermediateQueueSize) {
inputs = new LinkedBlockingQueue<>(inputQueueSize);
intermediates = new LinkedBlockingQueue<>(intermediateQueueSize);
for (int i = 0; i < numberOfThreads; i++) {
processThreads.add(new Thread(this::runProcessThread, name + "-process["+i+"]"));
receiverThread = new Thread(this::runReceiverThread, name + "-receiver");
public void clearQueues() {
private void runProcessThread() {
while (expectingInput || !inputs.isEmpty()) {
var in = inputs.poll(1, TimeUnit.SECONDS);
if (in != null) {
try {
var ret = onProcess(in);
if (ret != null) {
catch (InterruptedException ex) {
throw ex;
catch (Exception ex) {
logger.error("Exception", ex);
logger.debug("Terminating {}", Thread.currentThread().getName());
private void runReceiverThread() {
while (expectingOutput || !inputs.isEmpty() || !intermediates.isEmpty()) {
var intermediate = intermediates.poll(997, TimeUnit.MILLISECONDS);
if (intermediate != null) {
try {
catch (Exception ex) {
logger.error("Exception", ex);
logger.info("Terminating {}", Thread.currentThread().getName());
public void accept(INPUT input) {
protected abstract INTERMEDIATE onProcess(INPUT input) throws Exception;
protected abstract void onReceive(INTERMEDIATE intermediate) throws Exception;
public void join() throws InterruptedException {
expectingInput = false;
for (var thread : processThreads) {
expectingOutput = false;
@ -0,0 +1,41 @@
package nu.marginalia.util;
// This is not a fast way of finding primes
public class PrimeUtil {
public static long nextPrime(long start, long step) {
if (isDivisible(start, 2)) {
start = start + step;
long val;
for (val = start; !isPrime(val); val += 2*step) {}
return val;
public static boolean isPrime(long v) {
if (v <= 2) {
return true;
if ((v & 1) == 0) {
return false;
for (long t = 3; t <= v/3; t++) {
if ((v % t) == 0) {
return false;
return true;
public static boolean isDivisible(long a, long b) {
if (a == 0 || b == 0) {
return false;
if (a > b) {
return (a % b) == 0;
return (b % a) == 0;
@ -0,0 +1,139 @@
package nu.marginalia.util;
import io.prometheus.client.Gauge;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
/** For managing random writes on SSDs
* See https://en.wikipedia.org/wiki/Write_amplification
* */
public class RandomWriteFunnel implements AutoCloseable {
private final static Gauge write_rate = Gauge.build("wmsa_rwf_write_bytes", "Bytes/s")
private final static Gauge transfer_rate = Gauge.build("wmsa_rwf_transfer_bytes", "Bytes/s")
private static final Logger logger = LoggerFactory.getLogger(RandomWriteFunnel.class);
private DataBin[] bins;
private final int binSize;
public RandomWriteFunnel(Path tempDir, long size, int binSize) throws IOException {
this.binSize = binSize;
if (size > 0) {
int binCount = (int) (size / binSize + ((size % binSize) != 0L ? 1 : 0));
bins = new DataBin[binCount];
for (int i = 0; i < binCount; i++) {
bins[i] = new DataBin(tempDir, (int) Math.min(size - binSize * i, binSize));
else {
bins = new DataBin[0];
public void put(long address, long data) throws IOException {
bins[((int)(address / binSize))].put((int)(address%binSize), data);
public void write(FileChannel o) throws IOException {
ByteBuffer buffer = ByteBuffer.allocateDirect(binSize*8);
logger.debug("Writing from RWF");
for (int i = 0; i < bins.length; i++) {
var bin = bins[i];
while (buffer.hasRemaining()) {
int wb = o.write(buffer);
public void close() throws IOException {
for (DataBin bin : bins) {
static class DataBin implements AutoCloseable {
private final ByteBuffer buffer;
private int size;
private final FileChannel channel;
private final File file;
DataBin(Path tempDir, int size) throws IOException {
buffer = ByteBuffer.allocateDirect(360_000);
this.size = size;
file = Files.createTempFile(tempDir, "scatter-writer", ".dat").toFile();
channel = new RandomAccessFile(file, "rw").getChannel();
void put(int address, long data) throws IOException {
if (buffer.capacity() - buffer.position() < 12) {
private void flushBuffer() throws IOException {
if (buffer.position() == 0)
while (channel.write(buffer) > 0);
private void eval(ByteBuffer dest) throws IOException {
for (int i = 0; i < size; i++) {
while (channel.position() < channel.size()) {
int rb = channel.read(buffer);
if (rb < 0) {
else {
while (buffer.limit() - buffer.position() >= 12) {
int addr = buffer.getInt();
long data = buffer.getLong();
dest.putLong(8*addr, data);
public void close() throws IOException {
@ -0,0 +1,73 @@
package nu.marginalia.util;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.function.ToIntFunction;
public abstract class SeekDictionary<T> {
private final ArrayList<T> banks = new ArrayList<>();
private final TIntArrayList offsets = new TIntArrayList();
public static <T> SeekDictionary<T> of(ToIntFunction<T> length) {
return new SeekDictionary<T>() {
public int length(T obj) {
return length.applyAsInt(obj);
public T last() {
return banks.get(banks.size()-1);
public int lastStart() {
return offsets.get(offsets.size()-1);
public abstract int length(T obj);
public int end() {
if (banks.isEmpty()) return 0;
return (offsets.getQuick(offsets.size()-1) + length(last()));
public void add(T obj) {
if (banks.isEmpty()) {
else {
public T bankForOffset(int offset) {
return banks.get(idxForOffset(offset));
public int idxForOffset(int offset) {
int high = offsets.size() - 1;
int low = 0;
while ( low <= high ) {
int mid = ( low + high ) >>> 1;
int midVal = offsets.getQuick(mid);
if ( midVal < offset ) {
low = mid + 1;
else if ( midVal > offset ) {
high = mid - 1;
else {
return mid;
return low-1;
@ -0,0 +1,104 @@
package nu.marginalia.util.btree;
import nu.marginalia.util.btree.model.BTreeContext;
import nu.marginalia.util.btree.model.BTreeHeader;
import nu.marginalia.util.multimap.MultimapFileLong;
import nu.marginalia.util.multimap.MultimapSearcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BTreeReader {
private final MultimapFileLong file;
private final BTreeContext ctx;
private final Logger logger = LoggerFactory.getLogger(BTreeReader.class);
private final long mask;
private final MultimapSearcher searcher;
public BTreeReader(MultimapFileLong file, BTreeContext ctx) {
this.file = file;
this.searcher = file.createSearcher();
this.ctx = ctx;
this.mask = ctx.equalityMask();
public long fileSize() {
return file.size();
public BTreeHeader getHeader(long offset) {
return new BTreeHeader(file.get(offset), file.get(offset+1), file.get(offset+2));
public long offsetForEntry(BTreeHeader header, final long keyRaw) {
final long key = keyRaw & mask;
if (header.layers() == 0) {
return trivialSearch(header, key);
long p = searchEntireTopLayer(header, key);
if (p < 0) return -1;
long cumOffset = p * ctx.BLOCK_SIZE_WORDS();
for (int i = header.layers() - 2; i >= 0; --i) {
long offsetBase = header.indexOffsetLongs() + header.relativeLayerOffset(ctx, i);
p = searchLayerBlock(key, offsetBase+cumOffset);
if (p < 0)
return -1;
cumOffset = ctx.BLOCK_SIZE_WORDS()*(p + cumOffset);
long dataMax = header.dataOffsetLongs() + (long) header.numEntries() * ctx.entrySize();
return searchDataBlock(key,
header.dataOffsetLongs() + ctx.entrySize()*cumOffset,
private long searchEntireTopLayer(BTreeHeader header, long key) {
long offset = header.indexOffsetLongs();
return searcher.binarySearchUpperBound(key, offset, offset + ctx.BLOCK_SIZE_WORDS()) - offset;
private long searchLayerBlock(long key, long blockOffset) {
if (blockOffset < 0)
return blockOffset;
return searcher.binarySearchUpperBound(key, blockOffset, blockOffset + ctx.BLOCK_SIZE_WORDS()) - blockOffset;
private long searchDataBlock(long key, long blockOffset, long dataMax) {
if (blockOffset < 0)
return blockOffset;
long lastOffset = Math.min(blockOffset+ctx.BLOCK_SIZE_WORDS()*(long)ctx.entrySize(), dataMax);
int length = (int)(lastOffset - blockOffset);
if (ctx.entrySize() == 1) {
if (mask == ~0L) return searcher.binarySearchUpperBoundNoMiss(key, blockOffset, blockOffset+length);
return searcher.binarySearchUpperBoundNoMiss(key, blockOffset, blockOffset+length, mask);
return searcher.binarySearchUpperBoundNoMiss(key, blockOffset, ctx.entrySize(), length/ctx.entrySize(), mask);
private long trivialSearch(BTreeHeader header, long key) {
long offset = header.dataOffsetLongs();
if (ctx.entrySize() == 1) {
if (mask == ~0L) {
return searcher.binarySearchUpperBoundNoMiss(key, offset, offset+header.numEntries());
else {
return searcher.binarySearchUpperBoundNoMiss(key, offset, offset+header.numEntries(), mask);
return searcher.binarySearchUpperBoundNoMiss(key, offset, ctx.entrySize(), header.numEntries(), mask);
package nu.marginalia.util.btree;
import nu.marginalia.util.btree.model.BTreeContext;
import nu.marginalia.util.btree.model.BTreeHeader;
import nu.marginalia.util.multimap.MultimapFileLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
public class BTreeWriter {
private final Logger logger = LoggerFactory.getLogger(BTreeWriter.class);
private final BTreeContext ctx;
private final MultimapFileLong map;
public BTreeWriter(MultimapFileLong map, BTreeContext ctx) {
this.map = map;
this.ctx = ctx;
private static long indexSize(BTreeContext ctx, int numWords, int numLayers) {
if (numLayers == 0) {
return 0; // Special treatment for small tables
long size = 0;
for (int layer = 0; layer < numLayers; layer++) {
size += ctx.layerSize(numWords, layer);
return size;
public long write(long offset, int numEntries, WriteCallback writeIndex)
throws IOException
var header = makeHeader(offset, numEntries);
header.write(map, offset);
if (header.layers() < 1) {
return ctx.calculateSize(numEntries);
return ctx.calculateSize(numEntries);
public static BTreeHeader makeHeader(BTreeContext ctx, long offset, int numEntries) {
final int numLayers = ctx.numLayers(numEntries);
final int padding = BTreeHeader.getPadding(ctx, offset, numLayers);
final long indexOffset = offset + BTreeHeader.BTreeHeaderSizeLongs + padding;
final long dataOffset = indexOffset + indexSize(ctx, numEntries, numLayers);
return new BTreeHeader(numLayers, numEntries, indexOffset, dataOffset);
public BTreeHeader makeHeader(long offset, int numEntries) {
return makeHeader(ctx, offset, numEntries);
private void writeIndex(BTreeHeader header) {
var layerOffsets = getRelativeLayerOffsets(header);
long stride = ctx.BLOCK_SIZE_WORDS();
for (int layer = 0; layer < header.layers(); layer++,
stride*=ctx.BLOCK_SIZE_WORDS()) {
long indexWord = 0;
long offsetBase = layerOffsets[layer] + header.indexOffsetLongs();
long numEntries = header.numEntries();
for (long idx = 0; idx < numEntries; idx += stride, indexWord++) {
long dataOffset = header.dataOffsetLongs() + (idx + (stride-1)) * ctx.entrySize();
long val;
if (idx + (stride-1) < numEntries) {
val = map.get(dataOffset) & ctx.equalityMask();
else {
val = Long.MAX_VALUE;
if (offsetBase + indexWord < 0) {
logger.error("bad put @ {}", offsetBase + indexWord);
logger.error("layer{}", layer);
logger.error("layer offsets {}", layerOffsets);
logger.error("offsetBase = {}", offsetBase);
logger.error("numEntries = {}", numEntries);
logger.error("indexWord = {}", indexWord);
map.put(offsetBase + indexWord, val);
for (; (indexWord % ctx.BLOCK_SIZE_WORDS()) != 0; indexWord++) {
map.put(offsetBase + indexWord, Long.MAX_VALUE);
private long[] getRelativeLayerOffsets(BTreeHeader header) {
long[] layerOffsets = new long[header.layers()];
for (int i = 0; i < header.layers(); i++) {
layerOffsets[i] = header.relativeLayerOffset(ctx, i);
return layerOffsets;
package nu.marginalia.util.btree;
import java.io.IOException;
public interface WriteCallback {
void write(long offset) throws IOException;
package nu.marginalia.util.btree.model;
import nu.marginalia.util.btree.BTreeWriter;
public record BTreeContext(int MAX_LAYERS,
int entrySize,
long equalityMask,
public BTreeContext(int MAX_LAYERS, int entrySize, long equalityMask, int BLOCK_SIZE_BITS) {
this(MAX_LAYERS, entrySize, equalityMask, BLOCK_SIZE_BITS, 1 << BLOCK_SIZE_BITS);
public long calculateSize(int numEntries) {
var header = BTreeWriter.makeHeader(this, 0, numEntries);
return header.dataOffsetLongs() + (long)numEntries * entrySize;
public int numLayers(int numEntries) {
if (numEntries <= BLOCK_SIZE_WORDS*2) {
return 0;
for (int i = 1; i < MAX_LAYERS; i++) {
long div = (1L << (BLOCK_SIZE_BITS*i));
long frq = numEntries / div;
if (frq < (1L << BLOCK_SIZE_BITS)) {
if (numEntries == (numEntries & div)) {
return i;
return i+1;
return MAX_LAYERS;
public long layerSize(int numEntries, int level) {
return BLOCK_SIZE_WORDS * numBlocks(numEntries, level);
private long numBlocks(int numWords, int level) {
long layerSize = 1L<<(BLOCK_SIZE_BITS*(level+1));
int numBlocks = 0;
numBlocks += numWords / layerSize;
if (numWords % layerSize != 0) {
return numBlocks;
package nu.marginalia.util.btree.model;
import nu.marginalia.util.multimap.MultimapFileLong;
public record BTreeHeader(int layers, int numEntries, long indexOffsetLongs, long dataOffsetLongs) {
public BTreeHeader {
assert (layers >= 0);
assert (numEntries >= 0);
assert (indexOffsetLongs >= 0);
assert (dataOffsetLongs >= 0);
assert (dataOffsetLongs >= indexOffsetLongs);
public static int BTreeHeaderSizeLongs = 3;
public BTreeHeader(long a, long b, long c) {
this((int)(a >>> 32), (int)(a & 0xFFFF_FFFFL), b, c);
public static int getPadding(BTreeContext ctx, long offset, int numLayers) {
final int padding;
if (numLayers == 0) {
padding = 0;
else {
padding = (int) (ctx.BLOCK_SIZE_WORDS() - ((offset + BTreeHeader.BTreeHeaderSizeLongs) % ctx.BLOCK_SIZE_WORDS()));
return padding;
public void write(MultimapFileLong dest, long offset) {
dest.put(offset, ((long) layers << 32L) | ((long)numEntries & 0xFFFF_FFFFL));
dest.put(offset+1, indexOffsetLongs);
dest.put(offset+2, dataOffsetLongs);
public long relativeLayerOffset(BTreeContext ctx, int n) {
long offset = 0;
for (int i = n+1; i < layers; i++) {
offset += ctx.layerSize( numEntries, i);
return offset;
package nu.marginalia.util.dict;
import nu.marginalia.util.SeekDictionary;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.util.Arrays;
public class DictionaryData {
private final int DICTIONARY_BANK_SIZE;
private static final Logger logger = LoggerFactory.getLogger(DictionaryData.class);
private final SeekDictionary<DictionaryDataBank> banks = SeekDictionary.of(DictionaryDataBank::getSize);
public DictionaryData(int bankSize) {
banks.add(new DictionaryDataBank(0));
public int size() {
return banks.end();
public int add(byte[] data, int value) {
var activeBank = banks.last();
int rb = activeBank.add(data, value);
if (rb == -1) {
int end = activeBank.getEnd();
logger.debug("Switching bank @ {}", end);
var newBank = new DictionaryDataBank(end);
rb = newBank.add(data, value);
return rb;
public byte[] getBytes(int offset) {
return banks.bankForOffset(offset).getBytes(offset);
public boolean keyEquals(int offset, byte[] data) {
return banks.bankForOffset(offset).keyEquals(offset, data);
public int getValue(int offset) {
return banks.bankForOffset(offset).getValue(offset);
public class DictionaryDataBank {
private final int start_idx;
private final ByteBuffer data;
private int size;
private int[] offset;
private int[] value;
public DictionaryDataBank(int start_idx) {
this.start_idx = start_idx;
data = ByteBuffer.allocateDirect(DICTIONARY_BANK_SIZE);
offset = new int[DICTIONARY_BANK_SIZE/16];
value = new int[DICTIONARY_BANK_SIZE/16];
size = 0;
public int getStart() {
return start_idx;
public int getEnd() {
return start_idx + size;
public byte[] getBytes(int idx) {
if (idx < start_idx || idx - start_idx >= size) {
throw new IndexOutOfBoundsException(idx);
idx = idx - start_idx;
final int start;
final int end = offset[idx];
if (idx == 0) start = 0;
else start = offset[idx-1];
byte[] dst = new byte[end-start];
data.get(start, dst);
return dst;
public int getValue(int idx) {
if (idx < start_idx || idx - start_idx >= size) {
throw new IndexOutOfBoundsException(idx);
return value[idx - start_idx];
public boolean keyEquals(int idx, byte[] data) {
if (idx < start_idx || idx - start_idx >= size) {
throw new IndexOutOfBoundsException(idx);
idx = idx - start_idx;
int start;
int end = offset[idx];
if (idx == 0) {
start = 0;
else {
start = offset[idx-1];
if (data.length != end - start) {
return false;
for (int i = 0; i < data.length; i++) {
if (this.data.get(start + i) != data[i]) {
return false;
return true;
public long longHashCode(int idx) {
if (idx < start_idx || idx - start_idx >= size) {
throw new IndexOutOfBoundsException(idx);
idx = idx - start_idx;
int start;
int end = offset[idx];
if (idx == 0) {
start = 0;
else {
start = offset[idx-1];
long result = 1;
for (int i = start; i < end; i++)
result = 31 * result + data.get(i);
return result;
public int add(byte[] newData, int newValue) {
if (size == offset.length) {
logger.debug("Growing bank from {} to {}", offset.length, offset.length*2);
offset = Arrays.copyOf(offset, offset.length*2);
value = Arrays.copyOf(value, value.length*2);
if (size > 0 && offset[size-1]+newData.length >= DICTIONARY_BANK_SIZE) {
if (offset.length > size+1) {
logger.debug("Shrinking bank from {} to {}", offset.length, size - 1);
offset = Arrays.copyOf(offset, size + 1);
value = Arrays.copyOf(value, size + 1);
return -1; // Full
int dataOffset = size > 0 ? offset[size-1] : 0;
data.put(dataOffset, newData);
offset[size] = dataOffset + newData.length;
value[size] = newValue;
return start_idx + size++;
public int getSize() {
return size;
package nu.marginalia.util.dict;
import io.prometheus.client.Gauge;
import nu.marginalia.util.PrimeUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import static java.lang.Math.round;
import static nu.marginalia.util.FileSizeUtil.readableSize;
* Spiritually influenced by GNU Trove's hash maps
* LGPL 2.1
public class DictionaryHashMap {
private static final Logger logger = LoggerFactory.getLogger(DictionaryHashMap.class);
private static final Gauge probe_count_metrics
= Gauge.build("wmsa_dictionary_hash_map_probe_count", "Probing Count")
private final int bufferCount;
private final IntBuffer[] buffers;
public static final int NO_VALUE = Integer.MIN_VALUE;
private final DictionaryData dictionaryData;
private final long hashTableSize;
private final int bufferSizeBytes;
private final int intsPerBuffer;
private final long maxProbeLength;
private AtomicInteger sz = new AtomicInteger(0);
public DictionaryHashMap(long sizeMemory) {
final int intSize = 4;
bufferCount = 1 + (int) ((intSize*sizeMemory) / (1<<30));
buffers = new IntBuffer[bufferCount];
// Actually use a prime size for Donald Knuth reasons
hashTableSize = PrimeUtil.nextPrime(sizeMemory, -1);
intsPerBuffer = 1 + (int)(sizeMemory/ bufferCount);
bufferSizeBytes = intSize*intsPerBuffer;
maxProbeLength = sizeMemory/10;
logger.info("Allocating dictionary hash map of size {}, capacity: {}",
readableSize((long) bufferCount * bufferSizeBytes),
logger.info("available-size:{} memory-size:{} buffer-count: {}, buffer-size:{} ints-per-buffer:{} max-probe-length:{}",
hashTableSize, sizeMemory, bufferCount, bufferSizeBytes, intsPerBuffer, maxProbeLength);
if (((long) bufferCount * intsPerBuffer) < sizeMemory) {
logger.error("Buffer memory is less than requested memory: {}*{} = {} < {}; this data structure is not safe to use",
bufferSizeBytes, (long) bufferCount * bufferSizeBytes,
throw new Error("Irrecoverable logic error");
else {
logger.debug("Buffer size sanity checked passed");
dictionaryData = new DictionaryData(Math.min(1<<30, Math.max(32, (int)(sizeMemory/4))));
private void initializeBuffers() {
for (int b = 0; b < bufferCount; b++) {
buffers[b] = ByteBuffer.allocateDirect(bufferSizeBytes).asIntBuffer();
for (int i = 0; i < intsPerBuffer; i++) {
buffers[b].put(i, NO_VALUE);
public int memSz() {
return dictionaryData.size();
public int size() {
return sz.get();
private int getCell(long idx) {
int buffer = (int)(idx / intsPerBuffer);
int bufferIdx = (int)(idx % intsPerBuffer);
return buffers[buffer].get(bufferIdx);
private void setCell(long idx, int val) {
int buffer = (int)(idx / intsPerBuffer);
int bufferIdx = (int)(idx % intsPerBuffer);
buffers[buffer].put(bufferIdx, val);
public int put(byte[] data, int value) {
long hash = longHash(data) & 0x7FFF_FFFF_FFFF_FFFFL;
long idx = hash % hashTableSize;
if (getCell(idx) == NO_VALUE) {
return setValue(data, value, idx);
return putRehash(data, value, idx, hash);
private int putRehash(byte[] data, int value, long idx, long hash) {
final long pStride = 1 + (hash % (hashTableSize - 2));
for (long j = 1; j < maxProbeLength; j++) {
idx = idx - pStride;
if (idx < 0) {
idx += hashTableSize;
final int val = getCell(idx);
if (val == NO_VALUE) {
return setValue(data, value, idx);
else if (dictionaryData.keyEquals(val, data)) {
return val;
throw new IllegalStateException("DictionaryHashMap full @ size " + size() + "/" + hashTableSize + ", " + round((100.0*size()) / hashTableSize) + "%");
private int setValue(byte[] data, int value, long cell) {
int di = dictionaryData.add(data, value);
setCell(cell, di);
return di;
public int get(byte[] data) {
final long hash = longHash(data) & 0x7FFF_FFFF_FFFF_FFFFL;
final long cell = hash % hashTableSize;
if (getCell(cell) == NO_VALUE) {
return NO_VALUE;
else {
int val = getCell(cell);
if (dictionaryData.keyEquals(val, data)) {
return dictionaryData.getValue(val);
return getRehash(data, cell, hash);
private int getRehash(byte[] data, long idx, long hash) {
final long pStride = 1 + (hash % (hashTableSize - 2));
for (long j = 1; j < maxProbeLength; j++) {
idx = idx - pStride;
if (idx < 0) {
idx += hashTableSize;
final var val = getCell(idx);
if (val == NO_VALUE) {
return NO_VALUE;
else if (dictionaryData.keyEquals(val, data)) {
return dictionaryData.getValue(val);
throw new IllegalStateException("DictionaryHashMap full @ size " + size() + "/" + hashTableSize + ", " + round((100.0*size()) / hashTableSize) + "%");
private long longHash(byte[] bytes) {
if (bytes == null)
return 0;
// https://cp-algorithms.com/string/string-hashing.html
int p = 127;
long m = (1L<<61)-1;
long p_power = 1;
long hash_val = 0;
for (byte element : bytes) {
hash_val = (hash_val + (element+1) * p_power) % m;
p_power = (p_power * p) % m;
return hash_val;
package nu.marginalia.util.graphics.dithering;
import lombok.AllArgsConstructor;
import net.sf.image4j.util.ConvertUtil;
import org.imgscalr.Scalr;
import java.awt.image.BufferedImage;
import java.awt.image.IndexColorModel;
import java.util.Arrays;
import java.util.Comparator;
public class FloydSteinbergDither {
private final Color[] palette;
private final int maxWidth;
private final int maxHeight;
public FloydSteinbergDither(int[] colors, int maxWidth, int maxHeight) {
this.maxWidth = maxWidth;
this.maxHeight = maxHeight;
palette = Arrays.stream(colors)
public BufferedImage convert(BufferedImage src) {
BufferedImage out = dither(resize(src));
if (palette.length <= 16) {
int[] cmap = new int[palette.length];
for (int i = 0; i < palette.length; i++) {
cmap[i] = palette[i].toInt();
return ConvertUtil.convert4(out, cmap);
return out;
private BufferedImage dither(BufferedImage in) {
Errors errors = new Errors(in.getWidth(), in.getHeight());
final BufferedImage out = createOutBuffer(in);
for (int y = 0; y < in.getHeight(); y++) {
for (int x = 0; x < in.getWidth(); x++) {
setOutPixel(errors, out, in, x, y, 1);
if (++y >= in.getHeight()) {
for (int x = in.getWidth()-1; x >= 0; x--) {
setOutPixel(errors, out, in, x, y, -1);
return out;
private void setOutPixel(Errors errors, BufferedImage out, BufferedImage in, int x, int y, int dx) {
final Color color = new Color(in.getRGB(x, y));
final Color adjustedColor = errors.adjust(color, x, y);
final int newColor = getNearestColorAndDiffuseError(errors,
x, dx, y,
adjustedColor, color);
out.setRGB(x, y, newColor);
private BufferedImage createOutBuffer(BufferedImage in) {
var indexModel = createIndexColorModel();
return new BufferedImage(indexModel,
indexModel.createCompatibleWritableRaster(in.getWidth(), in.getHeight()),
false, null);
private BufferedImage resize(BufferedImage src) {
if (maxWidth < 0 || maxHeight < 0) {
return src;
final int width = src.getWidth();
final int height = src.getHeight();
double scaleF = Math.min(scaleFactor(width, maxWidth),
scaleFactor(height, maxHeight));
if (scaleF < 1.0) {
int newWidth = (int)Math.min(maxWidth, scaleF * width);
int newHeight = (int)Math.min(maxHeight, scaleF * height);
return Scalr.resize(src,
newWidth, newHeight);
return src;
private double scaleFactor(int actualValue, int desiredValue) {
if (actualValue <= desiredValue) {
return 1.;
return desiredValue / (double) actualValue;
private IndexColorModel createIndexColorModel() {
byte[] reds = new byte[palette.length];
byte[] greens = new byte[palette.length];
byte[] blues = new byte[palette.length];
for (int i = 0; i < palette.length; i++) {
int colorInt = palette[i].toInt();
reds[i] = (byte) ((colorInt >>> 16) & 0xFF);
greens[i] = (byte) ((colorInt >>> 8) & 0xFF);
blues[i] = (byte) ((colorInt) & 0xFF);
return new IndexColorModel(getPaletteBits(palette), palette.length, reds, greens, blues);
private int getPaletteBits(Color[] palette) {
if (palette.length <= 16) {
return 4;
else {
return 8;
private int getNearestColorAndDiffuseError(Errors errors, int x, int dx, int y, Color color, Color colorOrig) {
var match = Arrays.stream(palette).min(Comparator.comparing(c -> c.delta(color)));
assert match.isPresent();
var retC = match.get();
var error = colorOrig.minus(retC);
errors.add(x+dx, y, error.scale(7/16.));
errors.add(x+dx, y+1, error.scale(1/16.));
errors.add(x, y+1, error.scale(5/16.));
errors.add(x-dx, y+1, error.scale(3/16.));
return retC.toInt();
class Errors {
private final int width;
private final int height;
private final Color[] errors;
Errors(int width, int height) {
this.width = width;
this.height = height;
errors = new Color[width * height];
public void add(int x, int y, Color color) {
if (x > 0 && y > 0 && x + 1 < width && y + 1 < height) {
int index = getIndex(x, y);
if (errors[index] == null) {
errors[index] = color;
else {
errors[index] = errors[index].plus(color);
public Color adjust(Color in, int x, int y) {
int idx = getIndex(x, y);
if (errors[idx] != null) {
return in.plus(errors[idx]);
return in;
private int getIndex(int x, int y) {
return x * height + y;
class Color {
private final double r;
private final double g;
private final double b;
Color(int hex) {
this.b = ((hex) & 0xFF);
this.g = ((hex >>> 8) & 0xFF);
this.r = ((hex >>> 16) & 0xFF);
int toInt() {
double bv = clampByteRange(b);
double gv = clampByteRange(g);
double rv = clampByteRange(r);
return (((int)bv&0xFF) | (((int)gv & 0xFF) << 8) | (((int)rv & 0xFF) << 16));
double clampByteRange(double v) {
if (v < 0) return 0;
if (v > 255) return 255;
return v;
public Color scale(double factor) {
return new Color(r*factor, g*factor, b*factor);
public Color plus(Color other) {
return new Color(r+other.r, g+other.g, b+other.b);
public Color minus(Color other) {
return new Color(r-other.r, g-other.g, b-other.b);
public double delta(Color other) {
double avgr = (r + other.r)/2;
double dr = r - other.r;
double dg = g - other.g;
double db = b - other.b;
if (avgr > 128) {
return Math.sqrt(2 * dr * dr + 4 * dg * dg + 3 * db * db);
else {
return Math.sqrt(3 * dr * dr + 4 * dg * dg + 2 * db * db);
package nu.marginalia.util.graphics.dithering;
public class Palettes {
public static int[] MARGINALIA_PALETTE = new int[] {
public static int[] CGA_PALETTE = new int[]{
0x000000, 0xFFFFFF, 0x808080, 0xFF0000,
0x800000, 0x00FF00, 0x008000, 0x0000FF,
0x000080, 0xFFFF00, 0x808000, 0x00FFFF,
0x008080, 0xFF00FF, 0x800080, 0x404040
package nu.marginalia.util.hash;
import io.prometheus.client.Gauge;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import nu.marginalia.wmsa.edge.index.service.index.wordstable.IndexWordsTable;
import nu.marginalia.util.multimap.MultimapFileLong;
import nu.marginalia.util.PrimeUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static java.lang.Math.round;
* Spiritually influenced by GNU Trove's hash maps
* LGPL 2.1
public class LongPairHashMap {
private static final Logger logger = LoggerFactory.getLogger(LongPairHashMap.class);
private static final Gauge probe_count_metrics
= Gauge.build("wmsa_wordfile_hash_map_probe_count", "Probing Count")
private final long hashTableSize;
private final MultimapFileLong data;
private final long maxProbeLength;
private int sz = 0;
private static final int HEADER_SIZE = 2;
public LongPairHashMap(MultimapFileLong data, long size) {
this.data = data;
// Actually use a prime size for Donald Knuth reasons
hashTableSize = PrimeUtil.nextPrime(size, 1);
maxProbeLength = hashTableSize / 2;
logger.debug("Table size = " + hashTableSize);
data.put(0, IndexWordsTable.Strategy.HASH.ordinal());
data.put(1, hashTableSize);
for (int i = 2; i < hashTableSize; i++) {
data.put(HEADER_SIZE + 2L*i, 0);
public LongPairHashMap(MultimapFileLong data) {
this.data = data;
hashTableSize = data.get(1);
maxProbeLength = hashTableSize / 10;
logger.debug("Table size = " + hashTableSize);
public int size() {
return sz;
private CellData getCell(long idx) {
long bufferIdx = 2*idx + HEADER_SIZE;
long a = data.get(bufferIdx);
long b = data.get(bufferIdx+1);
return new CellData(a, b);
private void setCell(long idx, CellData cell) {
long bufferIdx = 2*idx + HEADER_SIZE;
data.put(bufferIdx, cell.first);
data.put(bufferIdx+1, cell.second);
public CellData put(CellData data) {
long hash = longHash(data.getKey()) & 0x7FFF_FFFFL;
long idx = hash% hashTableSize;
if (!getCell(hash% hashTableSize).isSet()) {
return setValue(data, hash% hashTableSize);
return putRehash(data, idx, hash);
private CellData putRehash(CellData data, long idx, long hash) {
final long pStride = 1 + (hash % (hashTableSize - 2));
for (long j = 1; j < maxProbeLength; j++) {
idx = idx - pStride;
if (idx < 0) {
idx += hashTableSize;
final var val = getCell(idx);
if (!val.isSet()) {
return setValue(data, idx);
else if (val.getKey() == data.getKey()) {
logger.error("Double write?");
return val;
throw new IllegalStateException("DictionaryHashMap full @ size " + size() + "/" + hashTableSize + ", " + round((100.0*size()) / hashTableSize) + "%, key = " + data.getKey() + ",#"+hash);
private CellData setValue(CellData data, long cell) {
setCell(cell, data);
return data;
public CellData get(int key) {
if (hashTableSize == 0) {
return new CellData(0, 0);
final long hash = longHash(key) & 0x7FFF_FFFFL;
var val = getCell(hash % hashTableSize);
if (!val.isSet()) {
return val;
else if (val.getKey() == key) {
return val;
return getRehash(key, hash % hashTableSize, hash);
private CellData getRehash(int key, long idx, long hash) {
final long pStride = 1 + (hash % (hashTableSize - 2));
for (long j = 1; j < maxProbeLength; j++) {
idx = idx - pStride;
if (idx < 0) {
idx += hashTableSize;
final var val = getCell(idx);
if (!val.isSet()) {
return val;
else if (val.getKey() == key) {
return val;
throw new IllegalStateException("DictionaryHashMap full @ size " + size() + "/" + hashTableSize + ", " + round((100.0*size()) / hashTableSize) + "%");
private long longHash(long x) {
return x;
@Getter @EqualsAndHashCode
public static class CellData {
long first;
long second;
public CellData(long key, long offset) {
first = key | 0x8000_0000_000_000L;
second = offset;
public long getKey() {
return first & ~0x8000_0000_000_000L;
public long getOffset() {
return second;
public boolean isSet() {
return first != 0 || second != 0L;
public void close() throws Exception {
package nu.marginalia.util.multimap;
import com.upserve.uppend.blobs.NativeIO;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
import static java.nio.channels.FileChannel.MapMode.READ_WRITE;
import static nu.marginalia.util.FileSizeUtil.readableSize;
public class MultimapFileLong implements AutoCloseable {
private final ArrayList<LongBuffer> buffers = new ArrayList<>();
private final ArrayList<MappedByteBuffer> mappedByteBuffers = new ArrayList<>();
private final FileChannel.MapMode mode;
private final int bufferSize;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final FileChannel channel;
private final long mapSize;
private final long fileLength;
private long mappedSize;
final static long WORD_SIZE = 8;
private boolean loadAggressively;
private NativeIO.Advice advice = null;
public static MultimapFileLong forReading(Path file) throws IOException {
long fileSize = Files.size(file);
int bufferSize = getBufferSize(fileSize, false);
return new MultimapFileLong(file.toFile(), READ_ONLY, Files.size(file), bufferSize);
public static MultimapFileLong forOutput(Path file, long estimatedSize) throws IOException {
return new MultimapFileLong(file.toFile(), READ_WRITE, 0, getBufferSize(estimatedSize, true));
private static int getBufferSize(long totalSize, boolean write) {
if (totalSize > Integer.MAX_VALUE/WORD_SIZE) {
return (int)(Integer.MAX_VALUE/WORD_SIZE);
else if (write && totalSize < 8*1024*1024) {
return 8*1024*1024;
else {
return (int) Math.min(totalSize, Integer.MAX_VALUE/WORD_SIZE);
public MultimapFileLong(File file,
FileChannel.MapMode mode,
long mapSize,
int bufferSize) throws IOException {
this(new RandomAccessFile(file, translateToRAFMode(mode)), mode, mapSize, bufferSize, false);
public MultimapFileLong loadAggressively(boolean v) {
this.loadAggressively = v;
return this;
private static String translateToRAFMode(FileChannel.MapMode mode) {
if (READ_ONLY.equals(mode)) {
return "r";
} else if (READ_WRITE.equals(mode)) {
return "rw";
return "rw";
public MultimapFileLong(RandomAccessFile file,
FileChannel.MapMode mode,
long mapSizeBytes,
int bufferSizeWords,
boolean loadAggressively) throws IOException {
this.mode = mode;
this.bufferSize = bufferSizeWords;
this.mapSize = mapSizeBytes;
this.fileLength = file.length();
this.loadAggressively = loadAggressively;
channel = file.getChannel();
mappedSize = 0;
logger.debug("Creating multimap file size = {} / buffer size = {}, mode = {}",
readableSize(mapSizeBytes), readableSize(8L*bufferSizeWords), mode);
public MultimapSearcher createSearcher() {
return new MultimapSearcher(this);
public MultimapSorter createSorter(Path tmpFile, int internalSortLimit) {
return new MultimapSorter(this, tmpFile, internalSortLimit);
public void advice(NativeIO.Advice advice) {
for (var buffer : mappedByteBuffers) {
NativeIO.madvise(buffer, advice);
public void advice0(NativeIO.Advice advice) {
NativeIO.madvise(mappedByteBuffers.get(0), advice);
public void adviceRange(NativeIO.Advice advice, long startLongs, long lengthLongs) {
long endLongs = (startLongs+lengthLongs);
if (endLongs >= mappedSize)
var buff = mappedByteBuffers.get((int)(startLongs / bufferSize));
if ((int)(startLongs / bufferSize) != (int)((endLongs) / bufferSize)) {
logger.warn("Misaligned madvise, skipping");
NativeIO.madviseRange(buff, advice, (startLongs % bufferSize) * WORD_SIZE, (int)(lengthLongs*WORD_SIZE));
public void pokeRange(long offset, int length) {
for (int i = 0; i < length; i += 4096/8) {
get(offset + i);
public void force() {
for (MappedByteBuffer buffer: mappedByteBuffers) {
private void grow(long posIdxRequired) {
if (posIdxRequired*WORD_SIZE > mapSize && mode == READ_ONLY) {
throw new IndexOutOfBoundsException(posIdxRequired + " (max " + mapSize + ")");
logger.trace("Growing to encompass {}i/{}b", posIdxRequired, posIdxRequired*WORD_SIZE);
long start;
if (buffers.isEmpty()) {
start = 0;
else {
start = (long) buffers.size() * bufferSize;
for (long posIdx = start; posIdxRequired >= posIdx; posIdx += bufferSize) {
long posBytes = posIdx * WORD_SIZE;
long bzBytes;
if (mode == READ_ONLY) {
bzBytes = Math.min(WORD_SIZE*bufferSize, mapSize - posBytes);
else {
bzBytes = WORD_SIZE*bufferSize;
logger.trace("Allocating {}-{}", posBytes, posBytes+bzBytes);
var buffer = channel.map(mode, posBytes, bzBytes);
if (loadAggressively)
if (advice != null) {
NativeIO.madvise(buffer, advice);
mappedSize += bzBytes/WORD_SIZE;
public long size() {
return fileLength;
public void put(long idx, long val) {
if (idx >= mappedSize)
try {
buffers.get((int)(idx / bufferSize)).put((int) (idx % bufferSize), val);
catch (IndexOutOfBoundsException ex) {
logger.error("Index out of bounds {} -> {}:{} cap {}", idx, buffers.get((int)(idx / bufferSize)), idx % bufferSize,
buffers.get((int)(idx / bufferSize)).capacity());
throw new RuntimeException(ex);
public long get(long idx) {
if (idx >= mappedSize)
try {
return buffers.get((int)(idx / bufferSize)).get((int)(idx % bufferSize));
catch (IndexOutOfBoundsException ex) {
logger.error("Index out of bounds {} -> {}:{} cap {}", idx, buffers.get((int)(idx / bufferSize)), idx % bufferSize,
buffers.get((int)(idx / bufferSize)).capacity());
throw new RuntimeException(ex);
public void read(long[] vals, long idx) {
read(vals, vals.length, idx);
public void read(long[] vals, int n, long idx) {
if (idx+n >= mappedSize) {
int iN = (int)((idx + n) / bufferSize);
for (int i = 0; i < n; ) {
int i0 = (int)((idx + i) / bufferSize);
int bufferOffset = (int) ((idx+i) % bufferSize);
var buffer = buffers.get(i0);
final int l;
if (i0 < iN) l = bufferSize - bufferOffset;
else l = Math.min(n - i, bufferSize - bufferOffset);
buffer.get(bufferOffset, vals, i, l);
public void write(long[] vals, long idx) {
write(vals, vals.length, idx);
public void write(long[] vals, int n, long idx) {
if (idx+n >= mappedSize) {
int iN = (int)((idx + n) / bufferSize);
for (int i = 0; i < n; ) {
int i0 = (int)((idx + i) / bufferSize);
int bufferOffset = (int) ((idx+i) % bufferSize);
var buffer = buffers.get(i0);
final int l;
if (i0 < iN) l = bufferSize - bufferOffset;
else l = Math.min(n - i, bufferSize - bufferOffset);
buffer.put(bufferOffset, vals, i, l);
public void write(LongBuffer vals, long idx) {
int n = vals.limit() - vals.position();
if (idx+n >= mappedSize) {
int iN = (int)((idx + n) / bufferSize);
for (int i = 0; i < n; ) {
int i0 = (int)((idx + i) / bufferSize);
int bufferOffset = (int) ((idx+i) % bufferSize);
var buffer = buffers.get(i0);
final int l;
if (i0 < iN) l = bufferSize - bufferOffset;
else l = Math.min(n - i, bufferSize - bufferOffset);
buffer.put(bufferOffset, vals, vals.position() + i, l);
public void transferFromFileChannel(FileChannel sourceChannel, long destOffset, long sourceStart, long sourceEnd) throws IOException {
int length = (int)(sourceEnd - sourceStart);
if (destOffset+length >= mappedSize) {
int i0 = (int)((destOffset) / bufferSize);
int iN = (int)((destOffset + length) / bufferSize);
int numBuffers = iN - i0 + 1;
ByteBuffer[] buffers = new ByteBuffer[numBuffers];
for (int i = 0; i < numBuffers; i++) {
buffers[i] = mappedByteBuffers.get(i0 + i);
if (i0 != iN) {
int startBuf0 = (int) ((destOffset) % bufferSize) * 8;
int endBuf0 = buffers[0].capacity() - (int) ((destOffset) % bufferSize) * 8;
int endBufN = (int)((destOffset + length) % bufferSize)*8;
buffers[0] = buffers[0].slice(startBuf0, endBuf0);
buffers[numBuffers-1] = buffers[numBuffers-1].slice(0, endBufN);
else {
buffers[0] = buffers[0].slice((int) ((destOffset) % bufferSize) * 8, 8*length);
long twb = 0;
while (twb < length * 8L) {
long rb = sourceChannel.read(buffers, 0, buffers.length);
if (rb < 0)
throw new IOException();
twb += rb;
public void close() throws IOException {
// I want to believe
@ -0,0 +1,128 @@
package nu.marginalia.util.multimap;
import lombok.experimental.Delegate;
public class MultimapSearcher {
private final MultimapFileLong mmf;
public MultimapSearcher(MultimapFileLong mmf) {
this.mmf = mmf;
public boolean binarySearch(long key, long fromIndex, long toIndex) {
long low = fromIndex;
long high = toIndex - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(mid);
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return true; // key found
return false; // key not found.
public long binarySearchUpperBound(long key, long fromIndex, long toIndex) {
long low = fromIndex;
long high = toIndex - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(mid);
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return mid;
return low;
public long binarySearchUpperBound(long key, long fromIndex, long toIndex, long mask) {
long low = fromIndex;
long high = toIndex - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(mid) & mask;
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return mid;
return low;
public long binarySearchUpperBoundNoMiss(long key, long fromIndex, long toIndex) {
long low = fromIndex;
long high = toIndex - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(mid);
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return mid;
return -1;
public long binarySearchUpperBoundNoMiss(long key, long fromIndex, long toIndex, long mask) {
long low = fromIndex;
long high = toIndex - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(mid) & mask;
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return mid;
return -1;
public long binarySearchUpperBoundNoMiss(long key, long fromIndex, long step, long steps, long mask) {
long low = 0;
long high = steps - 1;
while (low <= high) {
long mid = (low + high) >>> 1;
long midVal = get(fromIndex + mid*step) & mask;
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
return fromIndex + mid*step;
return -1;
@ -0,0 +1,89 @@
package nu.marginalia.util.multimap;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.LongBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import static nu.marginalia.util.multimap.MultimapFileLong.WORD_SIZE;
public class MultimapSorter {
private final Path tmpFileDir;
private final int internalSortLimit;
private final MultimapFileLong multimapFileLong;
private final long[] buffer;
public MultimapSorter(MultimapFileLong multimapFileLong, Path tmpFileDir, int internalSortLimit) {
this.multimapFileLong = multimapFileLong;
this.tmpFileDir = tmpFileDir;
this.internalSortLimit = internalSortLimit;
buffer = new long[internalSortLimit];
public void sort(long start, int length) throws IOException {
if (length <= internalSortLimit) {
multimapFileLong.read(buffer, length, start);
Arrays.sort(buffer, 0, length);
multimapFileLong.write(buffer, length, start);
else {
externalSort(start, length);
private void externalSort(long start, int length) throws IOException {
Path tmpFile = Files.createTempFile(tmpFileDir,"sort-"+start+"-"+(start+length), ".dat");
try (var raf = new RandomAccessFile(tmpFile.toFile(), "rw"); var channel = raf.getChannel()) {
var workBuffer =
channel.map(FileChannel.MapMode.READ_WRITE, 0, length * WORD_SIZE)
int width = Math.min(Integer.highestOneBit(length), Integer.highestOneBit(internalSortLimit));
// Do in-memory sorting up until internalSortLimit first
for (int i = 0; i < length; i += width) {
sort(start + i, Math.min(width, length-i));
// Then merge sort on disk for the rest
for (; width < length; width*=2) {
for (int i = 0; i < length; i += 2*width) {
merge(start, i, Math.min(i+width, length), Math.min(i+2*width, length), workBuffer);
multimapFileLong.write(workBuffer, start);
finally {
void merge(long offset, int left, int right, int end, LongBuffer workBuffer) {
int i = left;
int j = right;
for (int k = left; k < end; k++) {
final long bufferI = multimapFileLong.get(offset+i);
final long bufferJ = multimapFileLong.get(offset+j);
if (i < right && (j >= end || bufferI < bufferJ)) {
workBuffer.put(k, bufferI);
else {
workBuffer.put(k, bufferJ);
package nu.marginalia.wmsa.auth;
import com.google.inject.AbstractModule;
import com.google.inject.name.Names;
import java.nio.file.Path;
public class AuthConfigurationModule extends AbstractModule {
public void configure() {
@ -0,0 +1,28 @@
package nu.marginalia.wmsa.auth;
import com.google.inject.Guice;
import com.google.inject.Inject;
import com.google.inject.Injector;
import nu.marginalia.wmsa.configuration.MainClass;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import nu.marginalia.wmsa.configuration.module.ConfigurationModule;
import nu.marginalia.wmsa.configuration.server.Initialization;
import java.io.IOException;
public class AuthMain extends MainClass {
public AuthMain(AuthService service) throws IOException {
public static void main(String... args) {
init(ServiceDescriptor.AUTH, args);
Injector injector = Guice.createInjector(
new AuthConfigurationModule(),
new ConfigurationModule());
@ -0,0 +1,105 @@
package nu.marginalia.wmsa.auth;
import com.github.jknack.handlebars.internal.Files;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import nu.marginalia.wmsa.auth.model.LoginFormModel;
import nu.marginalia.wmsa.configuration.server.*;
import nu.marginalia.wmsa.renderer.mustache.MustacheRenderer;
import nu.marginalia.wmsa.renderer.mustache.RendererFactory;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import spark.Request;
import spark.Response;
import spark.Spark;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import static spark.Spark.*;
public class AuthService extends Service {
private final Logger logger = LoggerFactory.getLogger(getClass());
private String password;
private final RateLimiter rateLimiter = RateLimiter.forLogin();
private final MustacheRenderer<LoginFormModel> loginFormRenderer;
public AuthService(@Named("service-host") String ip,
@Named("service-port") Integer port,
@Named("password-file") Path topSecretPasswordFile,
RendererFactory rendererFactory,
Initialization initialization,
MetricsServer metricsServer) throws IOException {
super(ip, port, initialization, metricsServer);
try (var is = new FileReader(topSecretPasswordFile.toFile())) {
password = Files.read(is);
} catch (IOException e) {
logger.error("Could not read password from file " + topSecretPasswordFile, e);
loginFormRenderer = rendererFactory.renderer("auth/login");
Spark.path("public/api", () -> {
before((req, rsp) -> {
logger.info("{} {}", req.requestMethod(), req.pathInfo());
post("/login", this::login);
get("/login", this::loginForm);
Spark.path("api", () -> {
get("/is-logged-in", this::isLoggedIn);
private Object loginForm(Request request, Response response) throws IOException {
String redir = Objects.requireNonNull(request.queryParams("redirect"));
String service = Objects.requireNonNull(request.queryParams("service"));
return loginFormRenderer.render(new LoginFormModel(service, redir));
private Object login(Request request, Response response) {
var redir = Objects.requireNonNullElse(request.queryParams("redirect"), "/");
if (isLoggedIn(request, response)) {
return "";
if (!rateLimiter.isAllowed(Context.fromRequest(request))) {
Spark.halt(429, "Too many requests");
return null;
if (Objects.equals(password, request.queryParams("password"))) {
request.session(true).attribute("logged-in", true);
return "";
return "<h1>Bad password!</h1>";
public boolean isLoggedIn(Request request, Response response) {
var session = request.session(false);
if (null == session) {
return false;
return Optional.ofNullable(session.attribute("logged-in"))
package nu.marginalia.wmsa.auth.api;
import com.google.inject.Guice;
import com.google.inject.Inject;
import com.google.inject.Injector;
import nu.marginalia.wmsa.auth.AuthConfigurationModule;
import nu.marginalia.wmsa.auth.AuthMain;
import nu.marginalia.wmsa.auth.AuthService;
import nu.marginalia.wmsa.configuration.MainClass;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import nu.marginalia.wmsa.configuration.module.ConfigurationModule;
import nu.marginalia.wmsa.configuration.module.DatabaseModule;
import nu.marginalia.wmsa.configuration.server.Initialization;
import java.io.IOException;
public class ApiMain extends MainClass {
public ApiMain(ApiService service) throws IOException {
public static void main(String... args) {
init(ServiceDescriptor.API, args);
Injector injector = Guice.createInjector(
new DatabaseModule(),
new ConfigurationModule());
@ -0,0 +1,127 @@
package nu.marginalia.wmsa.auth.api;
import com.google.common.base.Strings;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.wmsa.auth.api.model.ApiLicense;
import nu.marginalia.wmsa.configuration.server.*;
import nu.marginalia.wmsa.edge.search.client.EdgeSearchClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import spark.Request;
import spark.Response;
import spark.Spark;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
public class ApiService extends Service {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Gson gson = new GsonBuilder().create();
private final EdgeSearchClient searchClient;
private final HikariDataSource dataSource;
private final ConcurrentHashMap<String, ApiLicense> licenseCache = new ConcurrentHashMap<>();
private final ConcurrentHashMap<ApiLicense, RateLimiter> rateLimiters = new ConcurrentHashMap<>();
public ApiService(@Named("service-host") String ip,
@Named("service-port") Integer port,
Initialization initialization,
MetricsServer metricsServer,
EdgeSearchClient searchClient,
HikariDataSource dataSource)
throws IOException
super(ip, port, initialization, metricsServer);
this.searchClient = searchClient;
this.dataSource = dataSource;
Spark.get("/public/api/", (rq, rsp) -> {
logger.info("Redireting to info");
return "";
Spark.get("/public/api/:key/", this::getKeyInfo, gson::toJson);
Spark.get("/public/api/:key/search/*", this::search, gson::toJson);
private Object getKeyInfo(Request request, Response response) {
return getLicense(request);
private Object search(Request request, Response response) {
String[] args = request.splat();
if (args.length != 1) {
var license = getLicense(request);
if (null == license) {
return "Forbidden";
RateLimiter rl = getRateLimiter(license);
if (rl != null && !rl.isAllowed()) {
return "Slow down";
int count = Integer.parseInt(request.queryParamOrDefault("count", "20"));
int index = Integer.parseInt(request.queryParamOrDefault("index", "3"));
logger.info("{} Search {}", license.key, args[0]);
return searchClient.query(Context.fromRequest(request), args[0], count, index)
private RateLimiter getRateLimiter(ApiLicense license) {
if (license.rate > 0) {
return rateLimiters.computeIfAbsent(license, l -> RateLimiter.custom(license.rate));
else {
return null;
private ApiLicense getLicense(Request request) {
final String key = request.params("key");
if (Strings.isNullOrEmpty(key)) {
var cachedLicense = licenseCache.get(key.toLowerCase());
if (cachedLicense != null) {
return cachedLicense;
try (var conn = dataSource.getConnection()) {
try (var stmt = conn.prepareStatement("SELECT LICENSE,NAME,RATE FROM EC_API_KEY WHERE LICENSE_KEY=?")) {
stmt.setString(1, key);
var rsp = stmt.executeQuery();
if (rsp.next()) {
var license = new ApiLicense(key.toLowerCase(), rsp.getString(1), rsp.getString(2), rsp.getInt(3));
licenseCache.put(key.toLowerCase(), license);
return license;
catch (Exception ex) {
logger.error("Bad request", ex);
return null; // unreachable
package nu.marginalia.wmsa.auth.api.model;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
public class ApiLicense {
public String key;
public String license;
public String name;
public int rate;
package nu.marginalia.wmsa.auth.api.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
import nu.marginalia.wmsa.edge.model.search.EdgeUrlDetails;
@AllArgsConstructor @Getter
public class ApiSearchResult {
public String url;
public String title;
public String description;
public double quality;
public ApiSearchResult(EdgeUrlDetails url) {
this.url = url.url.toString();
this.title = url.getTitle();
this.description = url.getDescription();
this.quality = url.getTermScore();
package nu.marginalia.wmsa.auth.api.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.With;
import java.util.List;
public class ApiSearchResults {
private final String license;
private final String query;
private final List<ApiSearchResult> results;
package nu.marginalia.wmsa.auth.client;
import com.google.inject.Inject;
import io.reactivex.rxjava3.core.Observable;
import kotlin.text.Charsets;
import nu.marginalia.wmsa.client.AbstractDynamicClient;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import nu.marginalia.wmsa.configuration.server.Context;
import org.apache.http.HttpStatus;
import spark.Request;
import spark.Response;
import spark.Spark;
import java.net.URLEncoder;
import java.util.concurrent.TimeUnit;
public class AuthClient extends AbstractDynamicClient {
public AuthClient() {
public Observable<Boolean> isLoggedIn(Context ctx) {
return get(ctx, "/api/is-logged-in").map(Boolean::parseBoolean);
public void redirectToLoginIfUnauthenticated(String domain, Request req, Response rsp) {
if (!isLoggedIn(Context.fromRequest(req)).timeout(1, TimeUnit.SECONDS).blockingFirst()) {
rsp.redirect(req.headers("X-Extern-Domain") + "/auth/login?service="+domain
+"&redirect="+ URLEncoder.encode(req.headers("X-Extern-Url"), Charsets.UTF_8));
public void requireLogIn(Context ctx) {
if (!isLoggedIn(ctx).timeout(1, TimeUnit.SECONDS).blockingFirst()) {
@ -0,0 +1,10 @@
package nu.marginalia.wmsa.auth.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter @AllArgsConstructor
public class LoginFormModel {
public final String service;
public final String redirect;
package nu.marginalia.wmsa.client;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.reactivex.rxjava3.core.Scheduler;
import io.reactivex.rxjava3.schedulers.Schedulers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
public class AbortingScheduler implements AutoCloseable {
private final String name;
private final ThreadFactory threadFactory;
private final Logger logger = LoggerFactory.getLogger(getClass());
private ExecutorService executorService;
public AbortingScheduler(String name) {
this.name = name;
threadFactory = new ThreadFactoryBuilder()
private void handleException(Thread thread, Throwable throwable) {
logger.error("Uncaught exception during Client IO in thread {}", thread.getName(), throwable);
public synchronized Scheduler get() {
return Schedulers.from(getExecutorService(),
public synchronized void abort() {
if (null != executorService) {
executorService = Executors.newFixedThreadPool(16, threadFactory);
private synchronized ExecutorService getExecutorService() {
if (null == executorService) {
executorService = Executors.newFixedThreadPool(16, threadFactory);
return executorService;
public synchronized void close() {
if (null != executorService) {
package nu.marginalia.wmsa.client;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.ObservableSource;
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
import lombok.SneakyThrows;
import nu.marginalia.wmsa.client.exception.LocalException;
import nu.marginalia.wmsa.client.exception.NetworkException;
import nu.marginalia.wmsa.client.exception.RemoteException;
import nu.marginalia.wmsa.client.exception.RouteNotConfiguredException;
import nu.marginalia.wmsa.configuration.server.Context;
import okhttp3.*;
import org.apache.http.HttpHost;
import org.apache.logging.log4j.ThreadContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.slf4j.MarkerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.ConnectException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPOutputStream;
public abstract class AbstractClient implements AutoCloseable {
public static final String CONTEXT_OUTBOUND_REQUEST = "outbound-request";
private final Gson gson = new GsonBuilder().create();
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Marker httpMarker = MarkerFactory.getMarker("HTTP");
private final OkHttpClient client;
private boolean quiet;
private String url;
public void setTimeout(int timeout) {
this.timeout = timeout;
private int timeout;
private volatile boolean alive;
private final Thread livenessMonitor;
public AbstractClient(String host, int port, int timeout) {
logger.info("Creating client for {}", getClass().getSimpleName());
this.timeout = timeout;
client = new OkHttpClient.Builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.readTimeout(6000, TimeUnit.SECONDS)
url = new HttpHost(host, port).toURI();
RxJavaPlugins.setErrorHandler(e -> {
if (e.getMessage() == null) {
logger.error("Error", e);
else {
logger.error("Error {}: {}", e.getClass().getSimpleName(), e.getMessage());
livenessMonitor = new Thread(this::monitorLiveness, host + "-monitor");
logger.info("Finished creating client for {}", getClass().getSimpleName());
public void setServiceRoute(String hostname, int port) {
url = new HttpHost(hostname, port).toURI();
private void monitorLiveness() {
Thread.sleep(100); // Wait for initialization
for (;;) {
try {
alive = isResponsive();
catch (java.util.concurrent.TimeoutException tex) {
catch (Exception ex) {
logger.warn("Oops", ex);
synchronized (livenessMonitor) {
if (alive) {
if (!alive) {
public void close() {
public abstract AbortingScheduler scheduler();
public void setQuiet(boolean quiet) {
this.quiet = quiet;
public abstract String name();
public synchronized boolean isResponsive() throws java.util.concurrent.TimeoutException {
Context ctx = Context.internal("ping");
var req = ctx.paint(new Request.Builder()).url(url + "/internal/ping").get().build();
var call = client.newCall(req);
return Observable.just(call)
.flatMap(line -> validateStatus(line, req).timeout(5000, TimeUnit.SECONDS).onErrorReturn(e -> 500))
.onErrorReturn(error -> 500)
public synchronized boolean isAccepting() {
Context ctx = Context.internal("ready");
var req = ctx.paint(new Request.Builder()).url(url + "/internal/ready").get().build();
var call = client.newCall(req);
return Observable.just(call)
.flatMap(line -> validateStatus(line, req))
.timeout(100, TimeUnit.MILLISECONDS)
.onErrorReturn(error -> 500)
protected synchronized Observable<HttpStatusCode> post(Context ctx, String endpoint, Object data) {
RequestBody body = RequestBody.create(
MediaType.parse("application/json; charset=utf-8"),
var req = ctx.paint(new Request.Builder()).url(url + endpoint).post(body).build();
var call = client.newCall(req);
return Observable
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.flatMap(line -> validateStatus(line, req))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized <T> Observable<T> postGet(Context ctx, String endpoint, Object data, Class<T> returnType) {
RequestBody body = RequestBody.create(
var req = ctx.paint(new Request.Builder()).url(url + endpoint).post(body).build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.map(rsp -> validateResponseStatus(rsp, req, 200))
.map(rsp -> getEntity(rsp, returnType))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized Observable<HttpStatusCode> post(Context ctx, String endpoint, String data, MediaType mediaType) {
var body = RequestBody.create(mediaType, data);
var req = ctx.paint(new Request.Builder()).url(url + endpoint).post(body).build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put(CONTEXT_OUTBOUND_REQUEST, url + endpoint);
return c;
.flatMap(line -> validateStatus(line, req))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized <T> Observable<T> get(Context ctx, String endpoint, Class<T> type) {
var req = ctx.paint(new Request.Builder()).url(url + endpoint).get().build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.map(rsp -> validateResponseStatus(rsp, req, 200))
.map(rsp -> getEntity(rsp, type))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized <T> Observable<List<T>> getList(Context ctx, String endpoint, Class<T> type) {
var req = ctx.paint(new Request.Builder()).url(url + endpoint).get().build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.map(rsp -> validateResponseStatus(rsp, req, 200))
.map(rsp -> Arrays.asList((T[])getEntity(rsp, type.arrayType())))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized Observable<byte[]> getBinary(Context ctx, String endpoint) {
var req = ctx.paint(new Request.Builder()).url(url + endpoint).get().build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.map(rsp -> validateResponseStatus(rsp, req, 200))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized Observable<String> get(Context ctx, String endpoint) {
var req = ctx.paint(new Request.Builder()).url(url + endpoint).get().build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.map(rsp -> validateResponseStatus(rsp, req,200))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
protected synchronized Observable<HttpStatusCode> delete(Context ctx, String endpoint) {
var req = ctx.paint(new Request.Builder()).url(url + endpoint).delete().build();
var call = client.newCall(req);
return Observable.just(call)
.map((c) -> {
ThreadContext.put("outbound-request", url + endpoint);
return c;
.flatMap(line -> validateStatus(line, req))
.timeout(timeout, TimeUnit.SECONDS)
.doFinally(() -> ThreadContext.remove("outbound-request"));
private Call logInbound(Call outgoing) {
return outgoing;
private Response logOutbound(Response incoming) {
return incoming;
private void ensureAlive() {
if (!isAlive()) {
if (!isAlive()) {
throw new RouteNotConfiguredException("Route not configured for " + name());
public void waitReady() {
boolean accepting = isAccepting();
if (accepting) {
logger.info("Waiting for " + name());
do {
} while (!isAccepting());
private ObservableSource<?> retryHandler(Observable<Throwable> error) {
return error.flatMap(this::filterRetryableExceptions);
private Observable<Throwable> filterRetryableExceptions(Throwable error) throws Throwable {
synchronized (livenessMonitor) {
if (error.getClass().equals(RouteNotConfiguredException.class)) {
logger.error("Network error {}", error.getMessage());
return Observable.<Throwable>empty().delay(50, TimeUnit.MILLISECONDS);
else if (error.getClass().equals(NetworkException.class)) {
logger.error("Network error {}", error.getMessage());
return Observable.<Throwable>empty().delay(1, TimeUnit.SECONDS);
else if (error.getClass().equals(ConnectException.class)) {
logger.error("Network error {}", error.getMessage());
return Observable.<Throwable>empty().delay(1, TimeUnit.SECONDS);
if (!quiet) {
if (error.getMessage() != null) {
logger.error("{} {}", error.getClass().getSimpleName(), error.getMessage());
else {
logger.error("Error ", error);
throw error;
private Observable<Integer> validateStatus(int status, Request request) {
if (status == org.apache.http.HttpStatus.SC_OK)
return Observable.just(status);
if (status == org.apache.http.HttpStatus.SC_ACCEPTED)
return Observable.just(status);
if (status == org.apache.http.HttpStatus.SC_CREATED)
return Observable.just(status);
return Observable.error(new RemoteException(name() + " responded status code " + status + " " + request.url()));
private Response validateResponseStatus(Response response, Request req, int expected) {
if (expected != response.code()) {
throw new RemoteException(name() + " responded status code " + response.code() + ", " + req.method() + " " + req.url().toString());
return response;
private int getResponseStatus(Response response) {
try (response) {
return response.code();
private <T> T getEntity(Response response, Class<T> clazz) {
try (response) {
return gson.fromJson(response.body().charStream(), clazz);
catch (Exception ex) {
throw ex;
private String getText(Response response) {
try (response) {
return response.body().string();
private byte[] getBinaryEntity(Response response) {
try (response) {
return response.body().bytes();
public boolean isAlive() {
return alive;
private String json(Object o) {
try {
return gson.toJson(o);
catch (Exception ex) {
throw new LocalException(ex);
private byte[] compressedJson(Object o) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
GZIPOutputStream gos = new GZIPOutputStream(baos);
try {
gson.toJson(o, new OutputStreamWriter(gos));
return baos.toByteArray();
catch (Exception ex) {
throw new LocalException(ex);
@ -0,0 +1,52 @@
package nu.marginalia.wmsa.client;
import io.reactivex.rxjava3.core.Observable;
import lombok.SneakyThrows;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import nu.marginalia.wmsa.configuration.server.Context;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
public class AbstractDynamicClient extends AbstractClient {
private final ServiceDescriptor service;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final AbortingScheduler scheduler;
public AbstractDynamicClient(@Nonnull ServiceDescriptor service) {
super("localhost", service.port, 10);
this.service = service;
this.scheduler = new AbortingScheduler(name());
public String name() {
return service.name;
public ServiceDescriptor getService() {
return service;
public void blockingWait() {
logger.info("Waiting for route to {}", service);
while (!isAlive()) {
public AbortingScheduler scheduler() {
return scheduler;
public Observable<String> who(Context ctx) {
return get(ctx, "/public/who");
public Observable<String> ping(Context ctx) {
return get(ctx, "/internal/ping");
@ -0,0 +1,19 @@
package nu.marginalia.wmsa.client;
public final class HttpStatusCode {
public final int code;
public HttpStatusCode(int code) {
this.code = code;
public boolean isGood() {
if (code == org.apache.http.HttpStatus.SC_OK)
return true;
if (code == org.apache.http.HttpStatus.SC_ACCEPTED)
return true;
if (code == org.apache.http.HttpStatus.SC_CREATED)
return true;
return false;
package nu.marginalia.wmsa.client.exception;
public class LocalException extends MessagingException {
public LocalException() {
public LocalException(String message) {
public LocalException(Throwable cause) {
public LocalException(String message, Throwable cause) {
super(message, cause);
package nu.marginalia.wmsa.client.exception;
public class MessagingException extends RuntimeException {
public MessagingException() {
public MessagingException(String message) {
public MessagingException(Throwable cause) {
public MessagingException(String message, Throwable cause) {
super(message, cause);
public Throwable fillInStackTrace() {
return this;
package nu.marginalia.wmsa.client.exception;
public class NetworkException extends MessagingException {
public NetworkException() {
public NetworkException(String message) {
public NetworkException(Throwable cause) {
public NetworkException(String message, Throwable cause) {
super(message, cause);
package nu.marginalia.wmsa.client.exception;
public class RemoteException extends MessagingException {
public RemoteException() {
public RemoteException(String message) {
public RemoteException(Throwable cause) {
public RemoteException(String message, Throwable cause) {
super(message, cause);
package nu.marginalia.wmsa.client.exception;
public class RouteNotConfiguredException extends MessagingException {
public RouteNotConfiguredException() {
public RouteNotConfiguredException(String message) {
public RouteNotConfiguredException(Throwable cause) {
public RouteNotConfiguredException(String message, Throwable cause) {
super(message, cause);
package nu.marginalia.wmsa.client.exception;
public class TimeoutException extends MessagingException {
public TimeoutException() {
public TimeoutException(String message) {
public TimeoutException(Throwable cause) {
public TimeoutException(String message, Throwable cause) {
super(message, cause);
package nu.marginalia.wmsa.configuration;
import io.prometheus.client.hotspot.DefaultExports;
import io.reactivex.rxjava3.exceptions.UndeliverableException;
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
import lombok.SneakyThrows;
import nu.marginalia.wmsa.client.exception.NetworkException;
import org.mariadb.jdbc.Driver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.util.Arrays;
public abstract class MainClass {
private Logger logger = LoggerFactory.getLogger(getClass());
public MainClass() {
RxJavaPlugins.setErrorHandler(ex -> {
if (ex instanceof UndeliverableException) {
ex = ex.getCause();
if (ex instanceof SocketTimeoutException) {
else if (ex instanceof UnknownHostException) {
else if (ex instanceof NetworkException) {
logger.warn("NetworkException", ex);
else {
logger.error("Uncaught exception", ex);
protected static void init(ServiceDescriptor service, String... args) {
System.setProperty("log4j2.isThreadContextMapInheritable", "true");
System.setProperty("isThreadContextMapInheritable", "true");
System.setProperty("service-name", service.name);
org.mariadb.jdbc.Driver driver = new Driver();
if (Arrays.asList(args).contains("go-no-go")) {
System.setProperty("go-no-go", "true");
package nu.marginalia.wmsa.configuration;
import nu.marginalia.wmsa.auth.AuthMain;
import nu.marginalia.wmsa.auth.api.ApiMain;
import nu.marginalia.wmsa.configuration.command.Command;
import nu.marginalia.wmsa.configuration.command.ListCommand;
import nu.marginalia.wmsa.configuration.command.StartCommand;
import nu.marginalia.wmsa.configuration.command.VersionCommand;
import nu.marginalia.wmsa.data_store.DataStoreMain;
import nu.marginalia.wmsa.edge.archive.EdgeArchiveMain;
import nu.marginalia.wmsa.edge.assistant.EdgeAssistantMain;
import nu.marginalia.wmsa.edge.crawler.EdgeCrawlerMain;
import nu.marginalia.wmsa.edge.dating.DatingMain;
import nu.marginalia.wmsa.edge.director.EdgeDirectorMain;
import nu.marginalia.wmsa.edge.index.EdgeIndexMain;
import nu.marginalia.wmsa.edge.search.EdgeSearchMain;
import nu.marginalia.wmsa.memex.MemexMain;
import nu.marginalia.wmsa.podcasts.PodcastScraperMain;
import nu.marginalia.wmsa.renderer.RendererMain;
import nu.marginalia.wmsa.resource_store.ResourceStoreMain;
import nu.marginalia.wmsa.smhi.scraper.SmhiScraperMain;
import org.apache.logging.log4j.core.lookup.MainMapLookup;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public enum ServiceDescriptor {
RESOURCE_STORE("resource-store", 5000, ResourceStoreMain.class),
DATA_STORE("data-store", 5001, DataStoreMain.class),
RENDERER("renderer", 5002, RendererMain.class),
AUTH("auth", 5003, AuthMain.class),
API("api", 5004, ApiMain.class),
SMHI_SCRAPER("smhi-scraper",5012, SmhiScraperMain.class),
PODCST_SCRAPER("podcast-scraper", 5013, PodcastScraperMain.class),
EDGE_CRAWLER("edge-crawler", 5020, EdgeCrawlerMain.class),
EDGE_INDEX("edge-index", 5021, EdgeIndexMain.class),
EDGE_DIRECTOR("edge-director", 5022, EdgeDirectorMain.class),
EDGE_SEARCH("edge-search", 5023, EdgeSearchMain.class),
EDGE_ARCHIVE("edge-archive", 5024, EdgeArchiveMain.class),
EDGE_ASSISTANT("edge-assistant", 5025, EdgeAssistantMain.class),
EDGE_MEMEX("memex", 5030, MemexMain.class),
DATING("dating", 5070, DatingMain.class),
TEST_1("test-1", 0, null),
TEST_2("test-2", 0, null);
public static ServiceDescriptor byName(String name) {
for (var v : values()) {
if (v.name.equals(name)) {
return v;
throw new IllegalArgumentException(name);
public final String name;
public final Class<?> mainClass;
public final int port;
ServiceDescriptor(String name, int port, Class<?> mainClass) {
this.name = name;
this.port = port;
this.mainClass = mainClass;
public String toString() {
return name;
public String describeService() {
return String.format("%s %s", name, mainClass.getName());
public static void main(String... args) {
Map<String, Command> functions = Stream.of(new ListCommand(),
new StartCommand(),
new VersionCommand()
).collect(Collectors.toMap(c -> c.name, c -> c));
if(args.length > 0) {
functions.getOrDefault(args[0], new Command("") {
public void execute(String... args) {
System.err.println("Unknown command");
else {
System.err.println("Usage: " + String.join("|", functions.keySet()));
package nu.marginalia.wmsa.configuration;
import java.nio.file.Files;
import java.nio.file.Path;
public class WmsaHome {
private static final String DEFAULT = "/var/lib/wmsa";
public static Path get() {
var ret = Path.of(System.getProperty("WMSA_HOME", DEFAULT));
if (!Files.isDirectory(ret)) {
throw new IllegalStateException("Could not find WMSA_HOME, either set environment variable or ensure " + DEFAULT + " exists");
return ret;
package nu.marginalia.wmsa.configuration.command;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import java.util.Arrays;
import java.util.Objects;
public abstract class Command {
public final String name;
protected Command(String name) {
this.name = name;
public abstract void execute(String... args);
static ServiceDescriptor getKind(String arg) {
try {
return Arrays.stream(ServiceDescriptor.values())
.filter(sd -> Objects.equals(arg, sd.name))
} catch (IllegalArgumentException ex) {
System.err.println("Unknown service '" + arg + "'");
return null;
package nu.marginalia.wmsa.configuration.command;
import lombok.SneakyThrows;
import nu.marginalia.wmsa.configuration.ServiceDescriptor;
import java.util.Arrays;
import java.util.Objects;
public class ListCommand extends Command {
public ListCommand() {
public void execute(String... args) {
.filter(sd -> Objects.nonNull(sd.mainClass))
package nu.marginalia.wmsa.configuration.command;
import lombok.SneakyThrows;
import java.util.Arrays;
public class StartCommand extends Command {
public StartCommand() {
public void execute(String... args) {
if (args.length < 2) {
System.err.println("Usage: start service-descriptor");
var mainMethod = getKind(args[1]).mainClass.getMethod("main", String[].class);
String[] args2 = Arrays.copyOfRange(args, 2, args.length);
mainMethod.invoke(null, (Object) args2);
package nu.marginalia.wmsa.configuration.command;
import lombok.SneakyThrows;
public class VersionCommand extends Command {
public VersionCommand() {
public void execute(String... args) {
try (var str = ClassLoader.getSystemResourceAsStream("_version.txt")) {
if (null == str) {
System.err.println("Bad jar, missing _version.txt");
package nu.marginalia.wmsa.configuration.module;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import com.google.inject.name.Named;
import lombok.SneakyThrows;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Objects;
import static com.google.inject.name.Names.named;
public class ConfigurationModule extends AbstractModule {
private static final String SERVICE_NAME = System.getProperty("service-name");
public static int MONITOR_PORT = Integer.getInteger("monitor.port", 5000);
public static String MONITOR_HOST = System.getProperty("monitor.host", "");
public void configure() {
public String buildVersion() {
try (var str = ClassLoader.getSystemResourceAsStream("_version.txt")) {
if (null == str) {
System.err.println("Missing _version.txt from classpath");
return LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME);
return new String(str.readAllBytes());
package nu.marginalia.wmsa.configuration.module;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import lombok.SneakyThrows;
import nu.marginalia.wmsa.configuration.WmsaHome;
import org.h2.tools.RunScript;
import org.mariadb.jdbc.Driver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Properties;
public class DatabaseModule extends AbstractModule {
private static final Logger logger = LoggerFactory.getLogger(DatabaseModule.class);
private static final String DB_USER_KEY="db.user";
private static final String DB_PASS_KEY ="db.pass";
private static final String DB_CONN_KEY ="db.conn";
private final Properties dbProperties;
public DatabaseModule() {
new Driver();
dbProperties = loadDbProperties();
private Properties loadDbProperties() {
Path propDir = WmsaHome.get().resolve("db.properties");
if (!Files.isRegularFile(propDir)) {
throw new IllegalStateException("Database properties file " + propDir + " does not exist");
try (var is = new FileInputStream(propDir.toFile())) {
var props = new Properties();
if (!props.containsKey(DB_USER_KEY)) throw new IllegalStateException(propDir + " missing required attribute " + DB_USER_KEY);
if (!props.containsKey(DB_PASS_KEY)) throw new IllegalStateException(propDir + " missing required attribute " + DB_PASS_KEY);
if (!props.containsKey(DB_CONN_KEY)) throw new IllegalStateException(propDir + " missing required attribute " + DB_CONN_KEY);
return props;
catch (IOException ex) {
throw new RuntimeException(ex);
public HikariDataSource provideConnection() {
if (Boolean.getBoolean("data-store-h2")) {
return getH2();
else {
return getMariaDB();
private HikariDataSource getMariaDB() {
var connStr = dbProperties.getProperty(DB_CONN_KEY);
try {
HikariConfig config = new HikariConfig();
config.addDataSourceProperty("cachePrepStmts", "true");
config.addDataSourceProperty("prepStmtCacheSize", "250");
config.addDataSourceProperty("prepStmtCacheSqlLimit", "2048");
return new HikariDataSource(config);
finally {
logger.info("Created HikariPool for {}", connStr);
private HikariDataSource getH2() {
HikariConfig config = new HikariConfig();
var ds = new HikariDataSource(config);
try (var stream = ClassLoader.getSystemResourceAsStream("sql/data-store-init.sql")) {
RunScript.execute(ds.getConnection(), new InputStreamReader(stream));
try (var stream = ClassLoader.getSystemResourceAsStream("sql/edge-crawler-cache.sql")) {
RunScript.execute(ds.getConnection(), new InputStreamReader(stream));
return ds;
package nu.marginalia.wmsa.configuration.module;
import com.google.inject.name.Named;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.inject.Inject;
import javax.inject.Provider;
public class HostnameProvider implements Provider<String> {
private static final String DEFAULT_HOSTNAME = "";
private final int monitorPort;
private final String monitorHost;
private final int timeout;
private Logger logger = LoggerFactory.getLogger(getClass());
public HostnameProvider(@Named("monitor-port") Integer monitorPort,
@Named("monitor-host") String monitorHost,
@Named("monitor-boot-timeout") Integer timeout
) {
this.monitorHost = monitorHost;
this.monitorPort = monitorPort;
this.timeout = timeout;
public String get() {
var override = System.getProperty("service-host");
if (null != override) {
return override;
package nu.marginalia.wmsa.configuration.module;
import com.google.inject.name.Named;
import javax.inject.Inject;
public class LoggerConfiguration {
public LoggerConfiguration(@Named("service-name") String serviceName) {
System.setProperty("service-name", serviceName);
package nu.marginalia.wmsa.configuration.module;
import com.google.inject.name.Named;
import javax.inject.Inject;
import javax.inject.Provider;
public class MetricsPortProvider implements Provider<Integer> {
private final Integer servicePort;
public MetricsPortProvider(@Named("service-port") Integer servicePort) {
this.servicePort = servicePort;
public Integer get() {
return servicePort+1000;
