MQFSM Usability WIP

This commit is contained in:
Viktor Lofgren 2023-07-06 13:33:11 +02:00
parent d89db10645
commit f0a8ca440f
13 changed files with 119 additions and 141 deletions

View File

@ -3,8 +3,8 @@ package nu.marginalia.mqsm;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.google.inject.Singleton; import com.google.inject.Singleton;
import nu.marginalia.mqsm.graph.ResumeBehavior;
import nu.marginalia.mqsm.state.MachineState; import nu.marginalia.mqsm.state.MachineState;
import nu.marginalia.mqsm.state.ResumeBehavior;
import nu.marginalia.mqsm.state.StateTransition; import nu.marginalia.mqsm.state.StateTransition;
import java.util.function.Function; import java.util.function.Function;
@ -74,4 +74,52 @@ public class StateFactory {
public StateTransition transition(String state, Object message) { public StateTransition transition(String state, Object message) {
return StateTransition.to(state, gson.toJson(message)); return StateTransition.to(state, gson.toJson(message));
} }
public static class ErrorState implements MachineState {
@Override
public String name() { return "ERROR"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}
public static class FinalState implements MachineState {
@Override
public String name() { return "END"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}
public static class ResumingState implements MachineState {
@Override
public String name() { return "RESUMING"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return false; }
}
} }

View File

@ -7,7 +7,8 @@ import nu.marginalia.mq.inbox.MqInboxResponse;
import nu.marginalia.mq.inbox.MqSubscription; import nu.marginalia.mq.inbox.MqSubscription;
import nu.marginalia.mq.outbox.MqOutbox; import nu.marginalia.mq.outbox.MqOutbox;
import nu.marginalia.mq.persistence.MqPersistence; import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.StateGraph; import nu.marginalia.mqsm.graph.ResumeBehavior;
import nu.marginalia.mqsm.graph.AbstractStateGraph;
import nu.marginalia.mqsm.state.*; import nu.marginalia.mqsm.state.*;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -30,13 +31,16 @@ public class StateMachine {
private final String queueName; private final String queueName;
private MachineState state; private MachineState state;
private final MachineState errorState = new ErrorState(); private final MachineState errorState = new StateFactory.ErrorState();
private final MachineState finalState = new FinalState(); private final MachineState finalState = new StateFactory.FinalState();
private final MachineState resumingState = new ResumingState(); private final MachineState resumingState = new StateFactory.ResumingState();
private final Map<String, MachineState> allStates = new HashMap<>(); private final Map<String, MachineState> allStates = new HashMap<>();
public StateMachine(MqPersistence persistence, String queueName, UUID instanceUUID) { public StateMachine(MqPersistence persistence,
String queueName,
UUID instanceUUID,
AbstractStateGraph stateGraph) {
this.queueName = queueName; this.queueName = queueName;
smInbox = new MqInbox(persistence, queueName, instanceUUID, Executors.newSingleThreadExecutor()); smInbox = new MqInbox(persistence, queueName, instanceUUID, Executors.newSingleThreadExecutor());
@ -45,28 +49,24 @@ public class StateMachine {
smInbox.subscribe(new StateEventSubscription()); smInbox.subscribe(new StateEventSubscription());
registerStates(List.of(errorState, finalState, resumingState)); registerStates(List.of(errorState, finalState, resumingState));
registerStates(stateGraph);
for (var declaredState : stateGraph.declaredStates()) {
if (!allStates.containsKey(declaredState)) {
throw new IllegalArgumentException("State " + declaredState + " is not defined in the state graph");
}
}
} }
/** Register the state graph */ /** Register the state graph */
public void registerStates(MachineState... states) { void registerStates(List<MachineState> states) {
if (state != null) {
throw new IllegalStateException("Cannot register states after state machine has been initialized");
}
for (var state : states) { for (var state : states) {
allStates.put(state.name(), state); allStates.put(state.name(), state);
} }
} }
/** Register the state graph */ /** Register the state graph */
public void registerStates(List<MachineState> states) { void registerStates(AbstractStateGraph states) {
for (var state : states) {
allStates.put(state.name(), state);
}
}
/** Register the state graph */
public void registerStates(StateGraph states) {
registerStates(states.asStateList()); registerStates(states.asStateList());
} }

View File

@ -1,22 +1,20 @@
package nu.marginalia.mqsm.graph; package nu.marginalia.mqsm.graph;
import nu.marginalia.mqsm.StateFactory;
import nu.marginalia.mqsm.state.MachineState; import nu.marginalia.mqsm.state.MachineState;
import nu.marginalia.mqsm.StateFactory;
import nu.marginalia.mqsm.state.StateTransition; import nu.marginalia.mqsm.state.StateTransition;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.List;
public abstract class StateGraph { public abstract class AbstractStateGraph {
private final StateFactory stateFactory; private final StateFactory stateFactory;
private static final Logger logger = LoggerFactory.getLogger(StateGraph.class); private static final Logger logger = LoggerFactory.getLogger(AbstractStateGraph.class);
public StateGraph(StateFactory stateFactory) { public AbstractStateGraph(StateFactory stateFactory) {
this.stateFactory = stateFactory; this.stateFactory = stateFactory;
} }
@ -38,6 +36,19 @@ public abstract class StateGraph {
throw new ControlFlowException("ERROR", ex.getClass().getSimpleName() + ":" + ex.getMessage()); throw new ControlFlowException("ERROR", ex.getClass().getSimpleName() + ":" + ex.getMessage());
} }
public Set<String> declaredStates() {
Set<String> ret = new HashSet<>();
for (var method : getClass().getMethods()) {
var gs = method.getAnnotation(GraphState.class);
if (gs != null) {
ret.add(gs.name());
ret.add(gs.next());
}
}
return ret;
}
public List<MachineState> asStateList() { public List<MachineState> asStateList() {
List<MachineState> ret = new ArrayList<>(); List<MachineState> ret = new ArrayList<>();

View File

@ -1,8 +1,6 @@
package nu.marginalia.mqsm.graph; package nu.marginalia.mqsm.graph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import java.lang.annotation.Retention; import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;

View File

@ -0,0 +1,8 @@
package nu.marginalia.mqsm.graph;
public enum ResumeBehavior {
/** Retry the state on resume */
RETRY,
/** Jump to ERROR on resume if the message has been acknowledged */
ERROR
}

View File

@ -1,17 +0,0 @@
package nu.marginalia.mqsm.state;
public class ErrorState implements MachineState {
@Override
public String name() { return "ERROR"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}

View File

@ -1,17 +0,0 @@
package nu.marginalia.mqsm.state;
public class FinalState implements MachineState {
@Override
public String name() { return "END"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return true; }
}

View File

@ -1,9 +1,13 @@
package nu.marginalia.mqsm.state; package nu.marginalia.mqsm.state;
import nu.marginalia.mqsm.graph.ResumeBehavior;
public interface MachineState { public interface MachineState {
String name(); String name();
StateTransition next(String message); StateTransition next(String message);
ResumeBehavior resumeBehavior(); ResumeBehavior resumeBehavior();
boolean isFinal(); boolean isFinal();
} }

View File

@ -1,6 +0,0 @@
package nu.marginalia.mqsm.state;
public enum ResumeBehavior {
RETRY,
ERROR
}

View File

@ -1,17 +0,0 @@
package nu.marginalia.mqsm.state;
public class ResumingState implements MachineState {
@Override
public String name() { return "RESUMING"; }
@Override
public StateTransition next(String message) {
throw new UnsupportedOperationException();
}
@Override
public ResumeBehavior resumeBehavior() { return ResumeBehavior.RETRY; }
@Override
public boolean isFinal() { return false; }
}

View File

@ -4,12 +4,11 @@ import com.google.gson.GsonBuilder;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageRow; import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil; import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence; import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState; import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph; import nu.marginalia.mqsm.graph.AbstractStateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior; import nu.marginalia.mqsm.graph.ResumeBehavior;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer; import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
@ -55,7 +54,7 @@ public class StateMachineErrorTest {
dataSource.close(); dataSource.close();
} }
public static class ErrorHurdles extends StateGraph { public static class ErrorHurdles extends AbstractStateGraph {
public ErrorHurdles(StateFactory stateFactory) { public ErrorHurdles(StateFactory stateFactory) {
super(stateFactory); super(stateFactory);
@ -71,17 +70,15 @@ public class StateMachineErrorTest {
} }
@GraphState(name = "OK", next = "END") @GraphState(name = "OK", next = "END")
public void ok() { public void ok() {
} }
} }
@Test @Test
public void smResumeResumableFromNew() throws Exception { public void smResumeResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ErrorHurdles(stateFactory));
sm.registerStates(new ErrorHurdles(stateFactory).asStateList());
sm.init(); sm.init();

View File

@ -8,8 +8,8 @@ import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil; import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence; import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState; import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph; import nu.marginalia.mqsm.graph.AbstractStateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior; import nu.marginalia.mqsm.graph.ResumeBehavior;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer; import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
@ -55,7 +55,7 @@ public class StateMachineResumeTest {
dataSource.close(); dataSource.close();
} }
public static class ResumeTrialsGraph extends StateGraph { public static class ResumeTrialsGraph extends AbstractStateGraph {
public ResumeTrialsGraph(StateFactory stateFactory) { public ResumeTrialsGraph(StateFactory stateFactory) {
super(stateFactory); super(stateFactory);
@ -75,10 +75,8 @@ public class StateMachineResumeTest {
@Test @Test
public void smResumeResumableFromNew() throws Exception { public void smResumeResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
sm.registerStates(new ResumeTrialsGraph(stateFactory).asStateList());
persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
@ -98,10 +96,8 @@ public class StateMachineResumeTest {
@Test @Test
public void smResumeFromAck() throws Exception { public void smResumeFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
sm.registerStates(new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null); long id = persistence.sendNewMessage(inboxId, null,"RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK); persistence.updateMessageState(id, MqMessageState.ACK);
@ -123,10 +119,8 @@ public class StateMachineResumeTest {
@Test @Test
public void smResumeNonResumableFromNew() throws Exception { public void smResumeNonResumableFromNew() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
sm.registerStates(new ResumeTrialsGraph(stateFactory));
persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
@ -146,10 +140,8 @@ public class StateMachineResumeTest {
@Test @Test
public void smResumeNonResumableFromAck() throws Exception { public void smResumeNonResumableFromAck() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
sm.registerStates(new ResumeTrialsGraph(stateFactory));
long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null); long id = persistence.sendNewMessage(inboxId, null,"NON-RESUMABLE", "", null);
persistence.updateMessageState(id, MqMessageState.ACK); persistence.updateMessageState(id, MqMessageState.ACK);
@ -170,10 +162,8 @@ public class StateMachineResumeTest {
@Test @Test
public void smResumeEmptyQueue() throws Exception { public void smResumeEmptyQueue() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new ResumeTrialsGraph(stateFactory));
sm.registerStates(new ResumeTrialsGraph(stateFactory));
sm.resume(); sm.resume();

View File

@ -3,19 +3,15 @@ package nu.marginalia.mqsm;
import com.google.gson.GsonBuilder; import com.google.gson.GsonBuilder;
import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import nu.marginalia.mq.MqMessageRow;
import nu.marginalia.mq.MqMessageState;
import nu.marginalia.mq.MqTestUtil; import nu.marginalia.mq.MqTestUtil;
import nu.marginalia.mq.persistence.MqPersistence; import nu.marginalia.mq.persistence.MqPersistence;
import nu.marginalia.mqsm.graph.GraphState; import nu.marginalia.mqsm.graph.GraphState;
import nu.marginalia.mqsm.graph.StateGraph; import nu.marginalia.mqsm.graph.AbstractStateGraph;
import nu.marginalia.mqsm.state.ResumeBehavior;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.testcontainers.containers.MariaDBContainer; import org.testcontainers.containers.MariaDBContainer;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.List;
import java.util.UUID; import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -55,7 +51,7 @@ public class StateMachineTest {
dataSource.close(); dataSource.close();
} }
public static class TestGraph extends StateGraph { public static class TestGraph extends AbstractStateGraph {
public TestGraph(StateFactory stateFactory) { public TestGraph(StateFactory stateFactory) {
super(stateFactory); super(stateFactory);
} }
@ -87,8 +83,8 @@ public class StateMachineTest {
var graph = new TestGraph(stateFactory); var graph = new TestGraph(stateFactory);
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID()); var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), graph);
sm.registerStates(graph.asStateList()); sm.registerStates(graph);
sm.init(); sm.init();
@ -101,34 +97,17 @@ public class StateMachineTest {
@Test @Test
public void testStartStopStartStop() throws Exception { public void testStartStopStartStop() throws Exception {
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID());
var stateFactory = new StateFactory(new GsonBuilder().create()); var stateFactory = new StateFactory(new GsonBuilder().create());
var sm = new StateMachine(persistence, inboxId, UUID.randomUUID(), new TestGraph(stateFactory));
var initial = stateFactory.create("INITIAL", ResumeBehavior.RETRY, () -> stateFactory.transition("GREET", "World"));
var greet = stateFactory.create("GREET", ResumeBehavior.RETRY, String.class, (String message) -> {
System.out.println("Hello, " + message + "!");
return stateFactory.transition("COUNT-TO-FIVE", 0);
});
var ctf = stateFactory.create("COUNT-TO-FIVE", ResumeBehavior.RETRY, Integer.class, (Integer count) -> {
System.out.println(count);
if (count < 5) {
return stateFactory.transition("COUNT-TO-FIVE", count + 1);
} else {
return stateFactory.transition("END");
}
});
sm.registerStates(initial, greet, ctf);
sm.init(); sm.init();
Thread.sleep(300); Thread.sleep(150);
sm.stop(); sm.stop();
var sm2 = new StateMachine(persistence, inboxId, UUID.randomUUID()); System.out.println("-------------------- ");
sm2.registerStates(initial, greet, ctf);
var sm2 = new StateMachine(persistence, inboxId, UUID.randomUUID(), new TestGraph(stateFactory));
sm2.resume(); sm2.resume();
sm2.join(); sm2.join();
sm2.stop(); sm2.stop();