State machines with enum and trait-based approaches
The State pattern allows an object to alter its behavior when its internal state changes. The object will appear to change its class. In Rust, this pattern leverages the type system to ensure state transitions are valid at compile time.
Key concepts:// Two main approaches in Rust:
// 1. Enum-based (compile-time safety, exhaustive matching)
enum ConnectionState {
Disconnected,
Connecting { attempt: u32 },
Connected { session_id: String },
Error { message: String },
}
// 2. Trait-based (runtime flexibility, polymorphism)
trait State {
fn handle(&self, context: &mut Context);
fn next(self: Box<Self>) -> Box<dyn State>;
}
A TCP-like connection with proper state transitions.
use std::time::{Duration, Instant};
use std::net::SocketAddr;
/// Connection events that trigger transitions
#[derive(Debug, Clone)]
pub enum ConnectionEvent {
Connect(SocketAddr),
ConnectionEstablished { session_id: String },
DataReceived(Vec<u8>),
DataSent(usize),
Disconnect,
Timeout,
Error(String),
Reset,
}
/// Connection state enum with associated data
#[derive(Debug, Clone)]
pub enum ConnectionState {
Closed,
Listening {
bind_addr: SocketAddr,
},
SynSent {
remote: SocketAddr,
attempt: u32,
started_at: Instant,
},
SynReceived {
remote: SocketAddr,
},
Established {
remote: SocketAddr,
session_id: String,
bytes_sent: u64,
bytes_received: u64,
},
FinWait1 {
remote: SocketAddr,
},
FinWait2 {
remote: SocketAddr,
},
CloseWait {
remote: SocketAddr,
},
Closing {
remote: SocketAddr,
},
LastAck {
remote: SocketAddr,
},
TimeWait {
started_at: Instant,
},
}
impl ConnectionState {
pub fn is_connected(&self) -> bool {
matches!(self, ConnectionState::Established { .. })
}
pub fn is_closed(&self) -> bool {
matches!(self, ConnectionState::Closed)
}
pub fn remote_addr(&self) -> Option<SocketAddr> {
match self {
ConnectionState::SynSent { remote, .. }
| ConnectionState::SynReceived { remote }
| ConnectionState::Established { remote, .. }
| ConnectionState::FinWait1 { remote }
| ConnectionState::FinWait2 { remote }
| ConnectionState::CloseWait { remote }
| ConnectionState::Closing { remote }
| ConnectionState::LastAck { remote } => Some(*remote),
_ => None,
}
}
}
/// TCP Connection with state machine
pub struct TcpConnection {
state: ConnectionState,
config: ConnectionConfig,
}
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub max_retries: u32,
pub connect_timeout: Duration,
pub time_wait_duration: Duration,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
max_retries: 3,
connect_timeout: Duration::from_secs(30),
time_wait_duration: Duration::from_secs(60),
}
}
}
/// Result of a state transition
#[derive(Debug)]
pub enum TransitionResult {
Ok,
InvalidTransition { from: String, event: String },
ConnectionEstablished { session_id: String },
ConnectionClosed,
DataReceived(Vec<u8>),
RetryNeeded { attempt: u32 },
Error(String),
}
impl TcpConnection {
pub fn new(config: ConnectionConfig) -> Self {
Self {
state: ConnectionState::Closed,
config,
}
}
pub fn state(&self) -> &ConnectionState {
&self.state
}
/// Process an event and transition to the next state
pub fn handle_event(&mut self, event: ConnectionEvent) -> TransitionResult {
let (new_state, result) = self.compute_transition(event);
self.state = new_state;
result
}
fn compute_transition(
&self,
event: ConnectionEvent,
) -> (ConnectionState, TransitionResult) {
match (&self.state, event) {
// From Closed
(ConnectionState::Closed, ConnectionEvent::Connect(addr)) => (
ConnectionState::SynSent {
remote: addr,
attempt: 1,
started_at: Instant::now(),
},
TransitionResult::Ok,
),
// From SynSent
(
ConnectionState::SynSent { remote, attempt, .. },
ConnectionEvent::ConnectionEstablished { session_id },
) => (
ConnectionState::Established {
remote: *remote,
session_id: session_id.clone(),
bytes_sent: 0,
bytes_received: 0,
},
TransitionResult::ConnectionEstablished { session_id },
),
(
ConnectionState::SynSent {
remote,
attempt,
started_at,
},
ConnectionEvent::Timeout,
) => {
if *attempt < self.config.max_retries {
(
ConnectionState::SynSent {
remote: *remote,
attempt: attempt + 1,
started_at: Instant::now(),
},
TransitionResult::RetryNeeded {
attempt: attempt + 1,
},
)
} else {
(
ConnectionState::Closed,
TransitionResult::Error("Connection timed out".to_string()),
)
}
}
(ConnectionState::SynSent { .. }, ConnectionEvent::Error(msg)) => {
(ConnectionState::Closed, TransitionResult::Error(msg))
}
// From Established
(
ConnectionState::Established {
remote,
session_id,
bytes_sent,
bytes_received,
},
ConnectionEvent::DataReceived(data),
) => {
let len = data.len() as u64;
(
ConnectionState::Established {
remote: *remote,
session_id: session_id.clone(),
bytes_sent: *bytes_sent,
bytes_received: bytes_received + len,
},
TransitionResult::DataReceived(data),
)
}
(
ConnectionState::Established {
remote,
session_id,
bytes_sent,
bytes_received,
},
ConnectionEvent::DataSent(len),
) => (
ConnectionState::Established {
remote: *remote,
session_id: session_id.clone(),
bytes_sent: bytes_sent + len as u64,
bytes_received: *bytes_received,
},
TransitionResult::Ok,
),
(ConnectionState::Established { remote, .. }, ConnectionEvent::Disconnect) => (
ConnectionState::FinWait1 { remote: *remote },
TransitionResult::Ok,
),
(ConnectionState::Established { .. }, ConnectionEvent::Reset) => {
(ConnectionState::Closed, TransitionResult::ConnectionClosed)
}
// From FinWait1
(ConnectionState::FinWait1 { remote }, ConnectionEvent::Disconnect) => (
ConnectionState::FinWait2 { remote: *remote },
TransitionResult::Ok,
),
// From FinWait2
(ConnectionState::FinWait2 { .. }, ConnectionEvent::Disconnect) => (
ConnectionState::TimeWait {
started_at: Instant::now(),
},
TransitionResult::Ok,
),
// From TimeWait
(ConnectionState::TimeWait { started_at }, ConnectionEvent::Timeout) => {
if started_at.elapsed() >= self.config.time_wait_duration {
(ConnectionState::Closed, TransitionResult::ConnectionClosed)
} else {
(
ConnectionState::TimeWait {
started_at: *started_at,
},
TransitionResult::Ok,
)
}
}
// Invalid transitions
(state, event) => (
self.state.clone(),
TransitionResult::InvalidTransition {
from: format!("{:?}", state),
event: format!("{:?}", event),
},
),
}
}
/// Helper methods for common operations
pub fn connect(&mut self, addr: SocketAddr) -> TransitionResult {
self.handle_event(ConnectionEvent::Connect(addr))
}
pub fn disconnect(&mut self) -> TransitionResult {
self.handle_event(ConnectionEvent::Disconnect)
}
pub fn send(&mut self, data: &[u8]) -> TransitionResult {
if self.state.is_connected() {
self.handle_event(ConnectionEvent::DataSent(data.len()))
} else {
TransitionResult::Error("Not connected".to_string())
}
}
}
fn main() {
let mut conn = TcpConnection::new(ConnectionConfig::default());
println!("Initial state: {:?}", conn.state());
// Connect
let result = conn.connect("192.168.1.1:8080".parse().unwrap());
println!("After connect: {:?} -> {:?}", result, conn.state());
// Simulate connection established
let result = conn.handle_event(ConnectionEvent::ConnectionEstablished {
session_id: "sess_123".to_string(),
});
println!("Established: {:?} -> {:?}", result, conn.state());
// Send data
let result = conn.send(b"Hello, Server!");
println!("After send: {:?}", result);
// Receive data
let result = conn.handle_event(ConnectionEvent::DataReceived(b"Hello, Client!".to_vec()));
println!("After receive: {:?}", result);
// Disconnect
let result = conn.disconnect();
println!("Disconnecting: {:?} -> {:?}", result, conn.state());
}
An order state machine with business rules.
use std::time::{SystemTime, UNIX_EPOCH};
use std::collections::HashMap;
/// Order item
#[derive(Debug, Clone)]
pub struct OrderItem {
pub product_id: String,
pub quantity: u32,
pub unit_price: f64,
}
/// Payment information
#[derive(Debug, Clone)]
pub struct Payment {
pub method: PaymentMethod,
pub amount: f64,
pub transaction_id: Option<String>,
}
#[derive(Debug, Clone)]
pub enum PaymentMethod {
CreditCard { last_four: String },
PayPal { email: String },
BankTransfer,
}
/// Shipping information
#[derive(Debug, Clone)]
pub struct ShippingInfo {
pub address: String,
pub carrier: String,
pub tracking_number: Option<String>,
}
/// Order state with associated data
#[derive(Debug, Clone)]
pub enum OrderState {
Draft {
items: Vec<OrderItem>,
created_at: u64,
},
PendingPayment {
items: Vec<OrderItem>,
total: f64,
payment_due_by: u64,
},
PaymentProcessing {
items: Vec<OrderItem>,
payment: Payment,
},
PaymentFailed {
items: Vec<OrderItem>,
reason: String,
retry_count: u32,
},
Confirmed {
items: Vec<OrderItem>,
payment: Payment,
confirmed_at: u64,
},
Processing {
items: Vec<OrderItem>,
payment: Payment,
started_at: u64,
},
ReadyToShip {
items: Vec<OrderItem>,
payment: Payment,
packed_at: u64,
},
Shipped {
items: Vec<OrderItem>,
payment: Payment,
shipping: ShippingInfo,
shipped_at: u64,
},
Delivered {
items: Vec<OrderItem>,
payment: Payment,
shipping: ShippingInfo,
delivered_at: u64,
},
Cancelled {
reason: String,
cancelled_at: u64,
refund_status: Option<RefundStatus>,
},
Returned {
items: Vec<OrderItem>,
reason: String,
refund_status: RefundStatus,
},
}
#[derive(Debug, Clone)]
pub enum RefundStatus {
Pending,
Processing,
Completed { transaction_id: String },
Failed { reason: String },
}
/// Order events
#[derive(Debug)]
pub enum OrderEvent {
AddItem(OrderItem),
RemoveItem(String),
Checkout,
SubmitPayment(Payment),
PaymentSucceeded { transaction_id: String },
PaymentFailed { reason: String },
RetryPayment,
Confirm,
StartProcessing,
PackComplete,
Ship(ShippingInfo),
MarkDelivered,
Cancel { reason: String },
RequestReturn { reason: String },
ProcessRefund,
RefundCompleted { transaction_id: String },
}
/// Order transition result
#[derive(Debug)]
pub enum OrderResult {
Ok,
ItemAdded,
ItemRemoved,
CheckoutReady { total: f64 },
PaymentSubmitted,
OrderConfirmed { order_id: String },
Shipped { tracking: String },
Delivered,
Cancelled,
RefundInitiated,
Error(String),
InvalidTransition { state: String, event: String },
}
fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
/// Order state machine
pub struct Order {
pub id: String,
state: OrderState,
}
impl Order {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
state: OrderState::Draft {
items: Vec::new(),
created_at: now(),
},
}
}
pub fn state(&self) -> &OrderState {
&self.state
}
pub fn handle(&mut self, event: OrderEvent) -> OrderResult {
let (new_state, result) = self.transition(event);
self.state = new_state;
result
}
fn transition(&self, event: OrderEvent) -> (OrderState, OrderResult) {
match (&self.state, event) {
// Draft state transitions
(
OrderState::Draft { items, created_at },
OrderEvent::AddItem(item),
) => {
let mut new_items = items.clone();
new_items.push(item);
(
OrderState::Draft {
items: new_items,
created_at: *created_at,
},
OrderResult::ItemAdded,
)
}
(
OrderState::Draft { items, created_at },
OrderEvent::RemoveItem(product_id),
) => {
let new_items: Vec<_> = items
.iter()
.filter(|i| i.product_id != product_id)
.cloned()
.collect();
(
OrderState::Draft {
items: new_items,
created_at: *created_at,
},
OrderResult::ItemRemoved,
)
}
(OrderState::Draft { items, .. }, OrderEvent::Checkout) => {
if items.is_empty() {
return (self.state.clone(), OrderResult::Error("Cart is empty".into()));
}
let total: f64 = items
.iter()
.map(|i| i.unit_price * i.quantity as f64)
.sum();
(
OrderState::PendingPayment {
items: items.clone(),
total,
payment_due_by: now() + 3600, // 1 hour
},
OrderResult::CheckoutReady { total },
)
}
// Pending payment transitions
(
OrderState::PendingPayment { items, total, .. },
OrderEvent::SubmitPayment(payment),
) => {
if (payment.amount - *total).abs() > 0.01 {
return (
self.state.clone(),
OrderResult::Error("Payment amount mismatch".into()),
);
}
(
OrderState::PaymentProcessing {
items: items.clone(),
payment,
},
OrderResult::PaymentSubmitted,
)
}
(OrderState::PendingPayment { .. }, OrderEvent::Cancel { reason }) => (
OrderState::Cancelled {
reason,
cancelled_at: now(),
refund_status: None,
},
OrderResult::Cancelled,
),
// Payment processing transitions
(
OrderState::PaymentProcessing { items, payment },
OrderEvent::PaymentSucceeded { transaction_id },
) => {
let mut payment = payment.clone();
payment.transaction_id = Some(transaction_id);
(
OrderState::Confirmed {
items: items.clone(),
payment,
confirmed_at: now(),
},
OrderResult::OrderConfirmed {
order_id: self.id.clone(),
},
)
}
(
OrderState::PaymentProcessing { items, .. },
OrderEvent::PaymentFailed { reason },
) => (
OrderState::PaymentFailed {
items: items.clone(),
reason,
retry_count: 0,
},
OrderResult::Error("Payment failed".into()),
),
// Payment failed transitions
(
OrderState::PaymentFailed {
items, retry_count, ..
},
OrderEvent::RetryPayment,
) => {
if *retry_count >= 3 {
return (
OrderState::Cancelled {
reason: "Max payment retries exceeded".into(),
cancelled_at: now(),
refund_status: None,
},
OrderResult::Cancelled,
);
}
let total: f64 = items.iter().map(|i| i.unit_price * i.quantity as f64).sum();
(
OrderState::PendingPayment {
items: items.clone(),
total,
payment_due_by: now() + 3600,
},
OrderResult::Ok,
)
}
// Confirmed transitions
(
OrderState::Confirmed { items, payment, .. },
OrderEvent::StartProcessing,
) => (
OrderState::Processing {
items: items.clone(),
payment: payment.clone(),
started_at: now(),
},
OrderResult::Ok,
),
(OrderState::Confirmed { payment, .. }, OrderEvent::Cancel { reason }) => (
OrderState::Cancelled {
reason,
cancelled_at: now(),
refund_status: Some(RefundStatus::Pending),
},
OrderResult::RefundInitiated,
),
// Processing transitions
(
OrderState::Processing { items, payment, .. },
OrderEvent::PackComplete,
) => (
OrderState::ReadyToShip {
items: items.clone(),
payment: payment.clone(),
packed_at: now(),
},
OrderResult::Ok,
),
// Ready to ship transitions
(
OrderState::ReadyToShip { items, payment, .. },
OrderEvent::Ship(shipping),
) => {
let tracking = shipping.tracking_number.clone().unwrap_or_default();
(
OrderState::Shipped {
items: items.clone(),
payment: payment.clone(),
shipping,
shipped_at: now(),
},
OrderResult::Shipped { tracking },
)
}
// Shipped transitions
(
OrderState::Shipped {
items,
payment,
shipping,
..
},
OrderEvent::MarkDelivered,
) => (
OrderState::Delivered {
items: items.clone(),
payment: payment.clone(),
shipping: shipping.clone(),
delivered_at: now(),
},
OrderResult::Delivered,
),
// Delivered transitions
(
OrderState::Delivered { items, .. },
OrderEvent::RequestReturn { reason },
) => (
OrderState::Returned {
items: items.clone(),
reason,
refund_status: RefundStatus::Pending,
},
OrderResult::RefundInitiated,
),
// Invalid transition
(state, event) => (
self.state.clone(),
OrderResult::InvalidTransition {
state: format!("{:?}", state).split_whitespace().next().unwrap().into(),
event: format!("{:?}", event).split_whitespace().next().unwrap().into(),
},
),
}
}
/// Check if order can be cancelled
pub fn can_cancel(&self) -> bool {
matches!(
self.state,
OrderState::Draft { .. }
| OrderState::PendingPayment { .. }
| OrderState::PaymentFailed { .. }
| OrderState::Confirmed { .. }
)
}
/// Get order total if available
pub fn total(&self) -> Option<f64> {
let items = match &self.state {
OrderState::Draft { items, .. }
| OrderState::PendingPayment { items, .. }
| OrderState::PaymentProcessing { items, .. }
| OrderState::PaymentFailed { items, .. }
| OrderState::Confirmed { items, .. }
| OrderState::Processing { items, .. }
| OrderState::ReadyToShip { items, .. }
| OrderState::Shipped { items, .. }
| OrderState::Delivered { items, .. }
| OrderState::Returned { items, .. } => Some(items),
OrderState::Cancelled { .. } => None,
};
items.map(|items| items.iter().map(|i| i.unit_price * i.quantity as f64).sum())
}
}
fn main() {
let mut order = Order::new("ORD-001");
// Add items to cart
order.handle(OrderEvent::AddItem(OrderItem {
product_id: "PROD-1".into(),
quantity: 2,
unit_price: 29.99,
}));
order.handle(OrderEvent::AddItem(OrderItem {
product_id: "PROD-2".into(),
quantity: 1,
unit_price: 49.99,
}));
println!("After adding items: {:?}", order.state());
println!("Total: ${:.2}", order.total().unwrap_or(0.0));
// Checkout
let result = order.handle(OrderEvent::Checkout);
println!("Checkout: {:?}", result);
// Submit payment
let result = order.handle(OrderEvent::SubmitPayment(Payment {
method: PaymentMethod::CreditCard {
last_four: "4242".into(),
},
amount: 109.97,
transaction_id: None,
}));
println!("Payment submitted: {:?}", result);
// Payment succeeded
let result = order.handle(OrderEvent::PaymentSucceeded {
transaction_id: "txn_123456".into(),
});
println!("Payment succeeded: {:?}", result);
// Process order
order.handle(OrderEvent::StartProcessing);
order.handle(OrderEvent::PackComplete);
// Ship
let result = order.handle(OrderEvent::Ship(ShippingInfo {
address: "123 Main St".into(),
carrier: "FedEx".into(),
tracking_number: Some("FX123456789".into()),
}));
println!("Shipped: {:?}", result);
// Deliver
let result = order.handle(OrderEvent::MarkDelivered);
println!("Delivered: {:?}", result);
println!("\nFinal state: {:?}", order.state());
}
A runtime-polymorphic state pattern using trait objects.
use std::time::Duration;
/// Media player context
pub struct MediaPlayer {
state: Box<dyn PlayerState>,
current_track: Option<Track>,
position: Duration,
volume: u8,
playlist: Vec<Track>,
playlist_index: usize,
}
#[derive(Debug, Clone)]
pub struct Track {
pub title: String,
pub artist: String,
pub duration: Duration,
}
/// Player state trait (runtime polymorphism)
pub trait PlayerState: std::fmt::Debug {
fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
fn pause(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState>;
fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState>;
fn name(&self) -> &'static str;
fn can_play(&self) -> bool { false }
fn can_pause(&self) -> bool { false }
fn is_playing(&self) -> bool { false }
}
// Concrete states
#[derive(Debug)]
pub struct StoppedState;
#[derive(Debug)]
pub struct PlayingState;
#[derive(Debug)]
pub struct PausedState;
#[derive(Debug)]
pub struct BufferingState {
pub target_state: BufferingTarget,
}
#[derive(Debug, Clone, Copy)]
pub enum BufferingTarget {
Play,
Seek(Duration),
}
impl PlayerState for StoppedState {
fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.current_track.is_none() && !player.playlist.is_empty() {
player.current_track = Some(player.playlist[0].clone());
player.playlist_index = 0;
}
if player.current_track.is_some() {
player.position = Duration::ZERO;
println!("Starting playback...");
Box::new(BufferingState {
target_state: BufferingTarget::Play,
})
} else {
println!("No track to play");
Box::new(StoppedState)
}
}
fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Already stopped");
Box::new(StoppedState)
}
fn stop(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Already stopped");
Box::new(StoppedState)
}
fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.playlist_index + 1 < player.playlist.len() {
player.playlist_index += 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
println!("Next track: {:?}", player.current_track);
}
Box::new(StoppedState)
}
fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.playlist_index > 0 {
player.playlist_index -= 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
println!("Previous track: {:?}", player.current_track);
}
Box::new(StoppedState)
}
fn seek(&self, _player: &mut MediaPlayer, _position: Duration) -> Box<dyn PlayerState> {
println!("Cannot seek while stopped");
Box::new(StoppedState)
}
fn name(&self) -> &'static str { "Stopped" }
fn can_play(&self) -> bool { true }
}
impl PlayerState for PlayingState {
fn play(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Already playing");
Box::new(PlayingState)
}
fn pause(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Paused at {:?}", player.position);
Box::new(PausedState)
}
fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
player.position = Duration::ZERO;
println!("Stopped");
Box::new(StoppedState)
}
fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.playlist_index + 1 < player.playlist.len() {
player.playlist_index += 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
player.position = Duration::ZERO;
println!("Playing next: {:?}", player.current_track);
Box::new(BufferingState {
target_state: BufferingTarget::Play,
})
} else {
println!("End of playlist");
player.position = Duration::ZERO;
Box::new(StoppedState)
}
}
fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
// If more than 3 seconds in, restart current track
if player.position > Duration::from_secs(3) {
player.position = Duration::ZERO;
println!("Restarting current track");
Box::new(PlayingState)
} else if player.playlist_index > 0 {
player.playlist_index -= 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
player.position = Duration::ZERO;
println!("Playing previous: {:?}", player.current_track);
Box::new(BufferingState {
target_state: BufferingTarget::Play,
})
} else {
player.position = Duration::ZERO;
Box::new(PlayingState)
}
}
fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState> {
println!("Seeking to {:?}...", position);
Box::new(BufferingState {
target_state: BufferingTarget::Seek(position),
})
}
fn name(&self) -> &'static str { "Playing" }
fn can_pause(&self) -> bool { true }
fn is_playing(&self) -> bool { true }
}
impl PlayerState for PausedState {
fn play(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Resuming from {:?}", player.position);
Box::new(PlayingState)
}
fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Already paused");
Box::new(PausedState)
}
fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
player.position = Duration::ZERO;
println!("Stopped");
Box::new(StoppedState)
}
fn next(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.playlist_index + 1 < player.playlist.len() {
player.playlist_index += 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
player.position = Duration::ZERO;
println!("Switched to: {:?} (paused)", player.current_track);
}
Box::new(PausedState)
}
fn previous(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
if player.playlist_index > 0 {
player.playlist_index -= 1;
player.current_track = Some(player.playlist[player.playlist_index].clone());
player.position = Duration::ZERO;
println!("Switched to: {:?} (paused)", player.current_track);
}
Box::new(PausedState)
}
fn seek(&self, player: &mut MediaPlayer, position: Duration) -> Box<dyn PlayerState> {
player.position = position;
println!("Seeked to {:?} (paused)", position);
Box::new(PausedState)
}
fn name(&self) -> &'static str { "Paused" }
fn can_play(&self) -> bool { true }
}
impl PlayerState for BufferingState {
fn play(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Still buffering...");
Box::new(BufferingState {
target_state: self.target_state,
})
}
fn pause(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Buffering cancelled, pausing");
Box::new(PausedState)
}
fn stop(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
player.position = Duration::ZERO;
println!("Buffering cancelled, stopped");
Box::new(StoppedState)
}
fn next(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Cannot skip while buffering");
Box::new(BufferingState {
target_state: self.target_state,
})
}
fn previous(&self, _player: &mut MediaPlayer) -> Box<dyn PlayerState> {
println!("Cannot go back while buffering");
Box::new(BufferingState {
target_state: self.target_state,
})
}
fn seek(&self, _player: &mut MediaPlayer, _position: Duration) -> Box<dyn PlayerState> {
println!("Cannot seek while buffering");
Box::new(BufferingState {
target_state: self.target_state,
})
}
fn name(&self) -> &'static str { "Buffering" }
}
impl BufferingState {
/// Called when buffering completes
pub fn buffering_complete(&self, player: &mut MediaPlayer) -> Box<dyn PlayerState> {
match self.target_state {
BufferingTarget::Play => {
println!("Buffering complete, playing");
Box::new(PlayingState)
}
BufferingTarget::Seek(position) => {
player.position = position;
println!("Buffering complete, playing from {:?}", position);
Box::new(PlayingState)
}
}
}
}
impl MediaPlayer {
pub fn new() -> Self {
Self {
state: Box::new(StoppedState),
current_track: None,
position: Duration::ZERO,
volume: 50,
playlist: Vec::new(),
playlist_index: 0,
}
}
pub fn add_track(&mut self, track: Track) {
self.playlist.push(track);
}
pub fn play(&mut self) {
let new_state = self.state.play(self);
self.state = new_state;
}
pub fn pause(&mut self) {
let new_state = self.state.pause(self);
self.state = new_state;
}
pub fn stop(&mut self) {
let new_state = self.state.stop(self);
self.state = new_state;
}
pub fn next(&mut self) {
let new_state = self.state.next(self);
self.state = new_state;
}
pub fn previous(&mut self) {
let new_state = self.state.previous(self);
self.state = new_state;
}
pub fn seek(&mut self, position: Duration) {
let new_state = self.state.seek(self, position);
self.state = new_state;
}
pub fn state_name(&self) -> &'static str {
self.state.name()
}
pub fn is_playing(&self) -> bool {
self.state.is_playing()
}
/// Simulate buffering completion (in real app, called by async loader)
pub fn on_buffering_complete(&mut self) {
if let Some(buffering) = self.state.as_any().downcast_ref::<BufferingState>() {
let new_state = buffering.buffering_complete(self);
self.state = new_state;
}
}
}
impl Default for MediaPlayer {
fn default() -> Self {
Self::new()
}
}
// Helper trait for downcasting
trait AsAny {
fn as_any(&self) -> &dyn std::any::Any;
}
impl<T: PlayerState + 'static> AsAny for T {
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl dyn PlayerState {
fn as_any(&self) -> &dyn std::any::Any {
// This requires a workaround in real code
// For simplicity, we handle buffering_complete differently
unimplemented!()
}
}
fn main() {
let mut player = MediaPlayer::new();
// Add some tracks
player.add_track(Track {
title: "Track 1".into(),
artist: "Artist A".into(),
duration: Duration::from_secs(180),
});
player.add_track(Track {
title: "Track 2".into(),
artist: "Artist B".into(),
duration: Duration::from_secs(240),
});
player.add_track(Track {
title: "Track 3".into(),
artist: "Artist A".into(),
duration: Duration::from_secs(200),
});
println!("Initial state: {}", player.state_name());
// Play
player.play();
println!("After play: {}", player.state_name());
// Simulate buffering complete
// In real app this would be async
println!("Buffering complete...");
player.state = Box::new(PlayingState);
println!("Now: {}", player.state_name());
// Pause
player.pause();
println!("After pause: {}", player.state_name());
// Resume
player.play();
println!("After resume: {}", player.state_name());
// Next track
player.next();
player.state = Box::new(PlayingState); // Simulate buffer complete
println!("After next: {}, track: {:?}", player.state_name(), player.current_track);
// Stop
player.stop();
println!("After stop: {}", player.state_name());
}
| Approach | Compile-time Safety | Runtime Flexibility | Performance | Best For |
|----------|---------------------|---------------------|-------------|----------|
| Enum | Highest | Low | Best | Fixed states, exhaustive matching |
| Trait objects | Medium | High | Good | Plugin states, complex hierarchies |
| Typestate | Highest | None | Best | API design, builder patterns |
| Generic states | High | Medium | Good | State-specific operations |
// DON'T: Massive match statements scattered throughout code
fn process(state: &State, action: &Action) {
match (state, action) {
// 50+ match arms spread across the codebase
}
}
// DON'T: Stringly-typed states
struct BadStateMachine {
state: String, // "playing", "paused", etc.
}
// DON'T: States that don't encapsulate their data
enum BadState {
Playing,
Paused,
}
struct Player {
state: BadState,
// State-specific data mixed together
buffer_progress: Option<f32>, // Only for buffering
pause_position: Option<u64>, // Only for paused
}
// DO: States own their data
enum GoodState {
Playing { position: u64 },
Paused { position: u64 },
Buffering { progress: f32, target: BufferTarget },
}
| Approach | State Size | Transition Cost | Pattern Matching |
|----------|------------|-----------------|------------------|
| Enum | Largest variant | Move/copy | Optimized |
| Trait object | Pointer + vtable | Virtual call | N/A |
| Typestate | Zero-cost | Compile-time | N/A |
Locked, Unlocked, Open statesRun this code in the official Rust Playground