(mqsm) guard against spurious transitions from unexpected messages

This commit is contained in:
Viktor Lofgren 2023-07-12 22:44:05 +02:00
parent bf783dad7a
commit 6c88f00a9d
8 changed files with 90 additions and 34 deletions

View File

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE ( CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE (
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT COMMENT 'Unique id a related message', RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',

View File

@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE ( CREATE TABLE IF NOT EXISTS MESSAGE_QUEUE (
ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id', ID BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT 'Unique id',
RELATED_ID BIGINT COMMENT 'Unique id a related message', RELATED_ID BIGINT NOT NULL DEFAULT -1 COMMENT 'Unique id a related message',
SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox', SENDER_INBOX VARCHAR(255) COMMENT 'Name of the sender inbox',
RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox', RECIPIENT_INBOX VARCHAR(255) NOT NULL COMMENT 'Name of the recipient inbox',
FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run', FUNCTION VARCHAR(255) NOT NULL COMMENT 'Which function to run',

View File

@ -7,7 +7,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Time;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -107,7 +106,7 @@ public class MqOutbox {
* <br> * <br>
* Use waitResponse(id) or pollResponse(id) to fetch the response. */ * Use waitResponse(id) or pollResponse(id) to fetch the response. */
public long sendAsync(String function, String payload) throws Exception { public long sendAsync(String function, String payload) throws Exception {
var id = persistence.sendNewMessage(inboxName, replyInboxName, function, payload, null); var id = persistence.sendNewMessage(inboxName, replyInboxName, null, function, payload, null);
pendingRequests.put(id, id); pendingRequests.put(id, id);
@ -163,7 +162,13 @@ public class MqOutbox {
} }
public long notify(String function, String payload) throws Exception { public long notify(String function, String payload) throws Exception {
return persistence.sendNewMessage(inboxName, null, function, payload, null); return persistence.sendNewMessage(inboxName, null, null, function, payload, null);
}
public long notify(long relatedId, String function, String payload) throws Exception {
return persistence.sendNewMessage(inboxName, null, relatedId, function, payload, null);
} }
public void flagAsBad(long id) throws SQLException {
persistence.updateMessageState(id, MqMessageState.ERR);
}
} }

View File

@ -52,23 +52,25 @@ public class MqPersistence {
* Adds a new message to the message queue. * Adds a new message to the message queue.
* *
* @param recipientInboxName The recipient's inbox name * @param recipientInboxName The recipient's inbox name
* @param senderInboxName (nullable) The sender's inbox name. Only needed if a reply is expected. If null, the message is not expected to be replied to. * @param senderInboxName (nullable) The sender's inbox name. Only needed if a reply is expected. If null, the message is not expected to be replied to.
* @param function The function to call * @param relatedMessageId (nullable) The id of the message this message is related to. If null, the message is not related to any other message.
* @param payload The payload to send, typically JSON. * @param function The function to call
* @param ttl (nullable) The time to live of the message, in seconds. If null, the message will never set to DEAD. * @param payload The payload to send, typically JSON.
* @param ttl (nullable) The time to live of the message, in seconds. If null, the message will never set to DEAD.
* @return The id of the message * @return The id of the message
*/ */
public long sendNewMessage(String recipientInboxName, public long sendNewMessage(String recipientInboxName,
@Nullable @Nullable
String senderInboxName, String senderInboxName,
Long relatedMessageId,
String function, String function,
String payload, String payload,
@Nullable Duration ttl @Nullable Duration ttl
) throws Exception { ) throws Exception {
try (var conn = dataSource.getConnection(); try (var conn = dataSource.getConnection();
var stmt = conn.prepareStatement(""" var stmt = conn.prepareStatement("""
INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, FUNCTION, PAYLOAD, TTL) INSERT INTO MESSAGE_QUEUE(RECIPIENT_INBOX, SENDER_INBOX, RELATED_ID, FUNCTION, PAYLOAD, TTL)
VALUES(?, ?, ?, ?, ?) VALUES(?, ?, ?, ?, ?, ?)
"""); """);
var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) { var lastIdQuery = conn.prepareStatement("SELECT LAST_INSERT_ID()")) {
@ -77,10 +79,13 @@ public class MqPersistence {
if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR); if (senderInboxName == null) stmt.setNull(2, java.sql.Types.VARCHAR);
else stmt.setString(2, senderInboxName); else stmt.setString(2, senderInboxName);
stmt.setString(3, function); if (relatedMessageId == null) stmt.setLong(3, -1);
stmt.setString(4, payload); else stmt.setLong(3, relatedMessageId);
if (ttl == null) stmt.setNull(5, java.sql.Types.BIGINT);
else stmt.setLong(5, ttl.toSeconds()); stmt.setString(4, function);
stmt.setString(5, payload);
if (ttl == null) stmt.setNull(6, java.sql.Types.BIGINT);
else stmt.setLong(6, ttl.toSeconds());
stmt.executeUpdate(); stmt.executeUpdate();
var rsp = lastIdQuery.executeQuery(); var rsp = lastIdQuery.executeQuery();

View File

@ -36,6 +36,14 @@ public class StateMachine {
private final Map<String, MachineState> allStates = new HashMap<>(); private final Map<String, MachineState> allStates = new HashMap<>();
/* The expectedMessageId guards against spurious state changes being triggered by old messages in the queue
*
* It contains the message id of the last message that was processed, and the messages sent by the state machine to
* itself via the message queue all have relatedId set to expectedMessageId. If the state machine is unitialized or
* in a terminal state, it will accept messages with relatedIds that are equal to -1.
* */
private long expectedMessageId = -1;
public StateMachine(MqFactory messageQueueFactory, public StateMachine(MqFactory messageQueueFactory,
String queueName, String queueName,
UUID instanceUUID, UUID instanceUUID,
@ -99,7 +107,7 @@ public class StateMachine {
} }
smInbox.start(); smInbox.start();
smOutbox.notify(transition.state(), transition.message()); smOutbox.notify(expectedMessageId, transition.state(), transition.message());
} }
/** Initialize the state machine. */ /** Initialize the state machine. */
@ -112,7 +120,7 @@ public class StateMachine {
} }
smInbox.start(); smInbox.start();
smOutbox.notify(transition.state(), transition.message()); smOutbox.notify(expectedMessageId, transition.state(), transition.message());
} }
/** Resume the state machine from the last known state. */ /** Resume the state machine from the last known state. */
@ -133,6 +141,7 @@ public class StateMachine {
smInbox.start(); smInbox.start();
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state()); logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
expectedMessageId = firstMessage.relatedId();
if (firstMessage.state() == MqMessageState.NEW) { if (firstMessage.state() == MqMessageState.NEW) {
// The message is not acknowledged, so starting the inbox will trigger a state transition // The message is not acknowledged, so starting the inbox will trigger a state transition
@ -141,10 +150,10 @@ public class StateMachine {
state = resumingState; state = resumingState;
} else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) { } else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) {
// The message is acknowledged, but the state does not support resuming // The message is acknowledged, but the state does not support resuming
smOutbox.notify("ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function()); smOutbox.notify(expectedMessageId, "ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function());
} else { } else {
// The message is already acknowledged, so we replay the last state // The message is already acknowledged, so we replay the last state
onStateTransition(firstMessage.function(), firstMessage.payload()); onStateTransition(firstMessage);
} }
} }
@ -153,13 +162,24 @@ public class StateMachine {
smOutbox.stop(); smOutbox.stop();
} }
private void onStateTransition(String nextState, String message) { private void onStateTransition(MqMessage msg) {
final String nextState = msg.function();
final String data = msg.payload();
final long messageId = msg.msgId();
final long relatedId = msg.relatedId();
if (expectedMessageId != relatedId) {
// We've received a message that we didn't expect, throwing an exception will cause it to be flagged
// as an error in the message queue; the message queue will proceed
throw new IllegalStateException("Unexpected message id " + relatedId + ", expected " + expectedMessageId);
}
try { try {
logger.info("FSM State change in {}: {}->{}({})", logger.info("FSM State change in {}: {}->{}({})",
queueName, queueName,
state == null ? "[null]" : state.name(), state == null ? "[null]" : state.name(),
nextState, nextState,
message); data);
if (!allStates.containsKey(nextState)) { if (!allStates.containsKey(nextState)) {
logger.error("Unknown state {}", nextState); logger.error("Unknown state {}", nextState);
@ -173,8 +193,13 @@ public class StateMachine {
} }
if (!state.isFinal()) { if (!state.isFinal()) {
var transition = state.next(message); var transition = state.next(msg.payload());
smOutbox.notify(transition.state(), transition.message());
expectedMessageId = messageId;
smOutbox.notify(expectedMessageId, transition.state(), transition.message());
}
else {
expectedMessageId = -1;
} }
} }
catch (Exception e) { catch (Exception e) {
@ -204,7 +229,7 @@ public class StateMachine {
@Override @Override
public void onNotification(MqMessage msg) { public void onNotification(MqMessage msg) {
onStateTransition(msg.function(), msg.payload()); onStateTransition(msg);
try { try {
stateChangeListeners.forEach(l -> l.accept(msg.function(), msg.payload())); stateChangeListeners.forEach(l -> l.accept(msg.function(), msg.payload()));
} }

View File

@ -57,7 +57,7 @@ public class MqPersistenceTest {
@Test @Test
public void testReaper() throws Exception { public void testReaper() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(2)); long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(2));
persistence.reapDeadMessages(); persistence.reapDeadMessages();
var messages = MqTestUtil.getMessages(dataSource, recipientId); var messages = MqTestUtil.getMessages(dataSource, recipientId);
@ -77,7 +77,7 @@ public class MqPersistenceTest {
@Test @Test
public void sendWithReplyAddress() throws Exception { public void sendWithReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId); var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size()); assertEquals(1, messages.size());
@ -95,7 +95,7 @@ public class MqPersistenceTest {
@Test @Test
public void sendNoReplyAddress() throws Exception { public void sendNoReplyAddress() throws Exception {
long id = persistence.sendNewMessage(recipientId, null, "function", "payload", Duration.ofSeconds(30)); long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30));
var messages = MqTestUtil.getMessages(dataSource, recipientId); var messages = MqTestUtil.getMessages(dataSource, recipientId);
assertEquals(1, messages.size()); assertEquals(1, messages.size());
@ -114,7 +114,7 @@ public class MqPersistenceTest {
@Test @Test
public void updateState() throws Exception { public void updateState() throws Exception {
long id = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); long id = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
persistence.updateMessageState(id, MqMessageState.OK); persistence.updateMessageState(id, MqMessageState.OK);
System.out.println(id); System.out.println(id);
@ -131,7 +131,7 @@ public class MqPersistenceTest {
@Test @Test
public void testReply() throws Exception { public void testReply() throws Exception {
long request = persistence.sendNewMessage(recipientId, senderId, "function", "payload", Duration.ofSeconds(30)); long request = persistence.sendNewMessage(recipientId, senderId, null, "function", "payload", Duration.ofSeconds(30));
long response = persistence.sendResponse(request, MqMessageState.OK, "response"); long response = persistence.sendResponse(request, MqMessageState.OK, "response");
var sentMessages = MqTestUtil.getMessages(dataSource, recipientId); var sentMessages = MqTestUtil.getMessages(dataSource, recipientId);
@ -159,7 +159,7 @@ public class MqPersistenceTest {
String instanceId = "BATMAN"; String instanceId = "BATMAN";
long tick = 1234L; long tick = 1234L;
long id = persistence.sendNewMessage(recipientId, null,"function", "payload", Duration.ofSeconds(30)); long id = persistence.sendNewMessage(recipientId, null, null, "function", "payload", Duration.ofSeconds(30));
var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick, 10); var messagesPollFirstTime = persistence.pollInbox(recipientId, instanceId , tick, 10);

View File

@ -81,7 +81,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null);
sm.resume(); sm.resume();
@ -102,7 +102,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); long id = persistence.sendNewMessage(inboxId, null, -1L, "RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK); persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume(); sm.resume();
@ -125,7 +125,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); persistence.sendNewMessage(inboxId, null, -1L, "NON-RESUMABLE", "", null);
sm.resume(); sm.resume();
@ -146,7 +146,7 @@ public class StateMachineResumeTest {
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory)); var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); long id = persistence.sendNewMessage(inboxId, null, null, "NON-RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK); persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume(); sm.resume();

View File

@ -118,4 +118,25 @@ public class StateMachineTest {
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
} }
@Test
public void testFalseTransition() throws Exception {
var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(messageQueueFactory, inboxId, UUID.randomUUID(), new TestGraph(stateFactory));
// Prep the queue with a message to set the state to initial,
// and an additional message to trigger the false transition back to initial
persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null);
persistence.sendNewMessage(inboxId, null, null, "INITIAL", "", null);
sm.resume();
Thread.sleep(50);
sm.join();
sm.stop();
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
} }