State Pattern

State machines with enum and trait-based approaches

advanced
statebehavioralstate-machineenum
🎮 Interactive Playground

What is the State Pattern?

The State pattern allows an object to alter its behavior when its internal state changes. The object will appear to change its class. In Rust, this pattern leverages the type system to ensure state transitions are valid at compile time.

Key concepts:
  • Context: Maintains a reference to the current state
  • State trait: Defines behavior that varies with state
  • Concrete states: Implement state-specific behavior
  • Transitions: Moving from one state to another
The Rust perspective:
  • Enum-based state machines offer compile-time transition verification
  • Typestate pattern (see Phase 8) prevents invalid states entirely
  • Trait objects enable runtime state changes
  • Ownership makes state transitions explicit and safe
// Two main approaches in Rust:

// 1. Enum-based (compile-time safety, exhaustive matching)
enum ConnectionState {
    Disconnected,
    Connecting { attempt: u32 },
    Connected { session_id: String },
    Error { message: String },
}

// 2. Trait-based (runtime flexibility, polymorphism)
trait State {
    fn handle(&self, context: &mut Context);
    fn next(self: Box<Self>) -> Box<dyn State>;
}

When to Use State Pattern

Appropriate use cases:
  1. Objects with clearly defined states and transitions
  2. State-dependent behavior that changes significantly
  3. Complex conditional logic based on state
  4. Protocol implementations (TCP, HTTP, etc.)
  5. UI component states (loading, error, success)
  6. Game entity states (idle, walking, attacking)
When to avoid:
  1. Simple boolean flags suffice
  2. States don't have different behaviors
  3. Only 2-3 states with minimal logic

Real-World Example 1: TCP Connection State Machine (Networking)

A TCP-like connection with proper state transitions.

use std::time::{Duration, Instant};
use std::net::SocketAddr;

/// Connection events that trigger transitions
#[derive(Debug, Clone)]
pub enum ConnectionEvent {
    Connect(SocketAddr),
    ConnectionEstablished { session_id: String },
    DataReceived(Vec<u8>),
    DataSent(usize),
    Disconnect,
    Timeout,
    Error(String),
    Reset,
}

/// Connection state enum with associated data
#[derive(Debug, Clone)]
pub enum ConnectionState {
    Closed,
    Listening {
        bind_addr: SocketAddr,
    },
    SynSent {
        remote: SocketAddr,
        attempt: u32,
        started_at: Instant,
    },
    SynReceived {
        remote: SocketAddr,
    },
    Established {
        remote: SocketAddr,
        session_id: String,
        bytes_sent: u64,
        bytes_received: u64,
    },
    FinWait1 {
        remote: SocketAddr,
    },
    FinWait2 {
        remote: SocketAddr,
    },
    CloseWait {
        remote: SocketAddr,
    },
    Closing {
        remote: SocketAddr,
    },
    LastAck {
        remote: SocketAddr,
    },
    TimeWait {
        started_at: Instant,
    },
}

impl ConnectionState {
    pub fn is_connected(&self) -> bool {
        matches!(self, ConnectionState::Established { .. })
    }

    pub fn is_closed(&self) -> bool {
        matches!(self, ConnectionState::Closed)
    }

    pub fn remote_addr(&self) -> Option<SocketAddr> {
        match self {
            ConnectionState::SynSent { remote, .. }
            | ConnectionState::SynReceived { remote }
            | ConnectionState::Established { remote, .. }
            | ConnectionState::FinWait1 { remote }
            | ConnectionState::FinWait2 { remote }
            | ConnectionState::CloseWait { remote }
            | ConnectionState::Closing { remote }
            | ConnectionState::LastAck { remote } => Some(*remote),
            _ => None,
        }
    }
}

/// TCP Connection with state machine
pub struct TcpConnection {
    state: ConnectionState,
    config: ConnectionConfig,
}

#[derive(Debug, Clone)]
pub struct ConnectionConfig {
    pub max_retries: u32,
    pub connect_timeout: Duration,
    pub time_wait_duration: Duration,
}

impl Default for ConnectionConfig {
    fn default() -> Self {
        Self {
            max_retries: 3,
            connect_timeout: Duration::from_secs(30),
            time_wait_duration: Duration::from_secs(60),
        }
    }
}

/// Result of a state transition
#[derive(Debug)]
pub enum TransitionResult {
    Ok,
    InvalidTransition { from: String, event: String },
    ConnectionEstablished { session_id: String },
    ConnectionClosed,
    DataReceived(Vec<u8>),
    RetryNeeded { attempt: u32 },
    Error(String),
}

impl TcpConnection {
    pub fn new(config: ConnectionConfig) -> Self {
        Self {
            state: ConnectionState::Closed,
            config,
        }
    }

    pub fn state(&self) -> &ConnectionState {
        &self.state
    }

    /// Process an event and transition to the next state
    pub fn handle_event(&mut self, event: ConnectionEvent) -> TransitionResult {
        let (new_state, result) = self.compute_transition(event);
        self.state = new_state;
        result
    }

    fn compute_transition(
        &self,
        event: ConnectionEvent,
    ) -> (ConnectionState, TransitionResult) {
        match (&self.state, event) {
            // From Closed
            (ConnectionState::Closed, ConnectionEvent::Connect(addr)) => (
                ConnectionState::SynSent {
                    remote: addr,
                    attempt: 1,
                    started_at: Instant::now(),
                },
                TransitionResult::Ok,
            ),

            // From SynSent
            (
                ConnectionState::SynSent { remote, attempt, .. },
                ConnectionEvent::ConnectionEstablished { session_id },
            ) => (
                ConnectionState::Established {
                    remote: *remote,
                    session_id: session_id.clone(),
                    bytes_sent: 0,
                    bytes_received: 0,
                },
                TransitionResult::ConnectionEstablished { session_id },
            ),

            (
                ConnectionState::SynSent {
                    remote,
                    attempt,
                    started_at,
                },
                ConnectionEvent::Timeout,
            ) => {
                if *attempt < self.config.max_retries {
                    (
                        ConnectionState::SynSent {
                            remote: *remote,
                            attempt: attempt + 1,
                            started_at: Instant::now(),
                        },
                        TransitionResult::RetryNeeded {
                            attempt: attempt + 1,
                        },
                    )
                } else {
                    (
                        ConnectionState::Closed,
                        TransitionResult::Error("Connection timed out".to_string()),
                    )
                }
            }

            (ConnectionState::SynSent { .. }, ConnectionEvent::Error(msg)) => {
                (ConnectionState::Closed, TransitionResult::Error(msg))
            }

            // From Established
            (
                ConnectionState::Established {
                    remote,
                    session_id,
                    bytes_sent,
                    bytes_received,
                },
                ConnectionEvent::DataReceived(data),
            ) => {
                let len = data.len() as u64;
                (
                    ConnectionState::Established {
                        remote: *remote,
                        session_id: session_id.clone(),
                        bytes_sent: *bytes_sent,
                        bytes_received: bytes_received + len,
                    },
                    TransitionResult::DataReceived(data),
                )
            }

            (
                ConnectionState::Established {
                    remote,
                    session_id,
                    bytes_sent,
                    bytes_received,
                },
                ConnectionEvent::DataSent(len),
            ) => (
                ConnectionState::Established {
                    remote: *remote,
                    session_id: session_id.clone(),
                    bytes_sent: bytes_sent + len as u64,
                    bytes_received: *bytes_received,
                },
                TransitionResult::Ok,
            ),

            (ConnectionState::Established { remote, .. }, ConnectionEvent::Disconnect) => (
                ConnectionState::FinWait1 { remote: *remote },
                TransitionResult::Ok,
            ),

            (ConnectionState::Established { .. }, ConnectionEvent::Reset) => {
                (ConnectionState::Closed, TransitionResult::ConnectionClosed)
            }

            // From FinWait1
            (ConnectionState::FinWait1 { remote }, ConnectionEvent::Disconnect) => (
                ConnectionState::FinWait2 { remote: *remote },
                TransitionResult::Ok,
            ),

            // From FinWait2
            (ConnectionState::FinWait2 { .. }, ConnectionEvent::Disconnect) => (
                ConnectionState::TimeWait {
                    started_at: Instant::now(),
                },
                TransitionResult::Ok,
            ),

            // From TimeWait
            (ConnectionState::TimeWait { started_at }, ConnectionEvent::Timeout) => {
                if started_at.elapsed() >= self.config.time_wait_duration {
                    (ConnectionState::Closed, TransitionResult::ConnectionClosed)
                } else {
                    (
                        ConnectionState::TimeWait {
                            started_at: *started_at,
                        },
                        TransitionResult::Ok,
                    )
                }
            }

            // Invalid transitions
            (state, event) => (
                self.state.clone(),
                TransitionResult::InvalidTransition {
                    from: format!("{:?}", state),
                    event: format!("{:?}", event),
                },
            ),
        }
    }

    /// Helper methods for common operations
    pub fn connect(&mut self, addr: SocketAddr) -> TransitionResult {
        self.handle_event(ConnectionEvent::Connect(addr))
    }

    pub fn disconnect(&mut self) -> TransitionResult {
        self.handle_event(ConnectionEvent::Disconnect)
    }

    pub fn send(&mut self, data: &[u8]) -> TransitionResult {
        if self.state.is_connected() {
            self.handle_event(ConnectionEvent::DataSent(data.len()))
        } else {
            TransitionResult::Error("Not connected".to_string())
        }
    }
}

fn main() {
    let mut conn = TcpConnection::new(ConnectionConfig::default());

    println!("Initial state: {:?}", conn.state());

    // Connect
    let result = conn.connect("192.168.1.1:8080".parse().unwrap());
    println!("After connect: {:?} -> {:?}", result, conn.state());

    // Simulate connection established
    let result = conn.handle_event(ConnectionEvent::ConnectionEstablished {
        session_id: "sess_123".to_string(),
    });
    println!("Established: {:?} -> {:?}", result, conn.state());

    // Send data
    let result = conn.send(b"Hello, Server!");
    println!("After send: {:?}", result);

    // Receive data
    let result = conn.handle_event(ConnectionEvent::DataReceived(b"Hello, Client!".to_vec()));
    println!("After receive: {:?}", result);

    // Disconnect
    let result = conn.disconnect();
    println!("Disconnecting: {:?} -> {:?}", result, conn.state());
}

Real-World Example 2: Order Processing Workflow (E-commerce)

An order state machine with business rules.

use std::time::{SystemTime, UNIX_EPOCH};
use std::collections::HashMap;

/// Order item
#[derive(Debug, Clone)]
pub struct OrderItem {
    pub product_id: String,
    pub quantity: u32,
    pub unit_price: f64,
}

/// Payment information
#[derive(Debug, Clone)]
pub struct Payment {
    pub method: PaymentMethod,
    pub amount: f64,
    pub transaction_id: Option<String>,
}

#[derive(Debug, Clone)]
pub enum PaymentMethod {
    CreditCard { last_four: String },
    PayPal { email: String },
    BankTransfer,
}

/// Shipping information
#[derive(Debug, Clone)]
pub struct ShippingInfo {
    pub address: String,
    pub carrier: String,
    pub tracking_number: Option<String>,
}

/// Order state with associated data
#[derive(Debug, Clone)]
pub enum OrderState {
    Draft {
        items: Vec<OrderItem>,
        created_at: u64,
    },
    PendingPayment {
        items: Vec<OrderItem>,
        total: f64,
        payment_due_by: u64,
    },
    PaymentProcessing {
        items: Vec<OrderItem>,
        payment: Payment,
    },
    PaymentFailed {
        items: Vec<OrderItem>,
        reason: String,
        retry_count: u32,
    },
    Confirmed {
        items: Vec<OrderItem>,
        payment: Payment,
        confirmed_at: u64,
    },
    Processing {
        items: Vec<OrderItem>,
        payment: Payment,
        started_at: u64,
    },
    ReadyToShip {
        items: Vec<OrderItem>,
        payment: Payment,
        packed_at: u64,
    },
    Shipped {
        items: Vec<OrderItem>,
        payment: Payment,
        shipping: ShippingInfo,
        shipped_at: u64,
    },
    Delivered {
        items: Vec<OrderItem>,
        payment: Payment,
        shipping: ShippingInfo,
        delivered_at: u64,
    },
    Cancelled {
        reason: String,
        cancelled_at: u64,
        refund_status: Option<RefundStatus>,
    },
    Returned {
        items: Vec<OrderItem>,
        reason: String,
        refund_status: RefundStatus,
    },
}

#[derive(Debug, Clone)]
pub enum RefundStatus {
    Pending,
    Processing,
    Completed { transaction_id: String },
    Failed { reason: String },
}

/// Order events
#[derive(Debug)]
pub enum OrderEvent {
    AddItem(OrderItem),
    RemoveItem(String),
    Checkout,
    SubmitPayment(Payment),
    PaymentSucceeded { transaction_id: String },
    PaymentFailed { reason: String },
    RetryPayment,
    Confirm,
    StartProcessing,
    PackComplete,
    Ship(ShippingInfo),
    MarkDelivered,
    Cancel { reason: String },
    RequestReturn { reason: String },
    ProcessRefund,
    RefundCompleted { transaction_id: String },
}

/// Order transition result
#[derive(Debug)]
pub enum OrderResult {
    Ok,
    ItemAdded,
    ItemRemoved,
    CheckoutReady { total: f64 },
    PaymentSubmitted,
    OrderConfirmed { order_id: String },
    Shipped { tracking: String },
    Delivered,
    Cancelled,
    RefundInitiated,
    Error(String),
    InvalidTransition { state: String, event: String },
}

fn now() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .as_secs()
}

/// Order state machine
pub struct Order {
    pub id: String,
    state: OrderState,
}

impl Order {
    pub fn new(id: impl Into<String>) -> Self {
        Self {
            id: id.into(),
            state: OrderState::Draft {
                items: Vec::new(),
                created_at: now(),
            },
        }
    }

    pub fn state(&self) -> &OrderState {
        &self.state
    }

    pub fn handle(&mut self, event: OrderEvent) -> OrderResult {
        let (new_state, result) = self.transition(event);
        self.state = new_state;
        result
    }

    fn transition(&self, event: OrderEvent) -> (OrderState, OrderResult) {
        match (&self.state, event) {
            // Draft state transitions
            (
                OrderState::Draft { items, created_at },
                OrderEvent::AddItem(item),
            ) => {
                let mut new_items = items.clone();
                new_items.push(item);
                (
                    OrderState::Draft {
                        items: new_items,
                        created_at: *created_at,
                    },
                    OrderResult::ItemAdded,
                )
            }

            (
                OrderState::Draft { items, created_at },
                OrderEvent::RemoveItem(product_id),
            ) => {
                let new_items: Vec<_> = items
                    .iter()
                    .filter(|i| i.product_id != product_id)
                    .cloned()
                    .collect();
                (
                    OrderState::Draft {
                        items: new_items,
                        created_at: *created_at,
                    },
                    OrderResult::ItemRemoved,
                )
            }

            (OrderState::Draft { items, .. }, OrderEvent::Checkout) => {
                if items.is_empty() {
                    return (self.state.clone(), OrderResult::Error("Cart is empty".into()));
                }
                let total: f64 = items
                    .iter()
                    .map(|i| i.unit_price * i.quantity as f64)
                    .sum();
                (
                    OrderState::PendingPayment {
                        items: items.clone(),
                        total,
                        payment_due_by: now() + 3600, // 1 hour
                    },
                    OrderResult::CheckoutReady { total },
                )
            }

            // Pending payment transitions
            (
                OrderState::PendingPayment { items, total, .. },
                OrderEvent::SubmitPayment(payment),
            ) => {
                if (payment.amount - *total).abs() > 0.01 {
                    return (
                        self.state.clone(),
                        OrderResult::Error("Payment amount mismatch".into()),
                    );
                }
                (
                    OrderState::PaymentProcessing {
                        items: items.clone(),
                        payment,
                    },
                    OrderResult::PaymentSubmitted,
                )
            }

            (OrderState::PendingPayment { .. }, OrderEvent::Cancel { reason }) => (
                OrderState::Cancelled {
                    reason,
                    cancelled_at: now(),
                    refund_status: None,
                },
                OrderResult::Cancelled,
            ),

            // Payment processing transitions
            (
                OrderState::PaymentProcessing { items, payment },
                OrderEvent::PaymentSucceeded { transaction_id },
            ) => {
                let mut payment = payment.clone();
                payment.transaction_id = Some(transaction_id);
                (
                    OrderState::Confirmed {
                        items: items.clone(),
                        payment,
                        confirmed_at: now(),
                    },
                    OrderResult::OrderConfirmed {
                        order_id: self.id.clone(),
                    },
                )
            }

            (
                OrderState::PaymentProcessing { items, .. },
                OrderEvent::PaymentFailed { reason },
            ) => (
                OrderState::PaymentFailed {
                    items: items.clone(),
                    reason,
                    retry_count: 0,
                },
                OrderResult::Error("Payment failed".into()),
            ),

            // Payment failed transitions
            (
                OrderState::PaymentFailed {
                    items, retry_count, ..
                },
                OrderEvent::RetryPayment,
            ) => {
                if *retry_count >= 3 {
                    return (
                        OrderState::Cancelled {
                            reason: "Max payment retries exceeded".into(),
                            cancelled_at: now(),
                            refund_status: None,
                        },
                        OrderResult::Cancelled,
                    );
                }
                let total: f64 = items.iter().map(|i| i.unit_price * i.quantity as f64).sum();
                (
                    OrderState::PendingPayment {
                        items: items.clone(),
                        total,
                        payment_due_by: now() + 3600,
                    },
                    OrderResult::Ok,
                )
            }

            // Confirmed transitions
            (
                OrderState::Confirmed { items, payment, .. },
                OrderEvent::StartProcessing,
            ) => (
                OrderState::Processing {
                    items: items.clone(),
                    payment: payment.clone(),
                    started_at: now(),
                },
                OrderResult::Ok,
            ),

            (OrderState::Confirmed { payment, .. }, OrderEvent::Cancel { reason }) => (
                OrderState::Cancelled {
                    reason,
                    cancelled_at: now(),
                    refund_status: Some(RefundStatus::Pending),
                },
                OrderResult::RefundInitiated,
            ),

            // Processing transitions
            (
                OrderState::Processing { items, payment, .. },
                OrderEvent::PackComplete,
            ) => (
                OrderState::ReadyToShip {
                    items: items.clone(),
                    payment: payment.clone(),
                    packed_at: now(),
                },
                OrderResult::Ok,
            ),

            // Ready to ship transitions
            (
                OrderState::ReadyToShip { items, payment, .. },
                OrderEvent::Ship(shipping),
            ) => {
                let tracking = shipping.tracking_number.clone().unwrap_or_default();
                (
                    OrderState::Shipped {
                        items: items.clone(),
                        payment: payment.clone(),
                        shipping,
                        shipped_at: now(),
                    },
                    OrderResult::Shipped { tracking },
                )
            }

            // Shipped transitions
            (
                OrderState::Shipped {
                    items,
                    payment,
                    shipping,
                    ..
                },
                OrderEvent::MarkDelivered,
            ) => (
                OrderState::Delivered {
                    items: items.clone(),
                    payment: payment.clone(),
                    shipping: shipping.clone(),
                    delivered_at: now(),
                },
                OrderResult::Delivered,
            ),

            // Delivered transitions
            (
                OrderState::Delivered { items, .. },
                OrderEvent::RequestReturn { reason },
            ) => (
                OrderState::Returned {
                    items: items.clone(),
                    reason,
                    refund_status: RefundStatus::Pending,
                },
                OrderResult::RefundInitiated,
            ),

            // Invalid transition
            (state, event) => (
                self.state.clone(),
                OrderResult::InvalidTransition {
                    state: format!("{:?}", state).split_whitespace().next().unwrap().into(),
                    event: format!("{:?}", event).split_whitespace().next().unwrap().into(),
                },
            ),
        }
    }

    /// Check if order can be cancelled
    pub fn can_cancel(&self) -> bool {
        matches!(
            self.state,
            OrderState::Draft { .. }
                | OrderState::PendingPayment { .. }
                | OrderState::PaymentFailed { .. }
                | OrderState::Confirmed { .. }
        )
    }

    /// Get order total if available
    pub fn total(&self) -> Option<f64> {
        let items = match &self.state {
            OrderState::Draft { items, .. }
            | OrderState::PendingPayment { items, .. }
            | OrderState::PaymentProcessing { items, .. }
            | OrderState::PaymentFailed { items, .. }
            | OrderState::Confirmed { items, .. }
            | OrderState::Processing { items, .. }
            | OrderState::ReadyToShip { items, .. }
            | OrderState::Shipped { items, .. }
            | OrderState::Delivered { items, .. }
            | OrderState::Returned { items, .. } => Some(items),
            OrderState::Cancelled { .. } => None,
        };

        items.map(|items| items.iter().map(|i| i.unit_price * i.quantity as f64).sum())
    }
}

fn main() {
    let mut order = Order::new("ORD-001");

    // Add items to cart
    order.handle(OrderEvent::AddItem(OrderItem {
        product_id: "PROD-1".into(),
        quantity: 2,
        unit_price: 29.99,
    }));
    order.handle(OrderEvent::AddItem(OrderItem {
        product_id: "PROD-2".into(),
        quantity: 1,
        unit_price: 49.99,
    }));

    println!("After adding items: {:?}", order.state());
    println!("Total: ${:.2}", order.total().unwrap_or(0.0));

    // Checkout
    let result = order.handle(OrderEvent::Checkout);
    println!("Checkout: {:?}", result);

    // Submit payment
    let result = order.handle(OrderEvent::SubmitPayment(Payment {
        method: PaymentMethod::CreditCard {
            last_four: "4242".into(),
        },
        amount: 109.97,
        transaction_id: None,
    }));
    println!("Payment submitted: {:?}", result);

    // Payment succeeded
    let result = order.handle(OrderEvent::PaymentSucceeded {
        transaction_id: "txn_123456".into(),
    });
    println!("Payment succeeded: {:?}", result);

    // Process order
    order.handle(OrderEvent::StartProcessing);
    order.handle(OrderEvent::PackComplete);

    // Ship
    let result = order.handle(OrderEvent::Ship(ShippingInfo {
        address: "123 Main St".into(),
        carrier: "FedEx".into(),
        tracking_number: Some("FX123456789".into()),
    }));
    println!("Shipped: {:?}", result);

    // Deliver
    let result = order.handle(OrderEvent::MarkDelivered);
    println!("Delivered: {:?}", result);

    println!("\nFinal state: {:?}", order.state());
}

Real-World Example 3: Media Player with Trait Objects (Multimedia)

A runtime-polymorphic state pattern using trait objects.

use std::time::Duration;

/// Media player context
pub struct MediaPlayer {
    state: Box<dyn PlayerState>,
    current_track: Option<Track>,
    position: Duration,
    volume: u8,
    playlist: Vec<Track>,
    playlist_index: usize,
}

#[derive(Debug, Clone)]
pub struct Track {
    pub title: String,
    pub artist: String,
    pub duration: Duration,
}

/// Player state trait (runtime polymorphism)
pub trait PlayerState: std::fmt::Debug {
    fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
    fn pause(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
    fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
    fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
    fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
    fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState>;

    fn name(&self) -> &'static str;
    fn can_play(&self) -> bool { false }
    fn can_pause(&self) -> bool { false }
    fn is_playing(&self) -> bool { false }
}

// Concrete states

#[derive(Debug)]
pub struct StoppedState;

#[derive(Debug)]
pub struct PlayingState;

#[derive(Debug)]
pub struct PausedState;

#[derive(Debug)]
pub struct BufferingState {
    pub target_state: BufferingTarget,
}

#[derive(Debug, Clone, Copy)]
pub enum BufferingTarget {
    Play,
    Seek(Duration),
}

impl PlayerState for StoppedState {
    fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.current_track.is_none() && !player.playlist.is_empty() {
            player.current_track = Some(player.playlist[0].clone());
            player.playlist_index = 0;
        }

        if player.current_track.is_some() {
            player.position = Duration::ZERO;
            println!("Starting playback...");
            Box::new(BufferingState {
                target_state: BufferingTarget::Play,
            })
        } else {
            println!("No track to play");
            Box::new(StoppedState)
        }
    }

    fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Already stopped");
        Box::new(StoppedState)
    }

    fn stop(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Already stopped");
        Box::new(StoppedState)
    }

    fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.playlist_index + 1 < player.playlist.len() {
            player.playlist_index += 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            println!("Next track: {:?}", player.current_track);
        }
        Box::new(StoppedState)
    }

    fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.playlist_index > 0 {
            player.playlist_index -= 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            println!("Previous track: {:?}", player.current_track);
        }
        Box::new(StoppedState)
    }

    fn seek(&self, _player: &mut MediaPlayer, _position: Duration) -> Box<dyn PlayerState> {
        println!("Cannot seek while stopped");
        Box::new(StoppedState)
    }

    fn name(&self) -> &'static str { "Stopped" }
    fn can_play(&self) -> bool { true }
}

impl PlayerState for PlayingState {
    fn play(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Already playing");
        Box::new(PlayingState)
    }

    fn pause(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Paused at {:?}", player.position);
        Box::new(PausedState)
    }

    fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        player.position = Duration::ZERO;
        println!("Stopped");
        Box::new(StoppedState)
    }

    fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.playlist_index + 1 < player.playlist.len() {
            player.playlist_index += 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            player.position = Duration::ZERO;
            println!("Playing next: {:?}", player.current_track);
            Box::new(BufferingState {
                target_state: BufferingTarget::Play,
            })
        } else {
            println!("End of playlist");
            player.position = Duration::ZERO;
            Box::new(StoppedState)
        }
    }

    fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        // If more than 3 seconds in, restart current track
        if player.position > Duration::from_secs(3) {
            player.position = Duration::ZERO;
            println!("Restarting current track");
            Box::new(PlayingState)
        } else if player.playlist_index > 0 {
            player.playlist_index -= 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            player.position = Duration::ZERO;
            println!("Playing previous: {:?}", player.current_track);
            Box::new(BufferingState {
                target_state: BufferingTarget::Play,
            })
        } else {
            player.position = Duration::ZERO;
            Box::new(PlayingState)
        }
    }

    fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState> {
        println!("Seeking to {:?}...", position);
        Box::new(BufferingState {
            target_state: BufferingTarget::Seek(position),
        })
    }

    fn name(&self) -> &'static str { "Playing" }
    fn can_pause(&self) -> bool { true }
    fn is_playing(&self) -> bool { true }
}

impl PlayerState for PausedState {
    fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Resuming from {:?}", player.position);
        Box::new(PlayingState)
    }

    fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Already paused");
        Box::new(PausedState)
    }

    fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        player.position = Duration::ZERO;
        println!("Stopped");
        Box::new(StoppedState)
    }

    fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.playlist_index + 1 < player.playlist.len() {
            player.playlist_index += 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            player.position = Duration::ZERO;
            println!("Switched to: {:?} (paused)", player.current_track);
        }
        Box::new(PausedState)
    }

    fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        if player.playlist_index > 0 {
            player.playlist_index -= 1;
            player.current_track = Some(player.playlist[player.playlist_index].clone());
            player.position = Duration::ZERO;
            println!("Switched to: {:?} (paused)", player.current_track);
        }
        Box::new(PausedState)
    }

    fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState> {
        player.position = position;
        println!("Seeked to {:?} (paused)", position);
        Box::new(PausedState)
    }

    fn name(&self) -> &'static str { "Paused" }
    fn can_play(&self) -> bool { true }
}

impl PlayerState for BufferingState {
    fn play(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Still buffering...");
        Box::new(BufferingState {
            target_state: self.target_state,
        })
    }

    fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Buffering cancelled, pausing");
        Box::new(PausedState)
    }

    fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        player.position = Duration::ZERO;
        println!("Buffering cancelled, stopped");
        Box::new(StoppedState)
    }

    fn next(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Cannot skip while buffering");
        Box::new(BufferingState {
            target_state: self.target_state,
        })
    }

    fn previous(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        println!("Cannot go back while buffering");
        Box::new(BufferingState {
            target_state: self.target_state,
        })
    }

    fn seek(&self, _player: &mut MediaPlayer, _position: Duration) -> Box<dyn PlayerState> {
        println!("Cannot seek while buffering");
        Box::new(BufferingState {
            target_state: self.target_state,
        })
    }

    fn name(&self) -> &'static str { "Buffering" }
}

impl BufferingState {
    /// Called when buffering completes
    pub fn buffering_complete(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
        match self.target_state {
            BufferingTarget::Play => {
                println!("Buffering complete, playing");
                Box::new(PlayingState)
            }
            BufferingTarget::Seek(position) => {
                player.position = position;
                println!("Buffering complete, playing from {:?}", position);
                Box::new(PlayingState)
            }
        }
    }
}

impl MediaPlayer {
    pub fn new() -> Self {
        Self {
            state: Box::new(StoppedState),
            current_track: None,
            position: Duration::ZERO,
            volume: 50,
            playlist: Vec::new(),
            playlist_index: 0,
        }
    }

    pub fn add_track(&mut self, track: Track) {
        self.playlist.push(track);
    }

    pub fn play(&mut self) {
        let new_state = self.state.play(self);
        self.state = new_state;
    }

    pub fn pause(&mut self) {
        let new_state = self.state.pause(self);
        self.state = new_state;
    }

    pub fn stop(&mut self) {
        let new_state = self.state.stop(self);
        self.state = new_state;
    }

    pub fn next(&mut self) {
        let new_state = self.state.next(self);
        self.state = new_state;
    }

    pub fn previous(&mut self) {
        let new_state = self.state.previous(self);
        self.state = new_state;
    }

    pub fn seek(&mut self, position: Duration) {
        let new_state = self.state.seek(self, position);
        self.state = new_state;
    }

    pub fn state_name(&self) -> &'static str {
        self.state.name()
    }

    pub fn is_playing(&self) -> bool {
        self.state.is_playing()
    }

    /// Simulate buffering completion (in real app, called by async loader)
    pub fn on_buffering_complete(&mut self) {
        if let Some(buffering) = self.state.as_any().downcast_ref::<BufferingState>() {
            let new_state = buffering.buffering_complete(self);
            self.state = new_state;
        }
    }
}

impl Default for MediaPlayer {
    fn default() -> Self {
        Self::new()
    }
}

// Helper trait for downcasting
trait AsAny {
    fn as_any(&self) -> &dyn std::any::Any;
}

impl<T: PlayerState + 'static> AsAny for T {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

impl dyn PlayerState {
    fn as_any(&self) -> &dyn std::any::Any {
        // This requires a workaround in real code
        // For simplicity, we handle buffering_complete differently
        unimplemented!()
    }
}

fn main() {
    let mut player = MediaPlayer::new();

    // Add some tracks
    player.add_track(Track {
        title: "Track 1".into(),
        artist: "Artist A".into(),
        duration: Duration::from_secs(180),
    });
    player.add_track(Track {
        title: "Track 2".into(),
        artist: "Artist B".into(),
        duration: Duration::from_secs(240),
    });
    player.add_track(Track {
        title: "Track 3".into(),
        artist: "Artist A".into(),
        duration: Duration::from_secs(200),
    });

    println!("Initial state: {}", player.state_name());

    // Play
    player.play();
    println!("After play: {}", player.state_name());

    // Simulate buffering complete
    // In real app this would be async
    println!("Buffering complete...");
    player.state = Box::new(PlayingState);
    println!("Now: {}", player.state_name());

    // Pause
    player.pause();
    println!("After pause: {}", player.state_name());

    // Resume
    player.play();
    println!("After resume: {}", player.state_name());

    // Next track
    player.next();
    player.state = Box::new(PlayingState); // Simulate buffer complete
    println!("After next: {}, track: {:?}", player.state_name(), player.current_track);

    // Stop
    player.stop();
    println!("After stop: {}", player.state_name());
}

Comparison: State Implementation Approaches

| Approach | Compile-time Safety | Runtime Flexibility | Performance | Best For |

|----------|---------------------|---------------------|-------------|----------|

| Enum | Highest | Low | Best | Fixed states, exhaustive matching |

| Trait objects | Medium | High | Good | Plugin states, complex hierarchies |

| Typestate | Highest | None | Best | API design, builder patterns |

| Generic states | High | Medium | Good | State-specific operations |

⚠️ Anti-patterns

// DON'T: Massive match statements scattered throughout code
fn process(state: &State, action: &Action) {
    match (state, action) {
        // 50+ match arms spread across the codebase
    }
}

// DON'T: Stringly-typed states
struct BadStateMachine {
    state: String,  // "playing", "paused", etc.
}

// DON'T: States that don't encapsulate their data
enum BadState {
    Playing,
    Paused,
}
struct Player {
    state: BadState,
    // State-specific data mixed together
    buffer_progress: Option<f32>,  // Only for buffering
    pause_position: Option<u64>,   // Only for paused
}

// DO: States own their data
enum GoodState {
    Playing { position: u64 },
    Paused { position: u64 },
    Buffering { progress: f32, target: BufferTarget },
}

Best Practices

  1. Encapsulate state-specific data in state variants - Don't scatter it
  2. Use exhaustive matching - Compiler catches missing transitions
  3. Make invalid states unrepresentable - Use typestate when possible
  4. Document state diagrams - Visualize transitions
  5. Test all transitions - Property-based testing works well
  6. Consider events vs methods - Events are more flexible, methods are simpler

Performance Characteristics

| Approach | State Size | Transition Cost | Pattern Matching |

|----------|------------|-----------------|------------------|

| Enum | Largest variant | Move/copy | Optimized |

| Trait object | Pointer + vtable | Virtual call | N/A |

| Typestate | Zero-cost | Compile-time | N/A |

Exercises

Beginner

  1. Implement a simple traffic light state machine
  2. Create a door lock with Locked, Unlocked, Open states

Intermediate

  1. Build a vending machine with inventory and payment states
  2. Implement an HTTP request state machine

Advanced

  1. Create a workflow engine with parallel states
  2. Implement a parser state machine with backtracking

Further Reading

Real-World Usage

  • tokio: Connection state machines
  • hyper: HTTP protocol states
  • sqlx: Transaction states
  • regex: NFA/DFA state machines
  • nom: Parser combinator states

🎮 Try it Yourself

🎮

State Pattern - Playground

Run this code in the official Rust Playground