Rate Limiting

Protecting APIs from abuse

intermediate
rate-limitthrottlingmiddleware
🎮 Interactive Playground

What is Rate Limiting?

Rate limiting controls how many requests a client can make in a time window. It protects APIs from abuse, ensures fair usage, and maintains service availability.

The Problem

Implementing rate limiting requires:

  • Identification: Who is making the request?
  • Counting: Track requests per client
  • Storage: Where to store counts (memory, Redis)
  • Response: What to do when limit exceeded

Example Code

use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use axum::{
    body::Body,
    extract::State,
    http::{Request, StatusCode, HeaderMap},
    middleware::Next,
    response::{Response, IntoResponse},
    Json,
};
use serde::Serialize;

/// Rate limit configuration
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
    /// Maximum requests per window
    pub max_requests: u32,
    /// Time window duration
    pub window: Duration,
    /// Include headers in response
    pub include_headers: bool,
}

impl Default for RateLimitConfig {
    fn default() -> Self {
        RateLimitConfig {
            max_requests: 100,
            window: Duration::from_secs(60),
            include_headers: true,
        }
    }
}

/// Rate limiter using sliding window counter
#[derive(Clone)]
pub struct RateLimiter {
    state: Arc<RwLock<HashMap<String, WindowState>>>,
    config: RateLimitConfig,
}

#[derive(Debug, Clone)]
struct WindowState {
    count: u32,
    window_start: Instant,
}

#[derive(Debug)]
pub struct RateLimitResult {
    pub allowed: bool,
    pub remaining: u32,
    pub reset_at: Duration,
    pub retry_after: Option<Duration>,
}

impl RateLimiter {
    pub fn new(config: RateLimitConfig) -> Self {
        RateLimiter {
            state: Arc::new(RwLock::new(HashMap::new())),
            config,
        }
    }

    pub fn check(&self, key: &str) -> RateLimitResult {
        let mut state = self.state.write().unwrap();
        let now = Instant::now();

        let entry = state.entry(key.to_string()).or_insert_with(|| WindowState {
            count: 0,
            window_start: now,
        });

        // Check if window has expired
        if now.duration_since(entry.window_start) >= self.config.window {
            entry.count = 0;
            entry.window_start = now;
        }

        let time_in_window = now.duration_since(entry.window_start);
        let reset_at = self.config.window - time_in_window;

        if entry.count >= self.config.max_requests {
            return RateLimitResult {
                allowed: false,
                remaining: 0,
                reset_at,
                retry_after: Some(reset_at),
            };
        }

        entry.count += 1;
        let remaining = self.config.max_requests - entry.count;

        RateLimitResult {
            allowed: true,
            remaining,
            reset_at,
            retry_after: None,
        }
    }

    /// Clean up expired entries (call periodically)
    pub fn cleanup(&self) {
        let mut state = self.state.write().unwrap();
        let now = Instant::now();

        state.retain(|_, v| {
            now.duration_since(v.window_start) < self.config.window * 2
        });
    }
}

/// Token bucket rate limiter (smoother rate limiting)
#[derive(Clone)]
pub struct TokenBucket {
    state: Arc<RwLock<HashMap<String, BucketState>>>,
    capacity: u32,
    refill_rate: f64, // tokens per second
}

#[derive(Debug, Clone)]
struct BucketState {
    tokens: f64,
    last_refill: Instant,
}

impl TokenBucket {
    pub fn new(capacity: u32, refill_per_second: f64) -> Self {
        TokenBucket {
            state: Arc::new(RwLock::new(HashMap::new())),
            capacity,
            refill_rate: refill_per_second,
        }
    }

    pub fn try_acquire(&self, key: &str, tokens: u32) -> bool {
        let mut state = self.state.write().unwrap();
        let now = Instant::now();

        let bucket = state.entry(key.to_string()).or_insert_with(|| BucketState {
            tokens: self.capacity as f64,
            last_refill: now,
        });

        // Refill tokens based on time passed
        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
        bucket.tokens = (bucket.tokens + elapsed * self.refill_rate)
            .min(self.capacity as f64);
        bucket.last_refill = now;

        if bucket.tokens >= tokens as f64 {
            bucket.tokens -= tokens as f64;
            true
        } else {
            false
        }
    }

    pub fn tokens_remaining(&self, key: &str) -> f64 {
        let state = self.state.read().unwrap();
        state.get(key).map(|b| b.tokens).unwrap_or(self.capacity as f64)
    }
}

/// Leaky bucket rate limiter (constant output rate)
#[derive(Clone)]
pub struct LeakyBucket {
    state: Arc<RwLock<HashMap<String, LeakyState>>>,
    capacity: u32,
    leak_rate: f64, // requests per second
}

#[derive(Debug, Clone)]
struct LeakyState {
    water_level: f64,
    last_leak: Instant,
}

impl LeakyBucket {
    pub fn new(capacity: u32, leak_per_second: f64) -> Self {
        LeakyBucket {
            state: Arc::new(RwLock::new(HashMap::new())),
            capacity,
            leak_rate: leak_per_second,
        }
    }

    pub fn try_add(&self, key: &str) -> bool {
        let mut state = self.state.write().unwrap();
        let now = Instant::now();

        let bucket = state.entry(key.to_string()).or_insert_with(|| LeakyState {
            water_level: 0.0,
            last_leak: now,
        });

        // Leak water based on time passed
        let elapsed = now.duration_since(bucket.last_leak).as_secs_f64();
        bucket.water_level = (bucket.water_level - elapsed * self.leak_rate).max(0.0);
        bucket.last_leak = now;

        if bucket.water_level < self.capacity as f64 {
            bucket.water_level += 1.0;
            true
        } else {
            false
        }
    }
}

/// Middleware for axum
pub async fn rate_limit_middleware(
    State(limiter): State<RateLimiter>,
    request: Request<Body>,
    next: Next,
) -> Response {
    // Extract client identifier (IP, API key, user ID)
    let key = extract_client_key(&request);

    let result = limiter.check(&key);

    if !result.allowed {
        let retry_after = result.retry_after.map(|d| d.as_secs()).unwrap_or(60);

        return (
            StatusCode::TOO_MANY_REQUESTS,
            [
                ("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
                ("X-RateLimit-Remaining", "0".to_string()),
                ("X-RateLimit-Reset", result.reset_at.as_secs().to_string()),
                ("Retry-After", retry_after.to_string()),
            ],
            Json(RateLimitError {
                error: "Too many requests".to_string(),
                retry_after,
            }),
        ).into_response();
    }

    // Add headers to response
    let mut response = next.run(request).await;

    if limiter.config.include_headers {
        let headers = response.headers_mut();
        headers.insert(
            "X-RateLimit-Limit",
            limiter.config.max_requests.to_string().parse().unwrap(),
        );
        headers.insert(
            "X-RateLimit-Remaining",
            result.remaining.to_string().parse().unwrap(),
        );
        headers.insert(
            "X-RateLimit-Reset",
            result.reset_at.as_secs().to_string().parse().unwrap(),
        );
    }

    response
}

fn extract_client_key(request: &Request<Body>) -> String {
    // Try API key header first
    if let Some(api_key) = request.headers().get("X-API-Key") {
        if let Ok(key) = api_key.to_str() {
            return format!("api:{}", key);
        }
    }

    // Fall back to IP address
    request
        .headers()
        .get("X-Forwarded-For")
        .and_then(|h| h.to_str().ok())
        .map(|s| s.split(',').next().unwrap_or("unknown").trim().to_string())
        .unwrap_or_else(|| "unknown".to_string())
}

#[derive(Serialize)]
struct RateLimitError {
    error: String,
    retry_after: u64,
}

/// Tiered rate limiting (different limits for different users)
#[derive(Debug, Clone)]
pub enum UserTier {
    Free,
    Basic,
    Premium,
    Enterprise,
}

impl UserTier {
    pub fn rate_limit(&self) -> (u32, Duration) {
        match self {
            UserTier::Free => (60, Duration::from_secs(60)),       // 60/min
            UserTier::Basic => (300, Duration::from_secs(60)),     // 300/min
            UserTier::Premium => (1000, Duration::from_secs(60)),  // 1000/min
            UserTier::Enterprise => (10000, Duration::from_secs(60)), // 10000/min
        }
    }
}

/// Endpoint-specific rate limits
pub struct EndpointLimits {
    limiters: HashMap<String, RateLimiter>,
}

impl EndpointLimits {
    pub fn new() -> Self {
        let mut limiters = HashMap::new();

        // Different limits for different endpoints
        limiters.insert(
            "/api/search".to_string(),
            RateLimiter::new(RateLimitConfig {
                max_requests: 30,
                window: Duration::from_secs(60),
                include_headers: true,
            }),
        );

        limiters.insert(
            "/api/upload".to_string(),
            RateLimiter::new(RateLimitConfig {
                max_requests: 10,
                window: Duration::from_secs(60),
                include_headers: true,
            }),
        );

        EndpointLimits { limiters }
    }

    pub fn check(&self, endpoint: &str, key: &str) -> Option<RateLimitResult> {
        self.limiters.get(endpoint).map(|l| l.check(key))
    }
}

impl Default for EndpointLimits {
    fn default() -> Self {
        Self::new()
    }
}

fn main() {
    // Sliding window example
    let limiter = RateLimiter::new(RateLimitConfig {
        max_requests: 5,
        window: Duration::from_secs(10),
        include_headers: true,
    });

    for i in 1..=7 {
        let result = limiter.check("user:123");
        println!(
            "Request {}: allowed={}, remaining={}",
            i, result.allowed, result.remaining
        );
    }

    // Token bucket example
    println!("\n--- Token Bucket ---");
    let bucket = TokenBucket::new(10, 2.0); // 10 capacity, 2 tokens/sec

    for i in 1..=12 {
        let allowed = bucket.try_acquire("user:456", 1);
        println!("Request {}: allowed={}", i, allowed);
    }

    // Leaky bucket example
    println!("\n--- Leaky Bucket ---");
    let leaky = LeakyBucket::new(5, 1.0); // 5 capacity, 1 leak/sec

    for i in 1..=7 {
        let allowed = leaky.try_add("user:789");
        println!("Request {}: allowed={}", i, allowed);
    }
}

Why This Works

  1. Multiple algorithms: Choose based on use case
  2. Headers: Inform clients of their limit status
  3. Tiered limits: Different users, different limits
  4. Cleanup: Prevent memory growth

Rate Limiting Algorithms

| Algorithm | Behavior | Best For |

|-----------|----------|----------|

| Fixed Window | Reset at window boundary | Simple cases |

| Sliding Window | Smooth rate over time | Most APIs |

| Token Bucket | Allows bursts up to capacity | Bursty traffic |

| Leaky Bucket | Constant output rate | Smooth processing |

Response Headers

| Header | Description |

|--------|-------------|

| X-RateLimit-Limit | Max requests in window |

| X-RateLimit-Remaining | Requests left in window |

| X-RateLimit-Reset | Seconds until reset |

| Retry-After | When to retry (if limited) |

⚠️ Anti-patterns

// DON'T: Limit by IP only (shared IPs, proxies)
let key = request.ip();

// DO: Use multiple identifiers
let key = if let Some(api_key) = get_api_key(&request) {
    format!("api:{}", api_key)
} else if let Some(user_id) = get_user_id(&request) {
    format!("user:{}", user_id)
} else {
    format!("ip:{}", get_ip(&request))
};

// DON'T: Same limit for all endpoints
// /api/search (expensive) same as /api/ping (cheap)

// DO: Endpoint-specific limits
let limit = match path {
    "/api/search" => 30,
    "/api/upload" => 10,
    _ => 100,
};

Exercises

  1. Implement Redis-backed rate limiting for distributed systems
  2. Add rate limiting based on request cost (not just count)
  3. Create a circuit breaker that triggers on rate limit errors
  4. Implement gradual backoff for repeated violations

🎮 Try it Yourself

🎮

Rate Limiting - Playground

Run this code in the official Rust Playground