diff --git a/code/common/db/src/main/resources/sql/current/12-message-queue.sql b/code/common/db/src/main/resources/sql/current/12-message-queue.sql index fd04f666..25bdc636 100644 --- a/code/common/db/src/main/resources/sql/current/12-message-queue.sql +++ b/code/common/db/src/main/resources/sql/current/12-message-queue.sql @@ -1,6 +1,6 @@ CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE ( ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', - RELATED_ID BIGINT COMMENT 'Unique id a related message', + RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message', SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', diff --git a/code/common/db/src/main/resources/sql/migrations/04-message-queue.sql b/code/common/db/src/main/resources/sql/migrations/04-message-queue.sql index fd04f666..25bdc636 100644 --- a/code/common/db/src/main/resources/sql/migrations/04-message-queue.sql +++ b/code/common/db/src/main/resources/sql/migrations/04-message-queue.sql @@ -1,6 +1,6 @@ CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE ( ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', - RELATED_ID BIGINT COMMENT 'Unique id a related message', + RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message', SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java index 88b9601f..22b4bc85 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/outbox/MqOutbox.java @@ -7,7 +7,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.SQLException; -import java.sql.Time; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -107,7 +106,7 @@ public class MqOutbox { *
* Use waitResponse(id) or pollResponse(id) to fetch the response. */ public long sendAsync(String function, String payload) throws Exception { - var id = persistence.sendNewMessage(inboxName, replyInboxName, function, payload, null); + var id = persistence.sendNewMessage(inboxName, replyInboxName, null, function, payload, null); pendingRequests.put(id, id); @@ -163,7 +162,13 @@ public class MqOutbox { } public long notify(String function, String payload) throws Exception { - return persistence.sendNewMessage(inboxName, null, function, payload, null); + return persistence.sendNewMessage(inboxName, null, null, function, payload, null); + } + public long notify(long relatedId, String function, String payload) throws Exception { + return persistence.sendNewMessage(inboxName, null, relatedId, function, payload, null); } + public void flagAsBad(long id) throws SQLException { + persistence.updateMessageState(id, MqMessageState.ERR); + } } \ No newline at end of file diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java index 198914b3..d075d445 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/persistence/MqPersistence.java @@ -52,23 +52,25 @@ public class MqPersistence { * Adds a new message to the message queue. * * @param recipientInboxName The recipient's inbox name - * @param senderInboxName (nullable) The sender's inbox name. Only needed if a reply is expected. If null, the message is not expected to be replied to. - * @param function The function to call - * @param payload The payload to send, typically JSON. - * @param ttl (nullable) The time to live of the message, in seconds. If null, the message will never set to DEAD. + * @param senderInboxName (nullable) The sender's inbox name. Only needed if a reply is expected. If null, the message is not expected to be replied to. + * @param relatedMessageId (nullable) The id of the message this message is related to. If null, the message is not related to any other message. + * @param function The function to call + * @param payload The payload to send, typically JSON. + * @param ttl (nullable) The time to live of the message, in seconds. If null, the message will never set to DEAD. * @return The id of the message */ public long sendNewMessage(String recipientInboxName, @Nullable String senderInboxName, + Long relatedMessageId, String function, String payload, @Nullable Duration ttl ) throws Exception { try (var conn = dataSource.getConnection(); var stmt = conn.prepareStatement(""" - INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, FUNCTION, PAYLOAD, TTL) - VALUES(?, ?, ?, ?, ?) + INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, RELATED_ID, FUNCTION, PAYLOAD, TTL) + VALUES(?, ?, ?, ?, ?, ?) """); var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) { @@ -77,10 +79,13 @@ public class MqPersistence { if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR); else stmt.setString(2, senderInboxName); - stmt.setString(3, function); - stmt.setString(4, payload); - if (ttl == null) stmt.setNull(5, java.sql.Types.BIGINT); - else stmt.setLong(5, ttl.toSeconds()); + if (relatedMessageId == null) stmt.setLong(3, -1); + else stmt.setLong(3, relatedMessageId); + + stmt.setString(4, function); + stmt.setString(5, payload); + if (ttl == null) stmt.setNull(6, java.sql.Types.BIGINT); + else stmt.setLong(6, ttl.toSeconds()); stmt.executeUpdate(); var rsp = lastIdQuery.executeQuery(); diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java index 9b7d2cfa..d039f363 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java @@ -36,6 +36,14 @@ public class StateMachine { private final Map allStates = new HashMap<>(); + /* The expectedMessageId guards against spurious state changes being triggered by old messages in the queue + * + * It contains the message id of the last message that was processed, and the messages sent by the state machine to + * itself via the message queue all have relatedId set to expectedMessageId. If the state machine is unitialized or + * in a terminal state, it will accept messages with relatedIds that are equal to -1. + * */ + private long expectedMessageId = -1; + public StateMachine(MqFactory messageQueueFactory, String queueName, UUID instanceUUID, @@ -99,7 +107,7 @@ public class StateMachine { } smInbox.start(); - smOutbox.notify(transition.state(), transition.message()); + smOutbox.notify(expectedMessageId, transition.state(), transition.message()); } /** Initialize the state machine. */ @@ -112,7 +120,7 @@ public class StateMachine { } smInbox.start(); - smOutbox.notify(transition.state(), transition.message()); + smOutbox.notify(expectedMessageId, transition.state(), transition.message()); } /** Resume the state machine from the last known state. */ @@ -133,6 +141,7 @@ public class StateMachine { smInbox.start(); logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state()); + expectedMessageId = firstMessage.relatedId(); if (firstMessage.state() == MqMessageState.NEW) { // The message is not acknowledged, so starting the inbox will trigger a state transition @@ -141,10 +150,10 @@ public class StateMachine { state = resumingState; } else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) { // The message is acknowledged, but the state does not support resuming - smOutbox.notify("ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function()); + smOutbox.notify(expectedMessageId, "ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function()); } else { // The message is already acknowledged, so we replay the last state - onStateTransition(firstMessage.function(), firstMessage.payload()); + onStateTransition(firstMessage); } } @@ -153,13 +162,24 @@ public class StateMachine { smOutbox.stop(); } - private void onStateTransition(String nextState, String message) { + private void onStateTransition(MqMessage msg) { + final String nextState = msg.function(); + final String data = msg.payload(); + final long messageId = msg.msgId(); + final long relatedId = msg.relatedId(); + + if (expectedMessageId != relatedId) { + // We've received a message that we didn't expect, throwing an exception will cause it to be flagged + // as an error in the message queue; the message queue will proceed + throw new IllegalStateException("Unexpected message id " + relatedId + ", expected " + expectedMessageId); + } + try { logger.info("FSM State change in {}: {}->{}({})", queueName, state == null ? "[null]" : state.name(), nextState, - message); + data); if (!allStates.containsKey(nextState)) { logger.error("Unknown state {}", nextState); @@ -173,8 +193,13 @@ public class StateMachine { } if (!state.isFinal()) { - var transition = state.next(message); - smOutbox.notify(transition.state(), transition.message()); + var transition = state.next(msg.payload()); + + expectedMessageId = messageId; + smOutbox.notify(expectedMessageId, transition.state(), transition.message()); + } + else { + expectedMessageId = -1; } } catch (Exception e) { @@ -204,7 +229,7 @@ public class StateMachine { @Override public void onNotification(MqMessage msg) { - onStateTransition(msg.function(), msg.payload()); + onStateTransition(msg); try { stateChangeListeners.forEach(l -> l.accept(msg.function(), msg.payload())); } diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java index 74f69682..4b93fa5e 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mq/persistence/MqPersistenceTest.java @@ -57,7 +57,7 @@ public class MqPersistenceTest { @Test public void testReaper() throws Exception { - long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(2)); + long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(2)); persistence.reapDeadMessages(); var messages = MqTestUtil.getMessages(dataSource, recipientId); @@ -77,7 +77,7 @@ public class MqPersistenceTest { @Test public void sendWithReplyAddress() throws Exception { - long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30)); var messages = MqTestUtil.getMessages(dataSource, recipientId); assertEquals(1, messages.size()); @@ -95,7 +95,7 @@ public class MqPersistenceTest { @Test public void sendNoReplyAddress() throws Exception { - long id = persistence.sendNewMessage(recipientId, null, "function", "payload", Duration.ofSeconds(30)); + long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30)); var messages = MqTestUtil.getMessages(dataSource, recipientId); assertEquals(1, messages.size()); @@ -114,7 +114,7 @@ public class MqPersistenceTest { @Test public void updateState() throws Exception { - long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30)); persistence.updateMessageState(id, MqMessageState.OK); System.out.println(id); @@ -131,7 +131,7 @@ public class MqPersistenceTest { @Test public void testReply() throws Exception { - long request = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); + long request = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30)); long response = persistence.sendResponse(request, MqMessageState.OK, "response"); var sentMessages = MqTestUtil.getMessages(dataSource, recipientId); @@ -159,7 +159,7 @@ public class MqPersistenceTest { String instanceId = "BATMAN"; long tick = 1234L; - long id = persistence.sendNewMessage(recipientId, null,"function", "payload", Duration.ofSeconds(30)); + long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30)); var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick, 10); diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java index 79af8d07..bf4e9990 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java @@ -81,7 +81,7 @@ public class StateMachineResumeTest { var stateFactory = new StateFactory(new GsonBuilder().create()); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); - persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); + persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null); sm.resume(); @@ -102,7 +102,7 @@ public class StateMachineResumeTest { var stateFactory = new StateFactory(new GsonBuilder().create()); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); - long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); + long id = persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null); persistence.updateMessageState(id, MqMessageState.ACK); sm.resume(); @@ -125,7 +125,7 @@ public class StateMachineResumeTest { var stateFactory = new StateFactory(new GsonBuilder().create()); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); - persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); + persistence.sendNewMessage(inboxId, null, -1L, "NON-RESUMABLE", "", null); sm.resume(); @@ -146,7 +146,7 @@ public class StateMachineResumeTest { var stateFactory = new StateFactory(new GsonBuilder().create()); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); - long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); + long id = persistence.sendNewMessage(inboxId, null, null, "NON-RESUMABLE", "", null); persistence.updateMessageState(id, MqMessageState.ACK); sm.resume(); diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java index 27ae869e..e8dcaa83 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java @@ -118,4 +118,25 @@ public class StateMachineTest { MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); } + @Test + public void testFalseTransition() throws Exception { + var stateFactory = new StateFactory(new GsonBuilder().create()); + var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new TestGraph(stateFactory)); + + // Prep the queue with a message to set the state to initial, + // and an additional message to trigger the false transition back to initial + + persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null); + persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null); + + sm.resume(); + + Thread.sleep(50); + + sm.join(); + sm.stop(); + + MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); + } + }