diff --git a/code/common/db/src/main/resources/sql/current/11-message-queue.sql b/code/common/db/src/main/resources/sql/current/11-message-queue.sql new file mode 100644 index 00000000..97e20d5a --- /dev/null +++ b/code/common/db/src/main/resources/sql/current/11-message-queue.sql @@ -0,0 +1,20 @@ +CREATE TABLE PROC_MESSAGE( + ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', + + RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message', + SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', + + RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', + FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', + PAYLOAD TEXT COMMENT 'Message to recipient', + + OWNER_INSTANCE VARCHAR(255) COMMENT 'Instance UUID corresponding to the party that has claimed the message', + OWNER_TICK BIGINT DEFAULT -1 COMMENT 'Used by recipient to determine which messages it has processed', + + STATE ENUM('NEW', 'ACK', 'OK', 'ERR', 'DEAD') + NOT NULL DEFAULT 'NEW' COMMENT 'Processing state', + + CREATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of creation', + UPDATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of last update', + TTL INT COMMENT 'Time to live in seconds' +); diff --git a/code/common/db/src/main/resources/sql/migrations/03-message-queue.sql b/code/common/db/src/main/resources/sql/migrations/03-message-queue.sql new file mode 100644 index 00000000..d357650e --- /dev/null +++ b/code/common/db/src/main/resources/sql/migrations/03-message-queue.sql @@ -0,0 +1,23 @@ +CREATE TABLE PROC_MESSAGE( + ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', + + RELATED_ID BIGINT COMMENT 'Unique id a related message', + SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', + + RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', + FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', + PAYLOAD TEXT COMMENT 'Message to recipient', + + -- These fields are used to avoid double processing of messages + -- instance marks the unique instance of the party, and the tick marks + -- the current polling iteration. Both are necessary. + OWNER_INSTANCE VARCHAR(255) COMMENT 'Instance UUID corresponding to the party that has claimed the message', + OWNER_TICK BIGINT DEFAULT -1 COMMENT 'Used by recipient to determine which messages it has processed', + + STATE ENUM('NEW', 'ACK', 'OK', 'ERR', 'DEAD') + NOT NULL DEFAULT 'NEW' COMMENT 'Processing state', + + CREATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of creation', + UPDATED_TIME TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT 'Time of last update', + TTL INT COMMENT 'Time to live in seconds' +); diff --git a/code/common/message-queue/build.gradle b/code/common/message-queue/build.gradle new file mode 100644 index 00000000..84ea9651 --- /dev/null +++ b/code/common/message-queue/build.gradle @@ -0,0 +1,48 @@ +plugins { + id 'java' +} + + +java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(17)) + } +} + +dependencies { + implementation project(':code:common:service-client') + implementation project(':code:common:service-discovery') + implementation project(':code:common:db') + + implementation libs.lombok + annotationProcessor libs.lombok + + implementation libs.spark + implementation libs.guice + implementation libs.rxjava + + implementation libs.bundles.prometheus + implementation libs.bundles.slf4j + implementation libs.bucket4j + + testImplementation libs.bundles.slf4j.test + implementation libs.bundles.mariadb + + testImplementation libs.bundles.slf4j.test + testImplementation libs.bundles.junit + testImplementation libs.mockito + + testImplementation platform('org.testcontainers:testcontainers-bom:1.17.4') + testImplementation 'org.testcontainers:mariadb:1.17.4' + testImplementation 'org.testcontainers:junit-jupiter:1.17.4' +} + +test { + useJUnitPlatform() +} + +task fastTests(type: Test) { + useJUnitPlatform { + excludeTags "slow" + } +} diff --git a/code/common/message-queue/msgstate.svg b/code/common/message-queue/msgstate.svg new file mode 100644 index 00000000..22691893 --- /dev/null +++ b/code/common/message-queue/msgstate.svg @@ -0,0 +1,4 @@ + + + +If the message is notacknowledged, it maybe declared dead afterTTLIf the message is not...Inbox acknowledges the messageInbox acknowledges the messageNewNewMessage processingfailedMessage processing...If the message doesn'tfinish within TTL it willbe marked as deadIf the message doesn't...Message processedOK, sender mayreceive a reply in theirinboxMessage processed...AckAckOkOkErrErrDeadDeadTerminal StatesTerminal S...Intermediate StatesIntermedia...Initial StateInitial St...Message StatesMessages pass through several states through their lifecycleMessage States...Text is not SVG - cannot display \ No newline at end of file diff --git a/code/common/message-queue/readme.md b/code/common/message-queue/readme.md new file mode 100644 index 00000000..68ae2825 --- /dev/null +++ b/code/common/message-queue/readme.md @@ -0,0 +1,5 @@ +# Message Queue + +Implements a message queue using mariadb. + + \ No newline at end of file diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/MqException.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqException.java new file mode 100644 index 00000000..351f60d7 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqException.java @@ -0,0 +1,11 @@ +package nu.marginalia.mq; + +public class MqException extends Exception { + public MqException(String message) { + super(message); + } + + public MqException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java new file mode 100644 index 00000000..5f4c11aa --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessage.java @@ -0,0 +1,10 @@ +package nu.marginalia.mq; + +public record MqMessage( + long msgId, + long relatedId, + String function, + String payload, + MqMessageState state +) { +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessageState.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessageState.java new file mode 100644 index 00000000..d1d03f15 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/MqMessageState.java @@ -0,0 +1,9 @@ +package nu.marginalia.mq; + +public enum MqMessageState { + NEW, + ACK, + OK, + ERR, + DEAD +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java new file mode 100644 index 00000000..7d94b327 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java @@ -0,0 +1,185 @@ +package nu.marginalia.mq.inbox; + +import nu.marginalia.mq.MqMessage; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.persistence.MqPersistence; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.sql.SQLException; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +public class MqInbox { + private final Logger logger = LoggerFactory.getLogger(MqInbox.class); + + private final String inboxName; + private final String instanceUUID; + private final ExecutorService threadPool; + private final MqPersistence persistence; + + private volatile boolean run = true; + + private final int pollIntervalMs = Integer.getInteger("mq.inbox.poll-interval-ms", 1000); + private final List eventSubscribers = new ArrayList<>(); + private final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(32); + + private Thread pollDbThread; + private Thread notifyThread; + + public MqInbox(MqPersistence persistence, + String inboxName, + UUID instanceUUID) + { + this.threadPool = Executors.newCachedThreadPool(); + this.persistence = persistence; + this.inboxName = inboxName; + this.instanceUUID = instanceUUID.toString(); + } + + public void subscribe(MqSubscription subscription) { + eventSubscribers.add(subscription); + } + + public void start() { + run = true; + + if (eventSubscribers.isEmpty()) { + logger.error("No subscribers for inbox {}, registering shredder", inboxName); + } + + // Add a final handler that fails any message that is not handled + eventSubscribers.add(new MqInboxShredder()); + + pollDbThread = new Thread(this::pollDb, "mq-inbox-update-thread:"+inboxName); + pollDbThread.setDaemon(true); + pollDbThread.start(); + + notifyThread = new Thread(this::notifySubscribers, "mq-inbox-notify-thread:"+inboxName); + notifyThread.setDaemon(true); + notifyThread.start(); + } + + public void stop() throws InterruptedException { + if (!run) + return; + + logger.info("Shutting down inbox {}", inboxName); + + run = false; + pollDbThread.join(); + notifyThread.join(); + + threadPool.shutdownNow(); + + while (!threadPool.awaitTermination(5, TimeUnit.SECONDS)); + } + + private void notifySubscribers() { + try { + while (run) { + + MqMessage msg = queue.poll(pollIntervalMs, TimeUnit.MILLISECONDS); + + if (msg == null) + continue; + + logger.info("Notifying subscribers of message {}", msg.msgId()); + + boolean handled = false; + + for (var eventSubscriber : eventSubscribers) { + if (eventSubscriber.filter(msg)) { + handleMessageWithSubscriber(eventSubscriber, msg); + handled = true; + break; + } + } + + if (!handled) { + logger.error("No subscriber wanted to handle message {}", msg.msgId()); + } + } + } + catch (InterruptedException ex) { + logger.error("MQ inbox notify thread interrupted", ex); + } + } + + private void handleMessageWithSubscriber(MqSubscription subscriber, MqMessage msg) { + + threadPool.execute(() -> { + try { + final var rsp = subscriber.handle(msg); + + sendResponse(msg, rsp.state(), rsp.message()); + } catch (Exception ex) { + logger.error("Message Queue subscriber threw exception", ex); + sendResponse(msg, MqMessageState.ERR); + } + }); + } + + private void sendResponse(MqMessage msg, MqMessageState mqMessageState) { + try { + persistence.updateMessageState(msg.msgId(), mqMessageState); + } + catch (SQLException ex) { + logger.error("Failed to update message state", ex); + } + } + + private void sendResponse(MqMessage msg, MqMessageState mqMessageState, String response) { + try { + persistence.sendResponse(msg.msgId(), mqMessageState, response); + } + catch (SQLException ex) { + logger.error("Failed to update message state", ex); + } + } + + public void pollDb() { + try { + for (long tick = 1; run; tick++) { + + queue.addAll(pollInbox(tick)); + + TimeUnit.MILLISECONDS.sleep(pollIntervalMs); + } + } + catch (InterruptedException ex) { + logger.error("MQ inbox update thread interrupted", ex); + } + } + + private Collection pollInbox(long tick) { + try { + return persistence.pollInbox(inboxName, instanceUUID, tick); + } + catch (SQLException ex) { + logger.error("Failed to poll inbox", ex); + return List.of(); + } + } + + + private class MqInboxShredder implements MqSubscription { + + @Override + public boolean filter(MqMessage rawMessage) { + return true; + } + + @Override + public MqInboxResponse handle(MqMessage msg) { + logger.warn("Unhandled message {}", msg.msgId()); + return MqInboxResponse.err(); + } + } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInboxResponse.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInboxResponse.java new file mode 100644 index 00000000..ba4eb6f2 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInboxResponse.java @@ -0,0 +1,22 @@ +package nu.marginalia.mq.inbox; + +import nu.marginalia.mq.MqMessageState; + +public record MqInboxResponse(String message, MqMessageState state) { + + public static MqInboxResponse ok(String message) { + return new MqInboxResponse(message, MqMessageState.OK); + } + + public static MqInboxResponse ok() { + return new MqInboxResponse("", MqMessageState.OK); + } + + public static MqInboxResponse err(String message) { + return new MqInboxResponse(message, MqMessageState.ERR); + } + + public static MqInboxResponse err() { + return new MqInboxResponse("", MqMessageState.ERR); + } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java new file mode 100644 index 00000000..ce52a26b --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqSubscription.java @@ -0,0 +1,9 @@ +package nu.marginalia.mq.inbox; + +import nu.marginalia.mq.MqMessage; + +public interface MqSubscription { + boolean filter(MqMessage rawMessage); + + MqInboxResponse handle(MqMessage msg); +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java new file mode 100644 index 00000000..e4fa2e23 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java @@ -0,0 +1,107 @@ +package nu.marginalia.mq.outbox; + +import nu.marginalia.mq.MqMessage; +import nu.marginalia.mq.persistence.MqPersistence; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.SQLException; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +public class MqOutbox { + private final Logger logger = LoggerFactory.getLogger(MqOutbox.class); + private final MqPersistence persistence; + private final String inboxName; + private final String replyInboxName; + private final String instanceUUID; + + private final ConcurrentHashMap pendingRequests = new ConcurrentHashMap<>(); + private final ConcurrentHashMap pendingResponses = new ConcurrentHashMap<>(); + + private final int pollIntervalMs = Integer.getInteger("mq.outbox.poll-interval-ms", 1000); + private final Thread pollThread; + + private volatile boolean run = true; + + public MqOutbox(MqPersistence persistence, + String inboxName, + UUID instanceUUID) { + this.persistence = persistence; + + this.inboxName = inboxName; + this.replyInboxName = "reply:" + inboxName; + this.instanceUUID = instanceUUID.toString(); + + pollThread = new Thread(this::poll, "mq-outbox-poll-thread:" + inboxName); + pollThread.setDaemon(true); + pollThread.start(); + } + + public void stop() throws InterruptedException { + if (!run) + return; + + logger.info("Shutting down outbox {}", inboxName); + + pendingRequests.clear(); + + run = false; + pollThread.join(); + } + + private void poll() { + try { + for (long id = 1; run; id++) { + pollDb(id); + + TimeUnit.MILLISECONDS.sleep(pollIntervalMs); + } + } catch (InterruptedException ex) { + logger.error("Outbox poll thread interrupted", ex); + } + } + + private void pollDb(long tick) { + if (pendingRequests.isEmpty()) + return; + + try { + var updates = persistence.pollReplyInbox(replyInboxName, instanceUUID, tick); + + for (var message : updates) { + pendingResponses.put(message.relatedId(), message); + pendingRequests.remove(message.relatedId()); + } + + if (updates.isEmpty() || pendingResponses.isEmpty()) + return; + + logger.info("Notifying {} pending responses", pendingResponses.size()); + + synchronized (pendingResponses) { + pendingResponses.notifyAll(); + } + } + catch (SQLException ex) { + logger.error("Failed to poll inbox", ex); + } + + } + + public MqMessage send(String function, String payload) throws Exception { + var id = persistence.sendNewMessage(inboxName, replyInboxName, function, payload, null); + pendingRequests.put(id, id); + + synchronized (pendingResponses) { + while (!pendingResponses.containsKey(id)) { + pendingResponses.wait(100); + } + return pendingResponses.remove(id); + } + } + + +} \ No newline at end of file diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java new file mode 100644 index 00000000..92fffb51 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java @@ -0,0 +1,237 @@ +package nu.marginalia.mq.persistence; + +import com.google.inject.Inject; +import com.google.inject.Singleton; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.MqMessage; + +import javax.annotation.Nullable; +import java.sql.SQLException; +import java.time.Duration; +import java.util.*; + +@Singleton +public class MqPersistence { + private final HikariDataSource dataSource; + + @Inject + public MqPersistence(HikariDataSource dataSource) { + this.dataSource = dataSource; + } + + /** Flags messages as dead if they have not been set to a terminal state within a TTL after the last update. */ + public int reapDeadMessages() throws SQLException { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(""" + UPDATE PROC_MESSAGE + SET STATE='DEAD', UPDATED_TIME=CURRENT_TIMESTAMP(6) + WHERE STATE IN ('NEW', 'ACK') + AND TTL IS NOT NULL + AND TIMESTAMPDIFF(SECOND, UPDATED_TIME, CURRENT_TIMESTAMP(6)) > TTL + """)) { + return stmt.executeUpdate(); + } + } + + public long sendNewMessage(String recipientInboxName, + @Nullable + String senderInboxName, + String function, + String payload, + @Nullable Duration ttl + ) throws Exception { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(""" + INSERT INTO PROC_MESSAGE(RECIPIENT_INBOX, SENDER_INBOX, FUNCTION, PAYLOAD, TTL) + VALUES(?, ?, ?, ?, ?) + """); + var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) { + + stmt.setString(1, recipientInboxName); + + if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR); + else stmt.setString(2, senderInboxName); + + stmt.setString(3, function); + stmt.setString(4, payload); + if (ttl == null) stmt.setNull(5, java.sql.Types.BIGINT); + else stmt.setLong(5, ttl.toSeconds()); + + stmt.executeUpdate(); + var rsp = lastIdQuery.executeQuery(); + + if (!rsp.next()) { + throw new IllegalStateException("No last insert id"); + } + + return rsp.getLong(1); + } + } + + + public void updateMessageState(long id, MqMessageState mqMessageState) throws SQLException { + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(""" + UPDATE PROC_MESSAGE + SET STATE=?, UPDATED_TIME=CURRENT_TIMESTAMP(6) + WHERE ID=? + """)) { + stmt.setString(1, mqMessageState.name()); + stmt.setLong(2, id); + + if (stmt.executeUpdate() != 1) { + throw new IllegalArgumentException("No rows updated"); + } + } + } + + public long sendResponse(long id, MqMessageState mqMessageState, String message) throws SQLException { + try (var conn = dataSource.getConnection()) { + conn.setAutoCommit(false); + + try (var updateState = conn.prepareStatement(""" + UPDATE PROC_MESSAGE + SET STATE=?, UPDATED_TIME=CURRENT_TIMESTAMP(6) + WHERE ID=? + """); + var addResponse = conn.prepareStatement(""" + INSERT INTO PROC_MESSAGE(RECIPIENT_INBOX, RELATED_ID, FUNCTION, PAYLOAD) + SELECT SENDER_INBOX, ID, ?, ? + FROM PROC_MESSAGE + WHERE ID=? AND SENDER_INBOX IS NOT NULL + """); + var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()") + ) { + + updateState.setString(1, mqMessageState.name()); + updateState.setLong(2, id); + if (updateState.executeUpdate() != 1) { + throw new IllegalArgumentException("No rows updated"); + } + + addResponse.setString(1, "REPLY"); + addResponse.setString(2, message); + addResponse.setLong(3, id); + if (addResponse.executeUpdate() != 1) { + throw new IllegalArgumentException("No rows updated"); + } + + var rsp = lastIdQuery.executeQuery(); + if (!rsp.next()) { + throw new IllegalStateException("No last insert id"); + } + long newId = rsp.getLong(1); + + conn.commit(); + + return newId; + } catch (SQLException|IllegalStateException|IllegalArgumentException ex) { + conn.rollback(); + throw ex; + } finally { + conn.setAutoCommit(true); + } + } + } + + + private int markInboxMessages(String inboxName, String instanceUUID, long tick) throws SQLException { + try (var conn = dataSource.getConnection(); + var updateStmt = conn.prepareStatement(""" + UPDATE PROC_MESSAGE + SET OWNER_INSTANCE=?, OWNER_TICK=?, UPDATED_TIME=CURRENT_TIMESTAMP(6), STATE='ACK' + WHERE RECIPIENT_INBOX=? + AND OWNER_INSTANCE IS NULL AND STATE='NEW' + """); + ) { + updateStmt.setString(1, instanceUUID); + updateStmt.setLong(2, tick); + updateStmt.setString(3, inboxName); + return updateStmt.executeUpdate(); + } + } + + /** Marks unclaimed messages addressed to this inbox with instanceUUID and tick, + * then returns these messages. + */ + public Collection pollInbox(String inboxName, String instanceUUID, long tick) throws SQLException { + + int expected = markInboxMessages(inboxName, instanceUUID, tick); + if (expected == 0) { + return Collections.emptyList(); + } + + try (var conn = dataSource.getConnection(); + var queryStmt = conn.prepareStatement(""" + SELECT ID, RELATED_ID, FUNCTION, PAYLOAD, STATE FROM PROC_MESSAGE + WHERE OWNER_INSTANCE=? AND OWNER_TICK=? + """) + ) { + queryStmt.setString(1, instanceUUID); + queryStmt.setLong(2, tick); + var rs = queryStmt.executeQuery(); + + List messages = new ArrayList<>(expected); + + while (rs.next()) { + long msgId = rs.getLong(1); + long relatedId = rs.getLong(2); + + String function = rs.getString(3); + String payload = rs.getString(4); + + MqMessageState state = MqMessageState.valueOf(rs.getString(5)); + + var msg = new MqMessage(msgId, relatedId, function, payload, state); + + messages.add(msg); + } + + return messages; + } + + } + + + /** Marks unclaimed messages addressed to this inbox with instanceUUID and tick, + * then returns these messages. + */ + public Collection pollReplyInbox(String inboxName, String instanceUUID, long tick) throws SQLException { + + int expected = markInboxMessages(inboxName, instanceUUID, tick); + if (expected == 0) { + return Collections.emptyList(); + } + + try (var conn = dataSource.getConnection(); + var queryStmt = conn.prepareStatement(""" + SELECT SELF.ID, SELF.RELATED_ID, SELF.FUNCTION, SELF.PAYLOAD, PARENT.STATE FROM PROC_MESSAGE SELF + LEFT JOIN PROC_MESSAGE PARENT ON SELF.RELATED_ID=PARENT.ID + WHERE SELF.OWNER_INSTANCE=? AND SELF.OWNER_TICK=? + """) + ) { + queryStmt.setString(1, instanceUUID); + queryStmt.setLong(2, tick); + var rs = queryStmt.executeQuery(); + + List messages = new ArrayList<>(expected); + + while (rs.next()) { + long msgId = rs.getLong(1); + long relatedId = rs.getLong(2); + + String function = rs.getString(3); + String payload = rs.getString(4); + + MqMessageState state = MqMessageState.valueOf(rs.getString(5)); + + var msg = new MqMessage(msgId, relatedId, function, payload, state); + + messages.add(msg); + } + + return messages; + } + } +} diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java new file mode 100644 index 00000000..933cdb62 --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqMessageRow.java @@ -0,0 +1,21 @@ +package nu.marginalia.mq.outbox; + +import nu.marginalia.mq.MqMessageState; + +import javax.annotation.Nullable; + +public record MqMessageRow ( + long id, + long relatedId, + @Nullable + String senderInbox, + String recipientInbox, + String function, + String payload, + MqMessageState state, + String ownerInstance, + long ownerTick, + long createdTime, + long updatedTime, + long ttl +) {} \ No newline at end of file diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java new file mode 100644 index 00000000..789aec15 --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqOutboxTest.java @@ -0,0 +1,177 @@ +package nu.marginalia.mq.outbox; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessage; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.inbox.MqInboxResponse; +import nu.marginalia.mq.inbox.MqInbox; +import nu.marginalia.mq.inbox.MqSubscription; +import nu.marginalia.mq.persistence.MqPersistence; +import org.junit.jupiter.api.*; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.UUID; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Tag("slow") +@Testcontainers +public class MqOutboxTest { + @Container + static MariaDBContainer> mariaDBContainer = new MariaDBContainer<>("mariadb") + .withDatabaseName("WMSA_prod") + .withUsername("wmsa") + .withPassword("wmsa") + .withInitScript("sql/current/11-message-queue.sql") + .withNetworkAliases("mariadb"); + + static HikariDataSource dataSource; + private String inboxId; + + @BeforeEach + public void setUp() { + inboxId = UUID.randomUUID().toString(); + } + @BeforeAll + public static void setUpAll() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mariaDBContainer.getJdbcUrl()); + config.setUsername("wmsa"); + config.setPassword("wmsa"); + + dataSource = new HikariDataSource(config); + } + + @AfterAll + public static void tearDownAll() { + dataSource.close(); + } + + @Test + public void testOpenClose() throws InterruptedException { + var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + outbox.stop(); + } + + @Test + public void testSend() throws Exception { + var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + Executors.newSingleThreadExecutor().submit(() -> outbox.send("test", "Hello World")); + + TimeUnit.MILLISECONDS.sleep(100); + + var messages = MqTestUtil.getMessages(dataSource, inboxId); + assertEquals(1, messages.size()); + System.out.println(messages.get(0)); + + outbox.stop(); + } + + @Test + public void testSendAndRespond() throws Exception { + var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + + var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + inbox.subscribe(justRespond("Alright then")); + inbox.start(); + + var rsp = outbox.send("test", "Hello World"); + + assertEquals(MqMessageState.OK, rsp.state()); + assertEquals("Alright then", rsp.payload()); + + var messages = MqTestUtil.getMessages(dataSource, inboxId); + assertEquals(1, messages.size()); + assertEquals(MqMessageState.OK, messages.get(0).state()); + + outbox.stop(); + inbox.stop(); + } + + @Test + public void testSendMultiple() throws Exception { + var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + + var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + inbox.subscribe(echo()); + inbox.start(); + + var rsp1 = outbox.send("test", "one"); + var rsp2 = outbox.send("test", "two"); + var rsp3 = outbox.send("test", "three"); + var rsp4 = outbox.send("test", "four"); + + Thread.sleep(500); + + assertEquals(MqMessageState.OK, rsp1.state()); + assertEquals("one", rsp1.payload()); + assertEquals(MqMessageState.OK, rsp2.state()); + assertEquals("two", rsp2.payload()); + assertEquals(MqMessageState.OK, rsp3.state()); + assertEquals("three", rsp3.payload()); + assertEquals(MqMessageState.OK, rsp4.state()); + assertEquals("four", rsp4.payload()); + + var messages = MqTestUtil.getMessages(dataSource, inboxId); + assertEquals(4, messages.size()); + for (var message : messages) { + assertEquals(MqMessageState.OK, message.state()); + } + + outbox.stop(); + inbox.stop(); + } + + @Test + public void testSendAndRespondWithErrorHandler() throws Exception { + var outbox = new MqOutbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + var inbox = new MqInbox(new MqPersistence(dataSource), inboxId, UUID.randomUUID()); + + inbox.start(); + + var rsp = outbox.send("test", "Hello World"); + + assertEquals(MqMessageState.ERR, rsp.state()); + + var messages = MqTestUtil.getMessages(dataSource, inboxId); + assertEquals(1, messages.size()); + assertEquals(MqMessageState.ERR, messages.get(0).state()); + + outbox.stop(); + inbox.stop(); + } + + public MqSubscription justRespond(String response) { + return new MqSubscription() { + @Override + public boolean filter(MqMessage rawMessage) { + return true; + } + + @Override + public MqInboxResponse handle(MqMessage msg) { + return MqInboxResponse.ok(response); + } + }; + } + + public MqSubscription echo() { + return new MqSubscription() { + @Override + public boolean filter(MqMessage rawMessage) { + return true; + } + + @Override + public MqInboxResponse handle(MqMessage msg) { + return MqInboxResponse.ok(msg.payload()); + } + }; + } + +} \ No newline at end of file diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java new file mode 100644 index 00000000..590ff64b --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqPersistenceTest.java @@ -0,0 +1,189 @@ +package nu.marginalia.mq.outbox; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.persistence.MqPersistence; +import org.junit.jupiter.api.*; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +@Tag("slow") +@Testcontainers +public class MqPersistenceTest { + @Container + static MariaDBContainer> mariaDBContainer = new MariaDBContainer<>("mariadb") + .withDatabaseName("WMSA_prod") + .withUsername("wmsa") + .withPassword("wmsa") + .withInitScript("sql/current/11-message-queue.sql") + .withNetworkAliases("mariadb"); + + static HikariDataSource dataSource; + static MqPersistence persistence; + String recipientId; + String senderId; + + @BeforeEach + public void setUp() { + senderId = UUID.randomUUID().toString(); + recipientId = UUID.randomUUID().toString(); + } + + @BeforeAll + public static void setUpAll() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mariaDBContainer.getJdbcUrl()); + config.setUsername("wmsa"); + config.setPassword("wmsa"); + + dataSource = new HikariDataSource(config); + persistence = new MqPersistence(dataSource); + } + + @AfterAll + public static void tearDownAll() { + dataSource.close(); + } + + @Test + public void testReaper() throws Exception { + + long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(2)); + persistence.reapDeadMessages(); + + var messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + assertEquals(MqMessageState.NEW, messages.get(0).state()); + + TimeUnit.SECONDS.sleep(5); + + persistence.reapDeadMessages(); + + messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + assertEquals(MqMessageState.DEAD, messages.get(0).state()); + } + + @Test + public void sendWithReplyAddress() throws Exception { + + long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + + var messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + + var message = messages.get(0); + + assertEquals(id, message.id()); + assertEquals("function", message.function()); + assertEquals("payload", message.payload()); + assertEquals(MqMessageState.NEW, message.state()); + + System.out.println(message); + } + + @Test + public void sendNoReplyAddress() throws Exception { + + long id = persistence.sendNewMessage(recipientId, null, "function", "payload", Duration.ofSeconds(30)); + + var messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + + var message = messages.get(0); + + assertEquals(id, message.id()); + assertNull(message.senderInbox()); + assertEquals("function", message.function()); + assertEquals("payload", message.payload()); + assertEquals(MqMessageState.NEW, message.state()); + + System.out.println(message); + } + + @Test + public void updateState() throws Exception { + + long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + persistence.updateMessageState(id, MqMessageState.OK); + System.out.println(id); + + var messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + + var message = messages.get(0); + + assertEquals(id, message.id()); + assertEquals(MqMessageState.OK, message.state()); + + System.out.println(message); + } + + @Test + public void testReply() throws Exception { + long request = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + long response = persistence.sendResponse(request, MqMessageState.OK, "response"); + + var sentMessages = MqTestUtil.getMessages(dataSource, recipientId); + System.out.println(sentMessages); + assertEquals(1, sentMessages.size()); + + var requestMessage = sentMessages.get(0); + assertEquals(request, requestMessage.id()); + assertEquals(MqMessageState.OK, requestMessage.state()); + + + var replies = MqTestUtil.getMessages(dataSource, senderId); + System.out.println(replies); + assertEquals(1, replies.size()); + + var responseMessage = replies.get(0); + assertEquals(response, responseMessage.id()); + assertEquals(request, responseMessage.relatedId()); + assertEquals(MqMessageState.NEW, responseMessage.state()); + } + + @Test + public void testPollInbox() throws Exception { + + String instanceId = "BATMAN"; + long tick = 1234L; + + long id = persistence.sendNewMessage(recipientId, null,"function", "payload", Duration.ofSeconds(30)); + + var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick); + + /** CHECK POLL RESULT */ + assertEquals(1, messagesPollFirstTime.size()); + var firstPollMessage = messagesPollFirstTime.iterator().next(); + assertEquals(id, firstPollMessage.msgId()); + assertEquals("function", firstPollMessage.function()); + assertEquals("payload", firstPollMessage.payload()); + + /** CHECK DB TABLE */ + var messages = MqTestUtil.getMessages(dataSource, recipientId); + assertEquals(1, messages.size()); + + var message = messages.get(0); + + assertEquals(id, message.id()); + assertEquals("function", message.function()); + assertEquals("payload", message.payload()); + assertEquals(MqMessageState.ACK, message.state()); + assertEquals(instanceId, message.ownerInstance()); + assertEquals(tick, message.ownerTick()); + + /** VERIFY SECOND POLL IS EMPTY */ + var messagePollSecondTime = persistence.pollInbox(recipientId, instanceId , 1); + assertEquals(0, messagePollSecondTime.size()); + } +} diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java new file mode 100644 index 00000000..3fee8b20 --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/outbox/MqTestUtil.java @@ -0,0 +1,52 @@ +package nu.marginalia.mq.outbox; + +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageState; +import org.junit.jupiter.api.Assertions; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +public class MqTestUtil { + public static List getMessages(HikariDataSource dataSource, String inbox) { + List messages = new ArrayList<>(); + + try (var conn = dataSource.getConnection(); + var stmt = conn.prepareStatement(""" + SELECT ID, RELATED_ID, + SENDER_INBOX, RECIPIENT_INBOX, + FUNCTION, PAYLOAD, + STATE, + OWNER_INSTANCE, OWNER_TICK, + CREATED_TIME, UPDATED_TIME, + TTL + FROM PROC_MESSAGE + WHERE RECIPIENT_INBOX = ? + """)) + { + stmt.setString(1, inbox); + var rsp = stmt.executeQuery(); + while (rsp.next()) { + messages.add(new MqMessageRow( + rsp.getLong("ID"), + rsp.getLong("RELATED_ID"), + rsp.getString("SENDER_INBOX"), + rsp.getString("RECIPIENT_INBOX"), + rsp.getString("FUNCTION"), + rsp.getString("PAYLOAD"), + MqMessageState.valueOf(rsp.getString("STATE")), + rsp.getString("OWNER_INSTANCE"), + rsp.getLong("OWNER_TICK"), + rsp.getTimestamp("CREATED_TIME").getTime(), + rsp.getTimestamp("UPDATED_TIME").getTime(), + rsp.getLong("TTL") + )); + } + } + catch (SQLException ex) { + Assertions.fail(ex); + } + return messages; + } +} diff --git a/settings.gradle b/settings.gradle index 90d74f99..41e0cb53 100644 --- a/settings.gradle +++ b/settings.gradle @@ -48,6 +48,7 @@ include 'code:api:assistant-api' include 'code:common:service-discovery' include 'code:common:service-client' include 'code:common:db' +include 'code:common:message-queue' include 'code:common:service' include 'code:common:config' include 'code:common:model'
Messages pass through several states through their lifecycle