diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java index 00b30cad..20184f32 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mq/inbox/MqInbox.java @@ -38,7 +38,15 @@ public class MqInbox { String inboxName, UUID instanceUUID) { - this.threadPool = Executors.newCachedThreadPool(); + this(persistence, inboxName, instanceUUID, Executors.newCachedThreadPool()); + } + + public MqInbox(MqPersistence persistence, + String inboxName, + UUID instanceUUID, + ExecutorService executorService) + { + this.threadPool = executorService; this.persistence = persistence; this.inboxName = inboxName; this.instanceUUID = instanceUUID.toString(); diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java index 8dccde4b..09c02ea7 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateFactory.java @@ -4,6 +4,7 @@ import com.google.gson.Gson; import com.google.inject.Inject; import com.google.inject.Singleton; import nu.marginalia.mqsm.state.MachineState; +import nu.marginalia.mqsm.state.ResumeBehavior; import nu.marginalia.mqsm.state.StateTransition; import java.util.function.Function; @@ -18,7 +19,7 @@ public class StateFactory { this.gson = gson; } - public MachineState create(String name, Class param, Function logic) { + public MachineState create(String name, ResumeBehavior resumeBehavior, Class param, Function logic) { return new MachineState() { @Override public String name() { @@ -30,6 +31,11 @@ public class StateFactory { return logic.apply(gson.fromJson(message, param)); } + @Override + public ResumeBehavior resumeBehavior() { + return resumeBehavior; + } + @Override public boolean isFinal() { return false; @@ -37,7 +43,7 @@ public class StateFactory { }; } - public MachineState create(String name, Supplier logic) { + public MachineState create(String name, ResumeBehavior resumeBehavior, Supplier logic) { return new MachineState() { @Override public String name() { @@ -49,6 +55,11 @@ public class StateFactory { return logic.get(); } + @Override + public ResumeBehavior resumeBehavior() { + return resumeBehavior; + } + @Override public boolean isFinal() { return false; diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java index cb7d1f33..827005ed 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/StateMachine.java @@ -7,6 +7,7 @@ import nu.marginalia.mq.inbox.MqInboxResponse; import nu.marginalia.mq.inbox.MqSubscription; import nu.marginalia.mq.outbox.MqOutbox; import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mqsm.graph.StateGraph; import nu.marginalia.mqsm.state.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -15,6 +16,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.Executors; /** A state machine that can be used to implement a finite state machine * using a message queue as the persistence layer. The state machine is @@ -37,7 +39,7 @@ public class StateMachine { public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) { this.queueName = queueName; - smInbox = new MqInbox(persistence, queueName, instanceUUID); + smInbox = new MqInbox(persistence, queueName, instanceUUID, Executors.newSingleThreadExecutor()); smOutbox = new MqOutbox(persistence, queueName, instanceUUID); smInbox.subscribe(new StateEventSubscription()); @@ -63,6 +65,11 @@ public class StateMachine { } } + /** Register the state graph */ + public void registerStates(StateGraph states) { + registerStates(states.asStateList()); + } + /** Wait for the state machine to reach a final state. * (possibly forever, halting problem and so on) */ @@ -94,29 +101,33 @@ public class StateMachine { /** Resume the state machine from the last known state. */ public void resume() throws Exception { - if (state == null) { - var messages = smInbox.replay(1); + if (state != null) { + return; + } - if (messages.isEmpty()) { - init(); - } else { - var firstMessage = messages.get(0); + var messages = smInbox.replay(1); + if (messages.isEmpty()) { + init(); + return; + } - smInbox.start(); + var firstMessage = messages.get(0); + var resumeState = allStates.get(firstMessage.function()); - logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state()); + smInbox.start(); + logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state()); - if (firstMessage.state() == MqMessageState.NEW) { - // The message is not acknowledged, so starting the inbox will trigger a state transition - // - // We still need to set a state here so that the join() method works + if (firstMessage.state() == MqMessageState.NEW) { + // The message is not acknowledged, so starting the inbox will trigger a state transition + // We still need to set a state here so that the join() method works - state = resumingState; - } else { - // The message is already acknowledged, so we replay the last state - onStateTransition(firstMessage.function(), firstMessage.payload()); - } - } + state = resumingState; + } else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) { + // The message is acknowledged, but the state does not support resuming + smOutbox.notify("ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function()); + } else { + // The message is already acknowledged, so we replay the last state + onStateTransition(firstMessage.function(), firstMessage.payload()); } } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/ControlFlowException.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/ControlFlowException.java new file mode 100644 index 00000000..aece44ea --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/ControlFlowException.java @@ -0,0 +1,21 @@ +package nu.marginalia.mqsm.graph; + +class ControlFlowException extends RuntimeException { + private final String state; + private final Object payload; + + public ControlFlowException(String state, Object payload) { + this.state = state; + this.payload = payload; + } + + public String getState() { + return state; + } + + public Object getPayload() { + return payload; + } + + public StackTraceElement[] getStackTrace() { return new StackTraceElement[0]; } +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/GraphState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/GraphState.java new file mode 100644 index 00000000..b79b71aa --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/GraphState.java @@ -0,0 +1,14 @@ +package nu.marginalia.mqsm.graph; + + +import nu.marginalia.mqsm.state.ResumeBehavior; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +@Retention(RetentionPolicy.RUNTIME) +public @interface GraphState { + String name(); + String next() default "ERROR"; + ResumeBehavior resume() default ResumeBehavior.ERROR; +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/StateGraph.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/StateGraph.java new file mode 100644 index 00000000..df8f4318 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/StateGraph.java @@ -0,0 +1,121 @@ +package nu.marginalia.mqsm.graph; + +import nu.marginalia.mqsm.StateFactory; +import nu.marginalia.mqsm.state.MachineState; +import nu.marginalia.mqsm.state.StateTransition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public abstract class StateGraph { + private final StateFactory stateFactory; + private static final Logger logger = LoggerFactory.getLogger(StateGraph.class); + + public StateGraph(StateFactory stateFactory) { + this.stateFactory = stateFactory; + } + + public void transition(String state) { + throw new ControlFlowException(state, ""); + } + + public void transition(String state, T payload) { + throw new ControlFlowException(state, payload); + } + + public void error() { + throw new ControlFlowException("ERROR", ""); + } + public void error(T payload) { + throw new ControlFlowException("ERROR", payload); + } + public void error(Exception ex) { + throw new ControlFlowException("ERROR", ex.getClass().getSimpleName() + ":" + ex.getMessage()); + } + + public List asStateList() { + List ret = new ArrayList<>(); + + for (var method : getClass().getMethods()) { + var gs = method.getAnnotation(GraphState.class); + if (gs != null) { + ret.add(graphState(method, gs)); + } + } + + return ret; + } + + private MachineState graphState(Method method, GraphState gs) { + + var parameters = method.getParameterTypes(); + boolean returnsVoid = method.getGenericReturnType().equals(Void.TYPE); + + if (parameters.length == 0) { + return stateFactory.create(gs.name(), gs.resume(), () -> { + try { + if (returnsVoid) { + method.invoke(this); + return StateTransition.to(gs.next()); + } else { + Object ret = method.invoke(this); + return stateFactory.transition(gs.next(), ret); + } + } + catch (Exception e) { + return invocationExceptionToStateTransition(gs.name(), e); + } + }); + } + else if (parameters.length == 1) { + return stateFactory.create(gs.name(), gs.resume(), parameters[0], (param) -> { + try { + if (returnsVoid) { + method.invoke(this, param); + return StateTransition.to(gs.next()); + } else { + Object ret = method.invoke(this, param); + return stateFactory.transition(gs.next(), ret); + } + } catch (Exception e) { + return invocationExceptionToStateTransition(gs.name(), e); + } + }); + } + else { + // We permit only @GraphState-annotated methods like this: + // + // void foo(); + // void foo(Object bar); + // Object foo(); + // Object foo(Object bar); + + throw new IllegalStateException("StateGraph " + + getClass().getSimpleName() + + " has invalid method signature for method " + + method.getName() + + ": Expected 0 or 1 parameter(s) but found " + + Arrays.toString(parameters)); + } + } + + private StateTransition invocationExceptionToStateTransition(String state, Throwable ex) { + while (ex instanceof InvocationTargetException e) { + if (e.getCause() != null) ex = ex.getCause(); + } + + if (ex instanceof ControlFlowException cfe) { + return stateFactory.transition(cfe.getState(), cfe.getPayload()); + } else { + logger.error("Error in state invocation " + state, ex); + return StateTransition.to("ERROR", + "Exception: " + ex.getClass().getSimpleName() + "/" + ex.getMessage()); + } + } + +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/TerminalState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/TerminalState.java new file mode 100644 index 00000000..5ae062b7 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/graph/TerminalState.java @@ -0,0 +1,9 @@ +package nu.marginalia.mqsm.graph; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +@Retention(RetentionPolicy.RUNTIME) +public @interface TerminalState { + String name(); +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java index 4f1fef96..dcb19125 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ErrorState.java @@ -9,6 +9,9 @@ public class ErrorState implements MachineState { throw new UnsupportedOperationException(); } + @Override + public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; } + @Override public boolean isFinal() { return true; } } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java index 5ee7d435..dc2362fe 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/FinalState.java @@ -9,6 +9,9 @@ public class FinalState implements MachineState { throw new UnsupportedOperationException(); } + @Override + public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; } + @Override public boolean isFinal() { return true; } } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java index 4bba33cf..11efc7c5 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/MachineState.java @@ -4,5 +4,6 @@ public interface MachineState { String name(); StateTransition next(String message); + ResumeBehavior resumeBehavior(); boolean isFinal(); } diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumeBehavior.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumeBehavior.java new file mode 100644 index 00000000..a82446f8 --- /dev/null +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumeBehavior.java @@ -0,0 +1,6 @@ +package nu.marginalia.mqsm.state; + +public enum ResumeBehavior { + RETRY, + ERROR +} diff --git a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java index 36a474e2..ce01bb79 100644 --- a/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java +++ b/code/common/message-queue/src/main/java/nu/marginalia/mqsm/state/ResumingState.java @@ -9,6 +9,9 @@ public class ResumingState implements MachineState { throw new UnsupportedOperationException(); } + @Override + public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; } + @Override public boolean isFinal() { return false; } } diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineErrorTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineErrorTest.java new file mode 100644 index 00000000..6c6298eb --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineErrorTest.java @@ -0,0 +1,100 @@ +package nu.marginalia.mqsm; + +import com.google.gson.GsonBuilder; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageRow; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.MqTestUtil; +import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mqsm.graph.GraphState; +import nu.marginalia.mqsm.graph.StateGraph; +import nu.marginalia.mqsm.state.ResumeBehavior; +import org.junit.jupiter.api.*; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Tag("slow") +@Testcontainers +public class StateMachineErrorTest { + @Container + static MariaDBContainer mariaDBContainer = new MariaDBContainer<>("mariadb") + .withDatabaseName("WMSA_prod") + .withUsername("wmsa") + .withPassword("wmsa") + .withInitScript("sql/current/11-message-queue.sql") + .withNetworkAliases("mariadb"); + + static HikariDataSource dataSource; + static MqPersistence persistence; + private String inboxId; + + @BeforeEach + public void setUp() { + inboxId = UUID.randomUUID().toString(); + } + @BeforeAll + public static void setUpAll() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mariaDBContainer.getJdbcUrl()); + config.setUsername("wmsa"); + config.setPassword("wmsa"); + + dataSource = new HikariDataSource(config); + persistence = new MqPersistence(dataSource); + } + + @AfterAll + public static void tearDownAll() { + dataSource.close(); + } + + public static class ErrorHurdles extends StateGraph { + + public ErrorHurdles(StateFactory stateFactory) { + super(stateFactory); + } + + @GraphState(name = "INITIAL", next = "FAILING") + public void initial() { + + } + @GraphState(name = "FAILING", next = "OK", resume = ResumeBehavior.RETRY) + public void resumable() { + throw new RuntimeException("Boom!"); + } + @GraphState(name = "OK", next = "END") + public void ok() { + + } + + } + + @Test + public void smResumeResumableFromNew() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ErrorHurdles(stateFactory).asStateList()); + + sm.init(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("INITIAL", "FAILING", "ERROR"), states); + } + +} diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java new file mode 100644 index 00000000..6913e13a --- /dev/null +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineResumeTest.java @@ -0,0 +1,191 @@ +package nu.marginalia.mqsm; + +import com.google.gson.GsonBuilder; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import nu.marginalia.mq.MqMessageRow; +import nu.marginalia.mq.MqMessageState; +import nu.marginalia.mq.MqTestUtil; +import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mqsm.graph.GraphState; +import nu.marginalia.mqsm.graph.StateGraph; +import nu.marginalia.mqsm.state.ResumeBehavior; +import org.junit.jupiter.api.*; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Tag("slow") +@Testcontainers +public class StateMachineResumeTest { + @Container + static MariaDBContainer mariaDBContainer = new MariaDBContainer<>("mariadb") + .withDatabaseName("WMSA_prod") + .withUsername("wmsa") + .withPassword("wmsa") + .withInitScript("sql/current/11-message-queue.sql") + .withNetworkAliases("mariadb"); + + static HikariDataSource dataSource; + static MqPersistence persistence; + private String inboxId; + + @BeforeEach + public void setUp() { + inboxId = UUID.randomUUID().toString(); + } + @BeforeAll + public static void setUpAll() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(mariaDBContainer.getJdbcUrl()); + config.setUsername("wmsa"); + config.setPassword("wmsa"); + + dataSource = new HikariDataSource(config); + persistence = new MqPersistence(dataSource); + } + + @AfterAll + public static void tearDownAll() { + dataSource.close(); + } + + public static class ResumeTrialsGraph extends StateGraph { + + public ResumeTrialsGraph(StateFactory stateFactory) { + super(stateFactory); + } + + @GraphState(name = "INITIAL", next = "RESUMABLE") + public void initial() {} + @GraphState(name = "RESUMABLE", next = "NON-RESUMABLE", resume = ResumeBehavior.RETRY) + public void resumable() {} + @GraphState(name = "NON-RESUMABLE", next = "OK", resume = ResumeBehavior.ERROR) + public void nonResumable() {} + + @GraphState(name = "OK", next = "END") + public void ok() {} + + } + + @Test + public void smResumeResumableFromNew() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ResumeTrialsGraph(stateFactory).asStateList()); + + persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("RESUMABLE", "NON-RESUMABLE", "OK", "END"), states); + } + + @Test + public void smResumeFromAck() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ResumeTrialsGraph(stateFactory)); + + long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); + persistence.updateMessageState(id, MqMessageState.ACK); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("RESUMABLE", "NON-RESUMABLE", "OK", "END"), states); + } + + + @Test + public void smResumeNonResumableFromNew() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ResumeTrialsGraph(stateFactory)); + + persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("NON-RESUMABLE", "OK", "END"), states); + } + + @Test + public void smResumeNonResumableFromAck() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ResumeTrialsGraph(stateFactory)); + + long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); + persistence.updateMessageState(id, MqMessageState.ACK); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("NON-RESUMABLE", "ERROR"), states); + } + + @Test + public void smResumeEmptyQueue() throws Exception { + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + var stateFactory = new StateFactory(new GsonBuilder().create()); + + sm.registerStates(new ResumeTrialsGraph(stateFactory)); + + sm.resume(); + + sm.join(); + sm.stop(); + + List states = MqTestUtil.getMessages(dataSource, inboxId) + .stream() + .peek(System.out::println) + .map(MqMessageRow::function) + .toList(); + + assertEquals(List.of("INITIAL", "RESUMABLE", "NON-RESUMABLE", "OK", "END"), states); + } +} diff --git a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java index 06cc658c..789b13ad 100644 --- a/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java +++ b/code/common/message-queue/src/test/java/nu/marginalia/mqsm/StateMachineTest.java @@ -7,6 +7,9 @@ import nu.marginalia.mq.MqMessageRow; import nu.marginalia.mq.MqMessageState; import nu.marginalia.mq.MqTestUtil; import nu.marginalia.mq.persistence.MqPersistence; +import nu.marginalia.mqsm.graph.GraphState; +import nu.marginalia.mqsm.graph.StateGraph; +import nu.marginalia.mqsm.state.ResumeBehavior; import org.junit.jupiter.api.*; import org.testcontainers.containers.MariaDBContainer; import org.testcontainers.junit.jupiter.Container; @@ -52,19 +55,63 @@ public class StateMachineTest { dataSource.close(); } + public static class TestGraph extends StateGraph { + public TestGraph(StateFactory stateFactory) { + super(stateFactory); + } + + @GraphState(name = "INITIAL", next = "GREET") + public String initial() { + return "World"; + } + + @GraphState(name = "GREET") + public void greet(String message) { + System.out.println("Hello, " + message + "!"); + + transition("COUNT-DOWN", 5); + } + + @GraphState(name = "COUNT-DOWN", next = "END") + public void countDown(Integer from) { + if (from > 0) { + System.out.println(from); + transition("COUNT-DOWN", from - 1); + } + } + } + + @Test + public void testAnnotatedStateGraph() throws Exception { + var stateFactory = new StateFactory(new GsonBuilder().create()); + var graph = new TestGraph(stateFactory); + + + var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); + sm.registerStates(graph.asStateList()); + + sm.init(); + + sm.join(); + sm.stop(); + + MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); + + } + @Test public void testStartStopStartStop() throws Exception { var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); var stateFactory = new StateFactory(new GsonBuilder().create()); - var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("GREET", "World")); + var initial = stateFactory.create("INITIAL", ResumeBehavior.RETRY, () -> stateFactory.transition("GREET", "World")); - var greet = stateFactory.create("GREET", String.class, (String message) -> { + var greet = stateFactory.create("GREET", ResumeBehavior.RETRY, String.class, (String message) -> { System.out.println("Hello, " + message + "!"); return stateFactory.transition("COUNT-TO-FIVE", 0); }); - var ctf = stateFactory.create("COUNT-TO-FIVE", Integer.class, (Integer count) -> { + var ctf = stateFactory.create("COUNT-TO-FIVE", ResumeBehavior.RETRY, Integer.class, (Integer count) -> { System.out.println(count); if (count < 5) { return stateFactory.transition("COUNT-TO-FIVE", count + 1); @@ -89,86 +136,4 @@ public class StateMachineTest { MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println); } - @Test - public void smResumeFromNew() throws Exception { - var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); - var stateFactory = new StateFactory(new GsonBuilder().create()); - - var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); - var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); - var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); - var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); - - sm.registerStates(initial, stateA, stateB, stateC); - persistence.sendNewMessage(inboxId, null,"B", "", null); - - sm.resume(); - - sm.join(); - sm.stop(); - - List states = MqTestUtil.getMessages(dataSource, inboxId) - .stream() - .peek(System.out::println) - .map(MqMessageRow::function) - .toList(); - - assertEquals(List.of("B", "C", "END"), states); - } - - @Test - public void smResumeFromAck() throws Exception { - var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); - var stateFactory = new StateFactory(new GsonBuilder().create()); - - var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); - var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); - var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); - var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); - - sm.registerStates(initial, stateA, stateB, stateC); - - long id = persistence.sendNewMessage(inboxId, null,"B", "", null); - persistence.updateMessageState(id, MqMessageState.ACK); - - sm.resume(); - - sm.join(); - sm.stop(); - - List states = MqTestUtil.getMessages(dataSource, inboxId) - .stream() - .peek(System.out::println) - .map(MqMessageRow::function) - .toList(); - - assertEquals(List.of("B", "C", "END"), states); - } - - - @Test - public void smResumeEmptyQueue() throws Exception { - var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); - var stateFactory = new StateFactory(new GsonBuilder().create()); - - var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A")); - var stateA = stateFactory.create("A", () -> stateFactory.transition("B")); - var stateB = stateFactory.create("B", () -> stateFactory.transition("C")); - var stateC = stateFactory.create("C", () -> stateFactory.transition("END")); - - sm.registerStates(initial, stateA, stateB, stateC); - - sm.resume(); - - sm.join(); - sm.stop(); - - List states = MqTestUtil.getMessages(dataSource, inboxId) - .stream() - .peek(System.out::println) - .map(MqMessageRow::function) - .toList(); - - assertEquals(List.of("INITIAL", "A", "B", "C", "END"), states); - } }