package fsm import ( "context" "errors" "fmt" "gitea.neur0tx.site/Neur0toxine/vegapokerbot/pkg/types" "runtime" "time" ) var ( // ErrPreventTransition should be returned from Enter if you don't want to perform a state transition and // everything that needed to be done has been done in Enter callback. ErrPreventTransition = errors.New("prevents transition; this is not an error") // ErrStateDoesNotExist will be returned if provided state ID does not exist in this machine. ErrStateDoesNotExist = errors.New("state does not exist") ) // MachineStateRouter should be provided to IMachine. This function can do two very useful things: // - It can modify Machine's payload. // - It can act as a router by changing Machine's state via provided controls. type MachineStateRouter[T any] func(*T, MachineControlsWithState[T]) // MachineStateProvider provided to every Handle call. It can be used to set the initial state of the machine or to // update existing machine state. type MachineStateProvider[T any] func(*T) *T // IMachine is a Machine contract. The Machine should be able to do the following: // - Move to another state (usually called by the IState itself). // - Handle the state input. // - Reset the machine. type IMachine[T any] interface { MachineControlsWithState[T] // Enqueue the state input. Handle func will accept the current payload and modify it based on user input. Enqueue(MachineStateProvider[T]) // Reset the machine to its initial state. Reset() } // Machine is a finite-state machine implementation. type Machine[T any] struct { ctx context.Context cancel context.CancelFunc payload *T transitions types.Queue[MachineStateProvider[T]] stateRouter MachineStateRouter[T] state StateID initialState StateID states *types.Map[StateID, IState[T]] errHandler ErrorState[T] handleNow bool } // New machine. func New[T any](initialState StateID, router MachineStateRouter[T], states []IState[T], errHandler ErrorState[T]) IMachine[T] { stateMap := types.NewMap[StateID, IState[T]]() for _, state := range states { stateMap.Set(state.ID(), state) } ctx, cancel := context.WithCancel(context.Background()) m := &Machine[T]{ ctx: ctx, cancel: cancel, state: initialState, transitions: types.NewQueue[MachineStateProvider[T]](), stateRouter: router, initialState: initialState, states: stateMap, errHandler: errHandler, } go m.handleQueue() runtime.SetFinalizer(m, func(m *Machine[T]) { m.cancel() }) return m } // Move to another state. // Internal: should never be called outside state callbacks. func (m *Machine[T]) Move(id StateID, payload *T) error { if id == m.state { return nil } next, err := m.loadState(id, payload) if next == nil || err != nil { return err } cur, err := m.loadState(m.state, payload) if err != nil { return err } if cur != nil { cur.Exit(payload) } if err := next.Enter(payload, m); err != nil { if errors.Is(err, ErrPreventTransition) { return nil } m.Reset() return err } m.state = id m.payload = payload return nil } func (m *Machine[T]) MoveForHandling(id StateID, payload *T) error { m.handleNow = true return m.Move(id, payload) } func (m *Machine[T]) Enqueue(provider MachineStateProvider[T]) { m.transitions.Enqueue(provider) } func (m *Machine[T]) handleQueue() { defer func() { if r := recover(); r != nil { go m.handleQueue() } }() for { select { case <-m.ctx.Done(): return case <-time.After(time.Nanosecond * 10): } transition := m.transitions.Dequeue() if transition == nil { continue } m.handle(transition) } } // Handle the input. func (m *Machine[T]) handle(provider MachineStateProvider[T]) { if provider != nil { m.payload = provider(m.payload) } if m.stateRouter != nil { m.stateRouter(m.payload, m) } st, err := m.loadState(m.state, m.payload) if st == nil || err != nil { return } for { st.Handle(m.payload, m) if m.handleNow { // MoveForHandling was called, trying to handle again. m.handleNow = false continue } return } } // State of the Machine. func (m *Machine[T]) State() *T { return m.payload } // Reset the machine. func (m *Machine[T]) Reset() { m.payload = nil m.state = m.initialState } func (m *Machine[T]) loadState(id StateID, payload *T) (IState[T], error) { if id == NilStateID { return nil, nil } st, ok := m.states.Get(id) if !ok { return nil, m.fatalError(fmt.Errorf("%w: %s", ErrStateDoesNotExist, id), id, payload) } return st, nil } func (m *Machine[T]) fatalError(err error, id StateID, payload *T) error { if m.errHandler != nil { m.errHandler.Handle(err, m.state, id, payload, m) } return nil }