vegapokerbot/pkg/fsm/machine.go

136 lines
3.6 KiB
Go
Raw Normal View History

2024-05-13 16:57:55 +03:00
package fsm
import (
"errors"
"fmt"
2024-05-14 14:44:41 +03:00
"gitea.neur0tx.site/Neur0toxine/vegapokerbot/pkg/types"
2024-05-13 16:57:55 +03:00
)
var (
// ErrPreventTransition should be returned from Enter if you don't want to perform a state transition and
// everything that needed to be done has been done in Enter callback.
ErrPreventTransition = errors.New("prevents transition; this is not an error")
// ErrStateDoesNotExist will be returned if provided state ID does not exist in this machine.
ErrStateDoesNotExist = errors.New("state does not exist")
)
2024-05-14 14:44:41 +03:00
// MachineStateRouter should be provided to IMachine. This function can do two very useful things:
// - It can modify Machine's payload.
2024-05-13 16:57:55 +03:00
// - It can act as a router by changing Machine's state via provided controls.
2024-05-14 14:44:41 +03:00
type MachineStateRouter[T any] func(*T, MachineControlsWithState[T])
// MachineStateProvider provided to every Handle call. It can be used to set the initial state of the machine or to
// update existing machine state.
type MachineStateProvider[T any] func(*T) *T
2024-05-13 16:57:55 +03:00
// IMachine is a Machine contract. The Machine should be able to do the following:
// - Move to another state (usually called by the IState itself).
// - Handle the state input.
// - Reset the machine.
type IMachine[T any] interface {
2024-05-14 14:44:41 +03:00
MachineControlsWithState[T]
2024-05-13 16:57:55 +03:00
// Handle the state input. Handle func will accept the current payload and modify it based on user input.
2024-05-14 14:44:41 +03:00
Handle(MachineStateProvider[T]) error
2024-05-13 16:57:55 +03:00
// Reset the machine to its initial state.
Reset()
}
// Machine is a finite-state machine implementation.
type Machine[T any] struct {
2024-05-14 14:44:41 +03:00
payload *T
stateRouter MachineStateRouter[T]
state StateID
initialState StateID
states *types.Map[StateID, IState[T]]
errHandler ErrorState[T]
2024-05-13 16:57:55 +03:00
}
// New machine.
2024-05-14 14:44:41 +03:00
func New[T any](initialState StateID, router MachineStateRouter[T], states []IState[T], errHandler ErrorState[T]) IMachine[T] {
stateMap := types.NewMap[StateID, IState[T]]()
2024-05-13 16:57:55 +03:00
for _, state := range states {
2024-05-14 14:44:41 +03:00
stateMap.Set(state.ID(), state)
2024-05-13 16:57:55 +03:00
}
return &Machine[T]{
2024-05-14 14:44:41 +03:00
state: initialState,
stateRouter: router,
initialState: initialState,
states: stateMap,
errHandler: errHandler,
2024-05-13 16:57:55 +03:00
}
}
// Move to another state.
// Internal: should never be called outside state callbacks.
func (m *Machine[T]) Move(id StateID, payload *T) error {
if id == m.state {
return nil
}
next, err := m.loadState(id, payload)
if next == nil || err != nil {
return err
}
cur, err := m.loadState(m.state, payload)
if err != nil {
return err
}
if cur != nil {
cur.Exit(payload)
}
if err := next.Enter(payload, m); err != nil {
if errors.Is(err, ErrPreventTransition) {
return nil
}
m.Reset()
return err
}
m.state = id
2024-05-13 17:18:22 +03:00
m.payload = payload
2024-05-13 16:57:55 +03:00
return nil
}
// Handle the input.
2024-05-14 14:44:41 +03:00
func (m *Machine[T]) Handle(provider MachineStateProvider[T]) error {
if provider != nil {
m.payload = provider(m.payload)
}
if m.stateRouter != nil {
m.stateRouter(m.payload, m)
}
2024-05-13 16:57:55 +03:00
st, err := m.loadState(m.state, m.payload)
if st == nil || err != nil {
return err
}
st.Handle(m.payload, m)
return nil
}
2024-05-14 14:44:41 +03:00
// State of the Machine.
func (m *Machine[T]) State() *T {
return m.payload
}
2024-05-13 16:57:55 +03:00
// Reset the machine.
func (m *Machine[T]) Reset() {
2024-05-14 14:44:41 +03:00
m.payload = nil
2024-05-13 16:57:55 +03:00
m.state = m.initialState
}
func (m *Machine[T]) loadState(id StateID, payload *T) (IState[T], error) {
if id == NilStateID {
return nil, nil
}
2024-05-14 14:44:41 +03:00
st, ok := m.states.Get(id)
2024-05-13 16:57:55 +03:00
if !ok {
return nil, m.fatalError(fmt.Errorf("%w: %s", ErrStateDoesNotExist, id), id, payload)
}
return st, nil
}
func (m *Machine[T]) fatalError(err error, id StateID, payload *T) error {
if m.errHandler != nil {
m.errHandler.Handle(err, m.state, id, payload, m)
}
return nil
}