Protecting APIs from abuse
Rate limiting controls how many requests a client can make in a time window. It protects APIs from abuse, ensures fair usage, and maintains service availability.
Implementing rate limiting requires:
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, HeaderMap},
middleware::Next,
response::{Response, IntoResponse},
Json,
};
use serde::Serialize;
/// Rate limit configuration
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
/// Maximum requests per window
pub max_requests: u32,
/// Time window duration
pub window: Duration,
/// Include headers in response
pub include_headers: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
RateLimitConfig {
max_requests: 100,
window: Duration::from_secs(60),
include_headers: true,
}
}
}
/// Rate limiter using sliding window counter
#[derive(Clone)]
pub struct RateLimiter {
state: Arc<RwLock<HashMap<String, WindowState>>>,
config: RateLimitConfig,
}
#[derive(Debug, Clone)]
struct WindowState {
count: u32,
window_start: Instant,
}
#[derive(Debug)]
pub struct RateLimitResult {
pub allowed: bool,
pub remaining: u32,
pub reset_at: Duration,
pub retry_after: Option<Duration>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
RateLimiter {
state: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn check(&self, key: &str) -> RateLimitResult {
let mut state = self.state.write().unwrap();
let now = Instant::now();
let entry = state.entry(key.to_string()).or_insert_with(|| WindowState {
count: 0,
window_start: now,
});
// Check if window has expired
if now.duration_since(entry.window_start) >= self.config.window {
entry.count = 0;
entry.window_start = now;
}
let time_in_window = now.duration_since(entry.window_start);
let reset_at = self.config.window - time_in_window;
if entry.count >= self.config.max_requests {
return RateLimitResult {
allowed: false,
remaining: 0,
reset_at,
retry_after: Some(reset_at),
};
}
entry.count += 1;
let remaining = self.config.max_requests - entry.count;
RateLimitResult {
allowed: true,
remaining,
reset_at,
retry_after: None,
}
}
/// Clean up expired entries (call periodically)
pub fn cleanup(&self) {
let mut state = self.state.write().unwrap();
let now = Instant::now();
state.retain(|_, v| {
now.duration_since(v.window_start) < self.config.window * 2
});
}
}
/// Token bucket rate limiter (smoother rate limiting)
#[derive(Clone)]
pub struct TokenBucket {
state: Arc<RwLock<HashMap<String, BucketState>>>,
capacity: u32,
refill_rate: f64, // tokens per second
}
#[derive(Debug, Clone)]
struct BucketState {
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new(capacity: u32, refill_per_second: f64) -> Self {
TokenBucket {
state: Arc::new(RwLock::new(HashMap::new())),
capacity,
refill_rate: refill_per_second,
}
}
pub fn try_acquire(&self, key: &str, tokens: u32) -> bool {
let mut state = self.state.write().unwrap();
let now = Instant::now();
let bucket = state.entry(key.to_string()).or_insert_with(|| BucketState {
tokens: self.capacity as f64,
last_refill: now,
});
// Refill tokens based on time passed
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * self.refill_rate)
.min(self.capacity as f64);
bucket.last_refill = now;
if bucket.tokens >= tokens as f64 {
bucket.tokens -= tokens as f64;
true
} else {
false
}
}
pub fn tokens_remaining(&self, key: &str) -> f64 {
let state = self.state.read().unwrap();
state.get(key).map(|b| b.tokens).unwrap_or(self.capacity as f64)
}
}
/// Leaky bucket rate limiter (constant output rate)
#[derive(Clone)]
pub struct LeakyBucket {
state: Arc<RwLock<HashMap<String, LeakyState>>>,
capacity: u32,
leak_rate: f64, // requests per second
}
#[derive(Debug, Clone)]
struct LeakyState {
water_level: f64,
last_leak: Instant,
}
impl LeakyBucket {
pub fn new(capacity: u32, leak_per_second: f64) -> Self {
LeakyBucket {
state: Arc::new(RwLock::new(HashMap::new())),
capacity,
leak_rate: leak_per_second,
}
}
pub fn try_add(&self, key: &str) -> bool {
let mut state = self.state.write().unwrap();
let now = Instant::now();
let bucket = state.entry(key.to_string()).or_insert_with(|| LeakyState {
water_level: 0.0,
last_leak: now,
});
// Leak water based on time passed
let elapsed = now.duration_since(bucket.last_leak).as_secs_f64();
bucket.water_level = (bucket.water_level - elapsed * self.leak_rate).max(0.0);
bucket.last_leak = now;
if bucket.water_level < self.capacity as f64 {
bucket.water_level += 1.0;
true
} else {
false
}
}
}
/// Middleware for axum
pub async fn rate_limit_middleware(
State(limiter): State<RateLimiter>,
request: Request<Body>,
next: Next,
) -> Response {
// Extract client identifier (IP, API key, user ID)
let key = extract_client_key(&request);
let result = limiter.check(&key);
if !result.allowed {
let retry_after = result.retry_after.map(|d| d.as_secs()).unwrap_or(60);
return (
StatusCode::TOO_MANY_REQUESTS,
[
("X-RateLimit-Limit", limiter.config.max_requests.to_string()),
("X-RateLimit-Remaining", "0".to_string()),
("X-RateLimit-Reset", result.reset_at.as_secs().to_string()),
("Retry-After", retry_after.to_string()),
],
Json(RateLimitError {
error: "Too many requests".to_string(),
retry_after,
}),
).into_response();
}
// Add headers to response
let mut response = next.run(request).await;
if limiter.config.include_headers {
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Limit",
limiter.config.max_requests.to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
result.remaining.to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Reset",
result.reset_at.as_secs().to_string().parse().unwrap(),
);
}
response
}
fn extract_client_key(request: &Request<Body>) -> String {
// Try API key header first
if let Some(api_key) = request.headers().get("X-API-Key") {
if let Ok(key) = api_key.to_str() {
return format!("api:{}", key);
}
}
// Fall back to IP address
request
.headers()
.get("X-Forwarded-For")
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or("unknown").trim().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
#[derive(Serialize)]
struct RateLimitError {
error: String,
retry_after: u64,
}
/// Tiered rate limiting (different limits for different users)
#[derive(Debug, Clone)]
pub enum UserTier {
Free,
Basic,
Premium,
Enterprise,
}
impl UserTier {
pub fn rate_limit(&self) -> (u32, Duration) {
match self {
UserTier::Free => (60, Duration::from_secs(60)), // 60/min
UserTier::Basic => (300, Duration::from_secs(60)), // 300/min
UserTier::Premium => (1000, Duration::from_secs(60)), // 1000/min
UserTier::Enterprise => (10000, Duration::from_secs(60)), // 10000/min
}
}
}
/// Endpoint-specific rate limits
pub struct EndpointLimits {
limiters: HashMap<String, RateLimiter>,
}
impl EndpointLimits {
pub fn new() -> Self {
let mut limiters = HashMap::new();
// Different limits for different endpoints
limiters.insert(
"/api/search".to_string(),
RateLimiter::new(RateLimitConfig {
max_requests: 30,
window: Duration::from_secs(60),
include_headers: true,
}),
);
limiters.insert(
"/api/upload".to_string(),
RateLimiter::new(RateLimitConfig {
max_requests: 10,
window: Duration::from_secs(60),
include_headers: true,
}),
);
EndpointLimits { limiters }
}
pub fn check(&self, endpoint: &str, key: &str) -> Option<RateLimitResult> {
self.limiters.get(endpoint).map(|l| l.check(key))
}
}
impl Default for EndpointLimits {
fn default() -> Self {
Self::new()
}
}
fn main() {
// Sliding window example
let limiter = RateLimiter::new(RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(10),
include_headers: true,
});
for i in 1..=7 {
let result = limiter.check("user:123");
println!(
"Request {}: allowed={}, remaining={}",
i, result.allowed, result.remaining
);
}
// Token bucket example
println!("\n--- Token Bucket ---");
let bucket = TokenBucket::new(10, 2.0); // 10 capacity, 2 tokens/sec
for i in 1..=12 {
let allowed = bucket.try_acquire("user:456", 1);
println!("Request {}: allowed={}", i, allowed);
}
// Leaky bucket example
println!("\n--- Leaky Bucket ---");
let leaky = LeakyBucket::new(5, 1.0); // 5 capacity, 1 leak/sec
for i in 1..=7 {
let allowed = leaky.try_add("user:789");
println!("Request {}: allowed={}", i, allowed);
}
}
| Algorithm | Behavior | Best For |
|-----------|----------|----------|
| Fixed Window | Reset at window boundary | Simple cases |
| Sliding Window | Smooth rate over time | Most APIs |
| Token Bucket | Allows bursts up to capacity | Bursty traffic |
| Leaky Bucket | Constant output rate | Smooth processing |
| Header | Description |
|--------|-------------|
| X-RateLimit-Limit | Max requests in window |
| X-RateLimit-Remaining | Requests left in window |
| X-RateLimit-Reset | Seconds until reset |
| Retry-After | When to retry (if limited) |
// DON'T: Limit by IP only (shared IPs, proxies)
let key = request.ip();
// DO: Use multiple identifiers
let key = if let Some(api_key) = get_api_key(&request) {
format!("api:{}", api_key)
} else if let Some(user_id) = get_user_id(&request) {
format!("user:{}", user_id)
} else {
format!("ip:{}", get_ip(&request))
};
// DON'T: Same limit for all endpoints
// /api/search (expensive) same as /api/ping (cheap)
// DO: Endpoint-specific limits
let limit = match path {
"/api/search" => 30,
"/api/upload" => 10,
_ => 100,
};
Run this code in the official Rust Playground