package fsm import ( "errors" "fmt" "sync" ) var ( // ErrPreventTransition should be returned from Enter if you don't want to perform a state transition and // everything that needed to be done has been done in Enter callback. ErrPreventTransition = errors.New("prevents transition; this is not an error") // ErrStateDoesNotExist will be returned if provided state ID does not exist in this machine. ErrStateDoesNotExist = errors.New("state does not exist") ) // MachineHandleInput should be provided to IMachine's Handle method. This function can do two very useful things: // - It can modify Machine's payload with input data. // - It can act as a router by changing Machine's state via provided controls. type MachineHandleInput[T any] func(*T, MachineControls[*T]) // IMachine is a Machine contract. The Machine should be able to do the following: // - Move to another state (usually called by the IState itself). // - Handle the state input. // - Reset the machine. type IMachine[T any] interface { MachineControls[*T] // Handle the state input. Handle func will accept the current payload and modify it based on user input. Handle() error // Reset the machine to its initial state. Reset() } // Machine is a finite-state machine implementation. type Machine[T any] struct { lock sync.Mutex payload *T preHandle MachineHandleInput[T] state StateID initialState StateID initialPayload T states map[StateID]IState[T] errHandler ErrorState[T] } // New machine. func New[T any](initialState StateID, initialPayload T, preHandle MachineHandleInput[T], states []IState[T], errHandler ErrorState[T]) IMachine[T] { stateMap := make(map[StateID]IState[T], len(states)) for _, state := range states { stateMap[state.ID()] = state } pl := initialPayload return &Machine[T]{ state: initialState, payload: &pl, preHandle: preHandle, initialState: initialState, initialPayload: initialPayload, states: stateMap, errHandler: errHandler, } } // Move to another state. // Internal: should never be called outside state callbacks. func (m *Machine[T]) Move(id StateID, payload *T) error { if id == m.state { return nil } next, err := m.loadState(id, payload) if next == nil || err != nil { return err } cur, err := m.loadState(m.state, payload) if err != nil { return err } defer m.lock.Unlock() m.lock.Lock() if cur != nil { cur.Exit(payload) } if err := next.Enter(payload, m); err != nil { if errors.Is(err, ErrPreventTransition) { return nil } m.Reset() return err } m.state = id m.payload = payload return nil } // Handle the input. func (m *Machine[T]) Handle() error { defer m.lock.Unlock() m.lock.Lock() if m.preHandle != nil { m.preHandle(m.payload, m) } st, err := m.loadState(m.state, m.payload) if st == nil || err != nil { return err } st.Handle(m.payload, m) return nil } // Reset the machine. func (m *Machine[T]) Reset() { pl := m.initialPayload m.payload = &pl m.state = m.initialState } func (m *Machine[T]) loadState(id StateID, payload *T) (IState[T], error) { if id == NilStateID { return nil, nil } st, ok := m.states[id] if !ok { return nil, m.fatalError(fmt.Errorf("%w: %s", ErrStateDoesNotExist, id), id, payload) } return st, nil } func (m *Machine[T]) fatalError(err error, id StateID, payload *T) error { if m.errHandler != nil { m.errHandler.Handle(err, m.state, id, payload, m) } return nil }