187 lines
4.6 KiB
Go
187 lines
4.6 KiB
Go
package fsm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"gitea.neur0tx.site/Neur0toxine/vegapokerbot/pkg/types"
|
|
"runtime"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
// ErrPreventTransition should be returned from Enter if you don't want to perform a state transition and
|
|
// everything that needed to be done has been done in Enter callback.
|
|
ErrPreventTransition = errors.New("prevents transition; this is not an error")
|
|
// ErrStateDoesNotExist will be returned if provided state ID does not exist in this machine.
|
|
ErrStateDoesNotExist = errors.New("state does not exist")
|
|
)
|
|
|
|
// MachineStateRouter should be provided to IMachine. This function can do two very useful things:
|
|
// - It can modify Machine's payload.
|
|
// - It can act as a router by changing Machine's state via provided controls.
|
|
type MachineStateRouter[T any] func(*T, MachineControlsWithState[T])
|
|
|
|
// MachineStateProvider provided to every Handle call. It can be used to set the initial state of the machine or to
|
|
// update existing machine state.
|
|
type MachineStateProvider[T any] func(*T) *T
|
|
|
|
// IMachine is a Machine contract. The Machine should be able to do the following:
|
|
// - Move to another state (usually called by the IState itself).
|
|
// - Handle the state input.
|
|
// - Reset the machine.
|
|
type IMachine[T any] interface {
|
|
MachineControlsWithState[T]
|
|
// Enqueue the state input. Handle func will accept the current payload and modify it based on user input.
|
|
Enqueue(MachineStateProvider[T])
|
|
// Reset the machine to its initial state.
|
|
Reset()
|
|
}
|
|
|
|
// Machine is a finite-state machine implementation.
|
|
type Machine[T any] struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
payload *T
|
|
transitions types.Queue[MachineStateProvider[T]]
|
|
stateRouter MachineStateRouter[T]
|
|
state StateID
|
|
initialState StateID
|
|
states *types.Map[StateID, IState[T]]
|
|
errHandler ErrorState[T]
|
|
handleNow bool
|
|
}
|
|
|
|
// New machine.
|
|
func New[T any](initialState StateID, router MachineStateRouter[T], states []IState[T], errHandler ErrorState[T]) IMachine[T] {
|
|
stateMap := types.NewMap[StateID, IState[T]]()
|
|
for _, state := range states {
|
|
stateMap.Set(state.ID(), state)
|
|
}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
m := &Machine[T]{
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
state: initialState,
|
|
transitions: types.NewQueue[MachineStateProvider[T]](),
|
|
stateRouter: router,
|
|
initialState: initialState,
|
|
states: stateMap,
|
|
errHandler: errHandler,
|
|
}
|
|
go m.handleQueue()
|
|
runtime.SetFinalizer(m, func(m *Machine[T]) {
|
|
m.cancel()
|
|
})
|
|
return m
|
|
}
|
|
|
|
// Move to another state.
|
|
// Internal: should never be called outside state callbacks.
|
|
func (m *Machine[T]) Move(id StateID, payload *T) error {
|
|
if id == m.state {
|
|
return nil
|
|
}
|
|
next, err := m.loadState(id, payload)
|
|
if next == nil || err != nil {
|
|
return err
|
|
}
|
|
cur, err := m.loadState(m.state, payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if cur != nil {
|
|
cur.Exit(payload)
|
|
}
|
|
if err := next.Enter(payload, m); err != nil {
|
|
if errors.Is(err, ErrPreventTransition) {
|
|
return nil
|
|
}
|
|
m.Reset()
|
|
return err
|
|
}
|
|
m.state = id
|
|
m.payload = payload
|
|
return nil
|
|
}
|
|
|
|
func (m *Machine[T]) MoveForHandling(id StateID, payload *T) error {
|
|
m.handleNow = true
|
|
return m.Move(id, payload)
|
|
}
|
|
|
|
func (m *Machine[T]) Enqueue(provider MachineStateProvider[T]) {
|
|
m.transitions.Enqueue(provider)
|
|
}
|
|
|
|
func (m *Machine[T]) handleQueue() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
go m.handleQueue()
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-m.ctx.Done():
|
|
return
|
|
case <-time.After(time.Nanosecond * 10):
|
|
}
|
|
transition := m.transitions.Dequeue()
|
|
if transition == nil {
|
|
continue
|
|
}
|
|
m.handle(transition)
|
|
}
|
|
}
|
|
|
|
// Handle the input.
|
|
func (m *Machine[T]) handle(provider MachineStateProvider[T]) {
|
|
if provider != nil {
|
|
m.payload = provider(m.payload)
|
|
}
|
|
if m.stateRouter != nil {
|
|
m.stateRouter(m.payload, m)
|
|
}
|
|
st, err := m.loadState(m.state, m.payload)
|
|
if st == nil || err != nil {
|
|
return
|
|
}
|
|
for {
|
|
st.Handle(m.payload, m)
|
|
if m.handleNow { // MoveForHandling was called, trying to handle again.
|
|
m.handleNow = false
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// State of the Machine.
|
|
func (m *Machine[T]) State() *T {
|
|
return m.payload
|
|
}
|
|
|
|
// Reset the machine.
|
|
func (m *Machine[T]) Reset() {
|
|
m.payload = nil
|
|
m.state = m.initialState
|
|
}
|
|
|
|
func (m *Machine[T]) loadState(id StateID, payload *T) (IState[T], error) {
|
|
if id == NilStateID {
|
|
return nil, nil
|
|
}
|
|
st, ok := m.states.Get(id)
|
|
if !ok {
|
|
return nil, m.fatalError(fmt.Errorf("%w: %s", ErrStateDoesNotExist, id), id, payload)
|
|
}
|
|
return st, nil
|
|
}
|
|
|
|
func (m *Machine[T]) fatalError(err error, id StateID, payload *T) error {
|
|
if m.errHandler != nil {
|
|
m.errHandler.Handle(err, m.state, id, payload, m)
|
|
}
|
|
return nil
|
|
}
|