Site icon Benji's Blog

Optionally typechecked StateMachines

Many things can be modelled as finite state machines. Particularly things where you’d naturally use “state” in the name e.g. the current state of an order, or delivery status. We often model these as enums.

enum OrderStatus {
    Pending,
    CheckingOut,
    Purchased,
    Shipped,
    Cancelled,
    Delivered,
    Failed,
    Refunded
}

Enums are great for restricting our order status to only valid states. However, usually there are only certain transitions that are valid. We can’t go from Delivered to Failed. Nor would we go straight from Pending to Delivered. Maybe we can transition from Purchased to either Shipped or Cancelled.

Using enum values we cannot restrict to the transitions to only those that we desire. It would be nice to also let the compiler help us out by not letting us choose invalid transitions in our code.

We can, however, achieve this if we use a class hierarchy to represent our states instead, and it can still be fairly concise. There are other reasons for using regular classes, they allow us to store and even capture state from the surrounding context.

Here’s a way we could model the above enum as a class heirarchy with the valid transitions.

interface OrderStatus extends State {}
static class Pending     implements OrderStatus, BiTransitionTo {}
static class CheckingOut implements OrderStatus, BiTransitionTo {}
static class Purchased   implements OrderStatus, BiTransitionTo {}
static class Shipped     implements OrderStatus, BiTransitionTo {}
static class Delivered   implements OrderStatus, TransitionTo {}
static class Cancelled   implements OrderStatus {}
static class Failed      implements OrderStatus {}
static class Refunded    implements OrderStatus {}

We’ve declared an OrderStatus interface and then created implementations of OrderStatus for each valid state. We’ve then encoded the valid transitions as other interface implementations. There’s a TransitionTo<State> and BiTransitionTo<State1,State2>, or TriTransitionTo<State1,State2,State3> depending on the number of valid transitions from that state. We need differently named interfaces for different numbers of transitions because Java doesn’t support variance on the number of generic type parameters.

Compile-time checking valid transitions

Now we can create the TransitionTo/BiTransitionTo interfaces, which can give us the functionality to transition to a new state (but only if it is valid)

We might imagine an api like this where we can choose which state to transition to

new Pending()
    .transitionTo(CheckingOut.class)
    .transitionTo(Purchased.class)
    .transitionTo(Refunded.class) // <-- can we make this line fail to compile?

This turns out to be a little tricky, but not impossible, due to type erasure.

Let's try to implement BiTransitionTo interface with the two valid transition.

public interface BiTransitionTo {
    default T transitionTo(Class type) { ... }
    default U transitionTo(Class type) { ... }
}

Both of these transitionTo methods have the same erasure. So we can't do it quite like this. However, if we can encourage the consumer of our API to pass a lambda, there is a way to work around this same erasure problem.

So how about this API, where instead of passing class literals we pass constructor references. It looks similarly clean, but constructor references are basically lambdas so we can avoid type erasure.

new Pending()
    .transition(CheckingOut::new)
    .transition(Purchased::new)
    .transition(Refunded::new) // <-- Now we can make this fail to compile

In order to make this work the trick is to create a new interface type for each valid transition within our BiTransitionTo interface

public interface BiTransitionTo {
    interface OneTransition extends Supplier { }
    default T transition(OneTransition constructor) { ... }
    interface TwoTransition extends Supplier { }
    default U transition(TwoTransition constructor) { ... }
}

Supplier<T> is a functional interface in the java.util.function that is equivalent to a no-args constructor reference. By creating two interfaces that extend this we can overload the transition() method twice, allowing both methods to be passed a constructor reference without the two methods having the same erasure.

Runtime checking

Sometimes we might not be able to know at compile-time what state our statemachine is in. Perhaps a Customer has a field of type OrderStatus that we serialize to a database. We would need to be able to try a transition at runtime, and fail in some manner if the transition is not valid.

This is also possible using the TransitionTo<NewState> approach outlined above. Since supertype parameters are available at runtime, we can implement a tryTransition() method that uses reflection to check which transitions are permitted.

First we'll need a way of finding the valid transition types. We'll add it to our State base interface.

default List> validTransitionTypes() {
    return asList(getClass().getGenericInterfaces())
        .stream()
        .filter(type -> type instanceof ParameterizedType)
        .map(type -> (ParameterizedType) type)
        .filter(TransitionTo::isTransition) 
        .flatMap(type -> asList(type.getActualTypeArguments()).stream())
        .map(type -> (Class>) type)
        .collect(toList());
}

Note the isTransition filter. Since we have multiple transition interfaces - TransitionTo<T>, BiTransitionTo<T,U>, TriTransitionTo<T,U,V> etc, we need a way of marking them as all specifying transitions. I've used an annotation

@Retention(RUNTIME)
@Target(ElementType.TYPE)
public @interface Transition {

}
static boolean isTransition(ParameterizedType type) {
     Class> cls = (Class>)type.getRawType();
     return cls.getAnnotationsByType(Transition.class).length > 0;
}

@Transition
public interface TriTransitionTo...

Once we have validTransitionTypes() we can find which transitions are valid at runtime.

static class Pending implements OrderStatus, BiTransitionTo {}
@Test
public void finding_valid_transitions_at_runtime() {
    Pending pending = new Pending();
    assertEquals(
        asList(CheckingOut.class, Cancelled.class),
        pending.validTransitionTypes()
    );
}

Now that we have the valid types, tryTransition() needs to check whether the requested transition is to one of those types.

This is a little tricky, but since we're passing a lambda we make it a lambda-type-token and use reflection to find the type parameter of the lambda.

Our implementation then looks something like

interface NextState extends Supplier, MethodFinder {
    default Class type() {
        return (Class) getContainingClass();
    }
}
default  T tryTransition(NextState desired) {
    if (validTransitionTypes.contains(desired.type())) {
        return desired.get();
    }

    throw new IllegalStateTransitionException();
}

We can make it a bit nicer by allowing the caller to specify the exception to throw on error, like an Optional's orElseThrow. We can also allow the caller to ignore failed transitions.

@Test
public void runtime_checked_transition() {
    OrderStatus state = new Pending();
    assertTrue(state instanceof Pending);
    state = state
        .tryTransition(CheckingOut::new)
        .unchecked();
    assertTrue(state instanceof CheckingOut);
}

Since we've transitioned into a known state (or thrown an exception) with tryTransition we could then chain compile-time checked transitions on the end.

@Test
public void runtime_checked_transition() {
    OrderStatus state = new Pending();
    assertTrue(state instanceof Pending);
    state = state
        .tryTransition(CheckingOut::new)
        .unchecked()
        .transition(Purchased::new); // This will be permitted if the tryTransition succeeds.
    assertTrue(state instanceof CheckingOut);
}

We can even let people ignore transition failures if they wish, just by catching the exception and returning the original value.

@Test
public void runtime_checked_transition_ignoring_failure() {
    OrderStatus state = new Pending();
    assertTrue(state instanceof Pending);
    state = state
        .tryTransition(Refunded::new)
        .ignoreIfInvalid();
    assertFalse(state instanceof Refunded);
    assertTrue(state instanceof Pending);
}

Adding Behaviour

Since our states are classes, we can add behaviour to them.

For instance we could add a notifyProgress() method to our OrderStatus, with different implementations in each state.

interface OrderStatus extends State {
    default void notifyProgress(Customer customer, EmailSender sender) {}
}
static class Purchased implements OrderStatus, BiTransitionTo {
    public void notifyProgress(Customer customer, EmailSender emailSender) {
        emailSender.sendEmail("fulfillment@mycompany.com", "Customer order pending");
        emailSender.sendEmail(customer.email(), "Your order is on its way");
    }
}
...
OrderStatus status = new Pending();
status.notifyProgress(customer, sender); // Does nothing
status = status
    .tryTransition(CheckingOut::new)
    .unchecked()
    .transition(Purchased::new);
status.notifyProgress(customer, sender) ; // sends emails

Then we can call notifyProgress on any OrderStatus instance and it will notify differently depending on which implementation is active.

Internal Transitions

One of the ways to make most use of the typechecked transitions is to have the transitions internally within the state. e.g. in a state machine for the Regex "A+B" the A state can transition either

If we do this we can make them typechecked even though we don't know what the string we're matching in advance is.

static class A implements APlusB, TriTransitionTo {
    public APlusB match(String s) {
        if (s.length() < 1) return transition(NoMatch::new);
        if (s.charAt(0) == 'A') return transition(A::new).match(s.substring(1));
        if (s.charAt(0) == 'B') return transition(B::new).match(s.substring(1));
        return transition(NoMatch::new);
    }
}

Full example here

Capturing State

If we use non-static classes we could also capture state from the enclosing class. Supposing these OrderStatuses are contained within an Order class that already has an EmailSender available, we'd no longer need to pass in the emailSender and the customer to the notifyProgress() method.

class Order {
    EmailSender emailSender;
    Customer customer;
    class Purchased implements OrderStatus, BiTransitionTo {
        public void notifyProgress() {
            emailSender.sendEmail("fulfillment@mycompany.com", "Customer order pending");
            emailSender.sendEmail(customer.email(), "Your order is on its way");
        }
    }
}

Guards

Another feature we might want is the ability to execute some code before transitioning into a new state or after transitioning into a new state. This is something we can add to our base State interface. Let's add two methods beforeTransition() and afterTransition()

interface State {
    default void afterTransition(T from) {}
    default void beforeTransition(T to) {}
}

We can then update our transition implementation to invoke these guard methods before and after a transition occurs.

We could use this to log all transitions into the Failure state.

class Failed implements OrderStatus {
    @Override
    public void afterTransition(OrderStatus from) {
        failureLog.warning("Oh bother! failed from " + from.getClass().getSimpleName());
    }
}

We could also combine state capturing and guard methods to build a stateful-state machine that updates its state on transition instead of just returning the new state. Here's an example where we use a guard method to mutate the state of lightSwitch after each transition.

class LightExample {
    Switch lightSwitch = new Off();

    public class Switch implements State

Show me the code

The code is on github if you'd like to play with it/see full executable examples