MQFSM Usability WIP

This commit is contained in:
Viktor Lofgren 2023-07-06 13:02:16 +02:00
parent 413dc6ced4
commit d89db10645
15 changed files with 574 additions and 107 deletions

View File

@ -38,7 +38,15 @@ public class MqInbox {
String inboxName,
UUID instanceUUID)
{
this.threadPool = Executors.newCachedThreadPool();
this(persistence, inboxName, instanceUUID, Executors.newCachedThreadPool());
}
public MqInbox(MqPersistence persistence,
String inboxName,
UUID instanceUUID,
ExecutorService executorService)
{
this.threadPool = executorService;
this.persistence = persistence;
this.inboxName = inboxName;
this.instanceUUID = instanceUUID.toString();

View File

@ -4,6 +4,7 @@ import com.google.gson.Gson;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import nu.marginalia.mqsm.state.MachineState;
import nu.marginalia.mqsm.state.ResumeBehavior;
import nu.marginalia.mqsm.state.StateTransition;
import java.util.function.Function;
@ -18,7 +19,7 @@ public class StateFactory {
this.gson = gson;
}
public <T> MachineState create(String name, Class<T> param, Function<T, StateTransition> logic) {
public <T> MachineState create(String name, ResumeBehavior resumeBehavior, Class<T> param, Function<T, StateTransition> logic) {
return new MachineState() {
@Override
public String name() {
@ -30,6 +31,11 @@ public class StateFactory {
return logic.apply(gson.fromJson(message, param));
}
@Override
public ResumeBehavior resumeBehavior() {
return resumeBehavior;
}
@Override
public boolean isFinal() {
return false;
@ -37,7 +43,7 @@ public class StateFactory {
};
}
public MachineState create(String name, Supplier<StateTransition> logic) {
public MachineState create(String name, ResumeBehavior resumeBehavior, Supplier<StateTransition> logic) {
return new MachineState() {
@Override
public String name() {
@ -49,6 +55,11 @@ public class StateFactory {
return logic.get();
}
@Override
public ResumeBehavior resumeBehavior() {
return resumeBehavior;
}
@Override
public boolean isFinal() {
return false;

View File

@ -7,6 +7,7 @@ import nu.marginalia.mq.inbox.MqInboxResponse;
import nu.marginalia.mq.inbox.MqSubscription;
import nu.marginalia.mq.outbox.MqOutbox;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.StateGraph;
import nu.marginalia.mqsm.state.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -15,6 +16,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Executors;
/** A state machine that can be used to implement a finite state machine
* using a message queue as the persistence layer. The state machine is
@ -37,7 +39,7 @@ public class StateMachine {
public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) {
this.queueName = queueName;
smInbox = new MqInbox(persistence, queueName, instanceUUID);
smInbox = new MqInbox(persistence, queueName, instanceUUID, Executors.newSingleThreadExecutor());
smOutbox = new MqOutbox(persistence, queueName, instanceUUID);
smInbox.subscribe(new StateEventSubscription());
@ -63,6 +65,11 @@ public class StateMachine {
}
}
/** Register the state graph */
public void registerStates(StateGraph states) {
registerStates(states.asStateList());
}
/** Wait for the state machine to reach a final state.
* (possibly forever, halting problem and so on)
*/
@ -94,29 +101,33 @@ public class StateMachine {
/** Resume the state machine from the last known state. */
public void resume() throws Exception {
if (state == null) {
var messages = smInbox.replay(1);
if (state != null) {
return;
}
if (messages.isEmpty()) {
init();
} else {
var firstMessage = messages.get(0);
var messages = smInbox.replay(1);
if (messages.isEmpty()) {
init();
return;
}
smInbox.start();
var firstMessage = messages.get(0);
var resumeState = allStates.get(firstMessage.function());
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
smInbox.start();
logger.info("Resuming state machine from {}({})/{}", firstMessage.function(), firstMessage.payload(), firstMessage.state());
if (firstMessage.state() == MqMessageState.NEW) {
// The message is not acknowledged, so starting the inbox will trigger a state transition
//
// We still need to set a state here so that the join() method works
if (firstMessage.state() == MqMessageState.NEW) {
// The message is not acknowledged, so starting the inbox will trigger a state transition
// We still need to set a state here so that the join() method works
state = resumingState;
} else {
// The message is already acknowledged, so we replay the last state
onStateTransition(firstMessage.function(), firstMessage.payload());
}
}
state = resumingState;
} else if (resumeState.resumeBehavior().equals(ResumeBehavior.ERROR)) {
// The message is acknowledged, but the state does not support resuming
smOutbox.notify("ERROR", "Illegal resumption from ACK'ed state " + firstMessage.function());
} else {
// The message is already acknowledged, so we replay the last state
onStateTransition(firstMessage.function(), firstMessage.payload());
}
}

View File

@ -0,0 +1,21 @@
package nu.marginalia.mqsm.graph;
class ControlFlowException extends RuntimeException {
private final String state;
private final Object payload;
public ControlFlowException(String state, Object payload) {
this.state = state;
this.payload = payload;
}
public String getState() {
return state;
}
public Object getPayload() {
return payload;
}
public StackTraceElement[] getStackTrace() { return new StackTraceElement[0]; }
}

View File

@ -0,0 +1,14 @@
package nu.marginalia.mqsm.graph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@Retention(RetentionPolicy.RUNTIME)
public @interface GraphState {
String name();
String next() default "ERROR";
ResumeBehavior resume() default ResumeBehavior.ERROR;
}

View File

@ -0,0 +1,121 @@
package nu.marginalia.mqsm.graph;
import nu.marginalia.mqsm.StateFactory;
import nu.marginalia.mqsm.state.MachineState;
import nu.marginalia.mqsm.state.StateTransition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public abstract class StateGraph {
private final StateFactory stateFactory;
private static final Logger logger = LoggerFactory.getLogger(StateGraph.class);
public StateGraph(StateFactory stateFactory) {
this.stateFactory = stateFactory;
}
public void transition(String state) {
throw new ControlFlowException(state, "");
}
public <T> void transition(String state, T payload) {
throw new ControlFlowException(state, payload);
}
public void error() {
throw new ControlFlowException("ERROR", "");
}
public <T> void error(T payload) {
throw new ControlFlowException("ERROR", payload);
}
public void error(Exception ex) {
throw new ControlFlowException("ERROR", ex.getClass().getSimpleName() + ":" + ex.getMessage());
}
public List<MachineState> asStateList() {
List<MachineState> ret = new ArrayList<>();
for (var method : getClass().getMethods()) {
var gs = method.getAnnotation(GraphState.class);
if (gs != null) {
ret.add(graphState(method, gs));
}
}
return ret;
}
private MachineState graphState(Method method, GraphState gs) {
var parameters = method.getParameterTypes();
boolean returnsVoid = method.getGenericReturnType().equals(Void.TYPE);
if (parameters.length == 0) {
return stateFactory.create(gs.name(), gs.resume(), () -> {
try {
if (returnsVoid) {
method.invoke(this);
return StateTransition.to(gs.next());
} else {
Object ret = method.invoke(this);
return stateFactory.transition(gs.next(), ret);
}
}
catch (Exception e) {
return invocationExceptionToStateTransition(gs.name(), e);
}
});
}
else if (parameters.length == 1) {
return stateFactory.create(gs.name(), gs.resume(), parameters[0], (param) -> {
try {
if (returnsVoid) {
method.invoke(this, param);
return StateTransition.to(gs.next());
} else {
Object ret = method.invoke(this, param);
return stateFactory.transition(gs.next(), ret);
}
} catch (Exception e) {
return invocationExceptionToStateTransition(gs.name(), e);
}
});
}
else {
// We permit only @GraphState-annotated methods like this:
//
// void foo();
// void foo(Object bar);
// Object foo();
// Object foo(Object bar);
throw new IllegalStateException("StateGraph " +
getClass().getSimpleName() +
" has invalid method signature for method " +
method.getName() +
": Expected 0 or 1 parameter(s) but found " +
Arrays.toString(parameters));
}
}
private StateTransition invocationExceptionToStateTransition(String state, Throwable ex) {
while (ex instanceof InvocationTargetException e) {
if (e.getCause() != null) ex = ex.getCause();
}
if (ex instanceof ControlFlowException cfe) {
return stateFactory.transition(cfe.getState(), cfe.getPayload());
} else {
logger.error("Error in state invocation " + state, ex);
return StateTransition.to("ERROR",
"Exception: " + ex.getClass().getSimpleName() + "/" + ex.getMessage());
}
}
}

View File

@ -0,0 +1,9 @@
package nu.marginalia.mqsm.graph;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@Retention(RetentionPolicy.RUNTIME)
public @interface TerminalState {
String name();
}

View File

@ -9,6 +9,9 @@ public class ErrorState implements MachineState {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}

View File

@ -9,6 +9,9 @@ public class FinalState implements MachineState {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}

View File

@ -4,5 +4,6 @@ public interface MachineState {
String name();
StateTransition next(String message);
ResumeBehavior resumeBehavior();
boolean isFinal();
}

View File

@ -0,0 +1,6 @@
package nu.marginalia.mqsm.state;
public enum ResumeBehavior {
RETRY,
ERROR
}

View File

@ -9,6 +9,9 @@ public class ResumingState implements MachineState {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return false; }
}

View File

@ -0,0 +1,100 @@
package nu.marginalia.mqsm;
import com.google.gson.GsonBuilder;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag("slow")
@Testcontainers
public class StateMachineErrorTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/11-message-queue.sql")
.withNetworkAliases("mariadb");
static HikariDataSource dataSource;
static MqPersistence persistence;
private String inboxId;
@BeforeEach
public void setUp() {
inboxId = UUID.randomUUID().toString();
}
@BeforeAll
public static void setUpAll() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
config.setUsername("wmsa");
config.setPassword("wmsa");
dataSource = new HikariDataSource(config);
persistence = new MqPersistence(dataSource);
}
@AfterAll
public static void tearDownAll() {
dataSource.close();
}
public static class ErrorHurdles extends StateGraph {
public ErrorHurdles(StateFactory stateFactory) {
super(stateFactory);
}
@GraphState(name = "INITIAL", next = "FAILING")
public void initial() {
}
@GraphState(name = "FAILING", next = "OK", resume = ResumeBehavior.RETRY)
public void resumable() {
throw new RuntimeException("Boom!");
}
@GraphState(name = "OK", next = "END")
public void ok() {
}
}
@Test
public void smResumeResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ErrorHurdles(stateFactory).asStateList());
sm.init();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("INITIAL", "FAILING", "ERROR"), states);
}
}

View File

@ -0,0 +1,191 @@
package nu.marginalia.mqsm;
import com.google.gson.GsonBuilder;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.List;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag("slow")
@Testcontainers
public class StateMachineResumeTest {
@Container
static MariaDBContainer<?> mariaDBContainer = new MariaDBContainer<>("mariadb")
.withDatabaseName("WMSA_prod")
.withUsername("wmsa")
.withPassword("wmsa")
.withInitScript("sql/current/11-message-queue.sql")
.withNetworkAliases("mariadb");
static HikariDataSource dataSource;
static MqPersistence persistence;
private String inboxId;
@BeforeEach
public void setUp() {
inboxId = UUID.randomUUID().toString();
}
@BeforeAll
public static void setUpAll() {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(mariaDBContainer.getJdbcUrl());
config.setUsername("wmsa");
config.setPassword("wmsa");
dataSource = new HikariDataSource(config);
persistence = new MqPersistence(dataSource);
}
@AfterAll
public static void tearDownAll() {
dataSource.close();
}
public static class ResumeTrialsGraph extends StateGraph {
public ResumeTrialsGraph(StateFactory stateFactory) {
super(stateFactory);
}
@GraphState(name = "INITIAL", next = "RESUMABLE")
public void initial() {}
@GraphState(name = "RESUMABLE", next = "NON-RESUMABLE", resume = ResumeBehavior.RETRY)
public void resumable() {}
@GraphState(name = "NON-RESUMABLE", next = "OK", resume = ResumeBehavior.ERROR)
public void nonResumable() {}
@GraphState(name = "OK", next = "END")
public void ok() {}
}
@Test
public void smResumeResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ResumeTrialsGraph(stateFactory).asStateList());
persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("RESUMABLE", "NON-RESUMABLE", "OK", "END"), states);
}
@Test
public void smResumeFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("RESUMABLE", "NON-RESUMABLE", "OK", "END"), states);
}
@Test
public void smResumeNonResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("NON-RESUMABLE", "OK", "END"), states);
}
@Test
public void smResumeNonResumableFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("NON-RESUMABLE", "ERROR"), states);
}
@Test
public void smResumeEmptyQueue() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
sm.registerStates(new ResumeTrialsGraph(stateFactory));
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("INITIAL", "RESUMABLE", "NON-RESUMABLE", "OK", "END"), states);
}
}

View File

@ -7,6 +7,9 @@ import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container;
@ -52,19 +55,63 @@ public class StateMachineTest {
dataSource.close();
}
public static class TestGraph extends StateGraph {
public TestGraph(StateFactory stateFactory) {
super(stateFactory);
}
@GraphState(name = "INITIAL", next = "GREET")
public String initial() {
return "World";
}
@GraphState(name = "GREET")
public void greet(String message) {
System.out.println("Hello, " + message + "!");
transition("COUNT-DOWN", 5);
}
@GraphState(name = "COUNT-DOWN", next = "END")
public void countDown(Integer from) {
if (from > 0) {
System.out.println(from);
transition("COUNT-DOWN", from - 1);
}
}
}
@Test
public void testAnnotatedStateGraph() throws Exception {
var stateFactory = new StateFactory(new GsonBuilder().create());
var graph = new TestGraph(stateFactory);
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
sm.registerStates(graph.asStateList());
sm.init();
sm.join();
sm.stop();
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
@Test
public void testStartStopStartStop() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("GREET", "World"));
var initial = stateFactory.create("INITIAL", ResumeBehavior.RETRY, () -> stateFactory.transition("GREET", "World"));
var greet = stateFactory.create("GREET", String.class, (String message) -> {
var greet = stateFactory.create("GREET", ResumeBehavior.RETRY, String.class, (String message) -> {
System.out.println("Hello, " + message + "!");
return stateFactory.transition("COUNT-TO-FIVE", 0);
});
var ctf = stateFactory.create("COUNT-TO-FIVE", Integer.class, (Integer count) -> {
var ctf = stateFactory.create("COUNT-TO-FIVE", ResumeBehavior.RETRY, Integer.class, (Integer count) -> {
System.out.println(count);
if (count < 5) {
return stateFactory.transition("COUNT-TO-FIVE", count + 1);
@ -89,86 +136,4 @@ public class StateMachineTest {
MqTestUtil.getMessages(dataSource, inboxId).forEach(System.out::println);
}
@Test
public void smResumeFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
persistence.sendNewMessage(inboxId, null,"B", "", null);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("B", "C", "END"), states);
}
@Test
public void smResumeFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
long id = persistence.sendNewMessage(inboxId, null,"B", "", null);
persistence.updateMessageState(id, MqMessageState.ACK);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("B", "C", "END"), states);
}
@Test
public void smResumeEmptyQueue() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create());
var initial = stateFactory.create("INITIAL", () -> stateFactory.transition("A"));
var stateA = stateFactory.create("A", () -> stateFactory.transition("B"));
var stateB = stateFactory.create("B", () -> stateFactory.transition("C"));
var stateC = stateFactory.create("C", () -> stateFactory.transition("END"));
sm.registerStates(initial, stateA, stateB, stateC);
sm.resume();
sm.join();
sm.stop();
List<String> states = MqTestUtil.getMessages(dataSource, inboxId)
.stream()
.peek(System.out::println)
.map(MqMessageRow::function)
.toList();
assertEquals(List.of("INITIAL", "A", "B", "C", "END"), states);
}
}