Derive, attribute, and function-like
Procedural macros are Rust's powerful compile-time metaprogramming system that operates on token streams, allowing you to write code that generates code. Unlike declarative macros that use pattern matching, procedural macros are Rust functions that take tokens as input, manipulate them programmatically, and return new tokens as output.
Think of procedural macros as compiler plugins that can inspect, modify, and generate Rust code during compilation. They enable sophisticated code generation, custom derive implementations, and compile-time transformations that would be impossible with declarative macros.
// The three types of procedural macros:
// 1. Derive macro - adds trait implementations
#[derive(Builder, Serialize, Debug)]
struct User {
name: String,
email: String,
}
// 2. Attribute macro - transforms items
#[cached(ttl = 60)]
async fn fetch_user(id: u64) -> Result<User, Error> {
database::get_user(id).await
}
// 3. Function-like macro - looks like macro_rules! but more powerful
let query = sql! {
SELECT * FROM users WHERE id = {id}
};
Key characteristics:
syn for parsing, quote for generation, proc_macro2 for testingproc-macro = true crateOne of the most common procedural macros is deriving builder patterns for complex structs. This example shows a production-quality implementation similar to the derive_builder crate:
// In proc-macro crate: builder_derive/src/lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
/// Derive a builder pattern for structs
///
/// # Example
///
/// #[derive(Builder)]
/// struct User {
/// name: String,
/// #[builder(default = "0")]
/// age: u32,
/// email: Option
/// }
///
/// let user = UserBuilder::default()
/// .name("Alice".to_string())
/// .age(30)
/// .build()?;
///
#[proc_macro_derive(Builder, attributes(builder))]
pub fn derive_builder(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let builder_name = syn::Ident::new(
&format!("{}Builder", name),
name.span()
);
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return syn::Error::new_spanned(
input,
"Builder can only be derived for structs with named fields"
).to_compile_error().into();
}
},
_ => {
return syn::Error::new_spanned(
input,
"Builder can only be derived for structs"
).to_compile_error().into();
}
};
// Generate builder struct fields (all Optional)
let builder_fields = fields.iter().map(|f| {
let name = &f.ident;
let ty = &f.ty;
// Check if field is already Option<T>
if is_option_type(ty) {
quote! { #name: #ty }
} else {
quote! { #name: ::std::option::Option<#ty> }
}
});
// Generate builder methods (setters)
let builder_methods = fields.iter().map(|f| {
let name = &f.ident;
let ty = &f.ty;
let inner_ty = if is_option_type(ty) {
extract_option_inner(ty)
} else {
ty.clone()
};
quote! {
pub fn #name(mut self, #name: #inner_ty) -> Self {
self.#name = ::std::option::Option::Some(#name);
self
}
}
});
// Generate build() method with validation
let build_fields = fields.iter().map(|f| {
let name = &f.ident;
let ty = &f.ty;
// Parse #[builder(default = "value")] attribute
let default_value = get_builder_default(&f.attrs);
if is_option_type(ty) {
// Already optional, use as-is
quote! {
#name: self.#name
}
} else if let Some(default) = default_value {
// Has default value
let default_tokens: proc_macro2::TokenStream = default.parse().unwrap();
quote! {
#name: self.#name.unwrap_or_else(|| #default_tokens)
}
} else {
// Required field
quote! {
#name: self.#name.ok_or_else(|| {
::std::format!("Field '{}' is required", stringify!(#name))
})?
}
}
});
let expanded = quote! {
impl #name {
pub fn builder() -> #builder_name {
#builder_name::default()
}
}
#[derive(Default)]
pub struct #builder_name {
#(#builder_fields,)*
}
impl #builder_name {
#(#builder_methods)*
pub fn build(self) -> ::std::result::Result<#name, ::std::string::String> {
::std::result::Result::Ok(#name {
#(#build_fields,)*
})
}
}
};
TokenStream::from(expanded)
}
// Helper functions
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "Option";
}
}
false
}
fn extract_option_inner(ty: &Type) -> Type {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
return inner_ty.clone();
}
}
}
}
}
ty.clone()
}
fn get_builder_default(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("builder") {
if let Ok(nested) = attr.parse_args::<syn::MetaNameValue>() {
if nested.path.is_ident("default") {
if let syn::Expr::Lit(expr_lit) = &nested.value {
if let syn::Lit::Str(lit_str) = &expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
}
None
}
// Real-world usage in API clients
#[derive(Builder, Debug)]
struct ApiRequest {
endpoint: String,
method: String,
#[builder(default = "60")]
timeout_secs: u64,
headers: Option<HashMap<String, String>>,
body: Option<String>,
}
fn make_request() -> Result<(), String> {
let request = ApiRequest::builder()
.endpoint("/api/users".to_string())
.method("GET".to_string())
.build()?;
println!("Request: {:?}", request);
Ok(())
}
Why this works in production:
Attribute macros transform functions, methods, or structs. This example implements automatic caching with LRU eviction:
// In proc-macro crate: cached_derive/src/lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, AttributeArgs, ItemFn, ReturnType};
/// Add automatic memoization to functions
///
/// # Example
///
/// #[cached(ttl = 60, size = 100)]
/// fn expensive_computation(x: i32) -> i32 {
/// // Complex calculation
/// x * x
/// }
///
#[proc_macro_attribute]
pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as AttributeArgs);
let input_fn = parse_macro_input!(input as ItemFn);
// Parse attributes
let mut ttl_secs: u64 = 300; // Default 5 minutes
let mut max_size: usize = 100;
for arg in args {
if let syn::NestedMeta::Meta(syn::Meta::NameValue(nv)) = arg {
if nv.path.is_ident("ttl") {
if let syn::Lit::Int(lit) = nv.lit {
ttl_secs = lit.base10_parse().unwrap();
}
} else if nv.path.is_ident("size") {
if let syn::Lit::Int(lit) = nv.lit {
max_size = lit.base10_parse().unwrap();
}
}
}
}
let fn_name = &input_fn.sig.ident;
let fn_inputs = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
let fn_block = &input_fn.block;
let fn_vis = &input_fn.vis;
let fn_attrs = &input_fn.attrs;
// Generate cache key type from function parameters
let param_names: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(&pat_ident.ident);
}
}
None
}).collect();
let param_types: Vec<_> = input_fn.sig.inputs.iter().filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
return Some(&*pat_type.ty);
}
None
}).collect();
// Extract return type
let return_type = match fn_output {
ReturnType::Type(_, ty) => ty.clone(),
ReturnType::Default => {
return syn::Error::new_spanned(
fn_output,
"Cached functions must have an explicit return type"
).to_compile_error().into();
}
};
let cache_name = syn::Ident::new(
&format!("__{}_CACHE", fn_name.to_string().to_uppercase()),
fn_name.span()
);
let original_fn_name = syn::Ident::new(
&format!("{}_uncached", fn_name),
fn_name.span()
);
let expanded = quote! {
// Store cache in thread-local or static
::lazy_static::lazy_static! {
static ref #cache_name: ::std::sync::Mutex<
::lru::LruCache<(#(#param_types,)*), (#return_type, ::std::time::Instant)>
> = ::std::sync::Mutex::new(
::lru::LruCache::new(::std::num::NonZeroUsize::new(#max_size).unwrap())
);
}
// Original function renamed
#(#fn_attrs)*
#fn_vis fn #original_fn_name(#fn_inputs) #fn_output {
#fn_block
}
// Wrapper function with caching
#(#fn_attrs)*
#fn_vis fn #fn_name(#fn_inputs) #fn_output {
let cache_key = (#(#param_names.clone(),)*);
// Check cache
{
let mut cache = #cache_name.lock().unwrap();
if let Some((cached_value, cached_time)) = cache.get(&cache_key) {
let elapsed = cached_time.elapsed().as_secs();
if elapsed < #ttl_secs {
return cached_value.clone();
}
}
}
// Compute and cache
let result = #original_fn_name(#(#param_names,)*);
{
let mut cache = #cache_name.lock().unwrap();
cache.put(cache_key, (result.clone(), ::std::time::Instant::now()));
}
result
}
};
TokenStream::from(expanded)
}
// Real-world usage in API services
use lru::LruCache;
use lazy_static::lazy_static;
use std::time::{Duration, Instant};
#[cached(ttl = 300, size = 1000)]
fn fetch_user_from_db(user_id: u64) -> Result<User, Error> {
// Expensive database query
database::query_user(user_id)
}
#[cached(ttl = 60, size = 100)]
fn calculate_recommendation_score(user_id: u64, item_id: u64) -> f64 {
// Complex ML inference
ml_model::predict(user_id, item_id)
}
// The macro automatically:
// - Creates LRU cache for each function
// - Generates cache key from parameters
// - Checks TTL before returning cached values
// - Thread-safe with Mutex
// - Zero overhead when cache is hit
Production benefits:
The async-trait crate is one of the most widely used procedural macros in the Rust ecosystem. Here's how it works:
// In proc-macro crate: async_trait_impl/src/lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemTrait, TraitItem, ReturnType, ImplItem};
/// Enable async fn in trait definitions
///
/// # Example
///
/// #[async_trait]
/// trait Repository {
/// async fn find(&self, id: u64) -> Result
/// async fn save(&mut self, user: User) -> Result<(), Error>;
/// }
///
#[proc_macro_attribute]
pub fn async_trait(_args: TokenStream, input: TokenStream) -> TokenStream {
let mut input_trait = parse_macro_input!(input as ItemTrait);
let trait_name = &input_trait.ident;
let mut transformed_items = Vec::new();
for item in &mut input_trait.items {
if let TraitItem::Fn(method) = item {
if method.sig.asyncness.is_some() {
// Transform async fn -> fn that returns Pin<Box<dyn Future>>
method.sig.asyncness = None;
let output = match &method.sig.output {
ReturnType::Type(_, ty) => ty.clone(),
ReturnType::Default => {
Box::new(syn::parse_quote!(()))
}
};
// Check if method takes &self, &mut self, or self
let self_ty = if let Some(syn::FnArg::Receiver(receiver)) = method.sig.inputs.first() {
if receiver.mutability.is_some() {
quote! { + ::std::marker::Send }
} else {
quote! { + ::std::marker::Send }
}
} else {
quote! { + ::std::marker::Send }
};
method.sig.output = syn::parse_quote! {
-> ::std::pin::Pin<::std::boxed::Box<
dyn ::std::future::Future<Output = #output> #self_ty + '_
>>
};
// If method has a default implementation, wrap it
if let Some(block) = &method.default {
let wrapped_block = quote! {
{
::std::boxed::Box::pin(async move #block)
}
};
method.default = Some(syn::parse_quote! { #wrapped_block });
}
}
}
}
let expanded = quote! {
#input_trait
};
TokenStream::from(expanded)
}
// Similar transformation for impl blocks
#[proc_macro_attribute]
pub fn async_trait_impl(_args: TokenStream, input: TokenStream) -> TokenStream {
let mut input_impl = parse_macro_input!(input as syn::ItemImpl);
for item in &mut input_impl.items {
if let ImplItem::Fn(method) = item {
if method.sig.asyncness.is_some() {
method.sig.asyncness = None;
let output = match &method.sig.output {
ReturnType::Type(_, ty) => ty.clone(),
ReturnType::Default => Box::new(syn::parse_quote!(())),
};
method.sig.output = syn::parse_quote! {
-> ::std::pin::Pin<::std::boxed::Box<
dyn ::std::future::Future<Output = #output> + ::std::marker::Send + '_
>>
};
let block = &method.block;
method.block = syn::parse_quote! {
{
::std::boxed::Box::pin(async move #block)
}
};
}
}
}
let expanded = quote! {
#input_impl
};
TokenStream::from(expanded)
}
// Real-world usage in async services
use async_trait::async_trait;
#[async_trait]
trait UserRepository {
async fn find_by_id(&self, id: u64) -> Result<User, Error>;
async fn find_by_email(&self, email: &str) -> Result<User, Error>;
async fn save(&self, user: &User) -> Result<(), Error>;
async fn delete(&self, id: u64) -> Result<(), Error>;
}
struct PostgresUserRepository {
pool: PgPool,
}
#[async_trait]
impl UserRepository for PostgresUserRepository {
async fn find_by_id(&self, id: u64) -> Result<User, Error> {
sqlx::query_as!(User, "SELECT * FROM users WHERE id = $1", id as i64)
.fetch_one(&self.pool)
.await
.map_err(Into::into)
}
async fn find_by_email(&self, email: &str) -> Result<User, Error> {
sqlx::query_as!(User, "SELECT * FROM users WHERE email = $1", email)
.fetch_one(&self.pool)
.await
.map_err(Into::into)
}
async fn save(&self, user: &User) -> Result<(), Error> {
sqlx::query!(
"INSERT INTO users (id, name, email) VALUES ($1, $2, $3)
ON CONFLICT (id) DO UPDATE SET name = $2, email = $3",
user.id as i64,
user.name,
user.email
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn delete(&self, id: u64) -> Result<(), Error> {
sqlx::query!("DELETE FROM users WHERE id = $1", id as i64)
.execute(&self.pool)
.await?;
Ok(())
}
}
// Mock implementation for testing
struct MockUserRepository {
users: Arc<Mutex<HashMap<u64, User>>>,
}
#[async_trait]
impl UserRepository for MockUserRepository {
async fn find_by_id(&self, id: u64) -> Result<User, Error> {
self.users
.lock()
.await
.get(&id)
.cloned()
.ok_or(Error::NotFound)
}
async fn find_by_email(&self, email: &str) -> Result<User, Error> {
self.users
.lock()
.await
.values()
.find(|u| u.email == email)
.cloned()
.ok_or(Error::NotFound)
}
async fn save(&self, user: &User) -> Result<(), Error> {
self.users.lock().await.insert(user.id, user.clone());
Ok(())
}
async fn delete(&self, id: u64) -> Result<(), Error> {
self.users.lock().await.remove(&id);
Ok(())
}
}
Why this is essential:
Database ORMs heavily rely on derive macros for type-safe SQL generation. Here's a simplified version of patterns from Diesel and SQLx:
// In proc-macro crate: orm_derive/src/lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Data, Fields};
/// Derive database model traits
///
/// # Example
///
/// #[derive(Model)]
/// #[table_name = "users"]
/// struct User {
/// #[primary_key]
/// id: i64,
/// #[column(name = "full_name")]
/// name: String,
/// email: String,
/// #[column(default)]
/// created_at: chrono::DateTime
/// }
///
#[proc_macro_derive(Model, attributes(table_name, primary_key, column))]
pub fn derive_model(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
// Extract table name from attribute
let table_name = extract_table_name(&input.attrs)
.unwrap_or_else(|| to_snake_case(&name.to_string()));
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => panic!("Model must have named fields"),
},
_ => panic!("Model can only be derived for structs"),
};
// Find primary key field
let primary_key = fields.iter().find(|f| {
f.attrs.iter().any(|attr| attr.path().is_ident("primary_key"))
}).map(|f| f.ident.as_ref().unwrap());
// Generate column mappings
let column_mappings: Vec<_> = fields.iter().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let column_name = extract_column_name(&f.attrs)
.unwrap_or_else(|| to_snake_case(&field_name.to_string()));
let field_ty = &f.ty;
quote! {
(#column_name, stringify!(#field_name), stringify!(#field_ty))
}
}).collect();
// Generate SELECT query
let column_list: Vec<_> = fields.iter().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let column_name = extract_column_name(&f.attrs)
.unwrap_or_else(|| to_snake_case(&field_name.to_string()));
quote! { #column_name }
}).collect();
// Generate field assignments for query results
let field_assignments: Vec<_> = fields.iter().enumerate().map(|(i, f)| {
let field_name = f.ident.as_ref().unwrap();
quote! {
#field_name: row.get(#i)?
}
}).collect();
// Generate INSERT query
let insert_placeholders: Vec<_> = (1..=fields.len())
.map(|i| format!("${}", i))
.collect();
let insert_values: Vec<_> = fields.iter().map(|f| {
let field_name = f.ident.as_ref().unwrap();
quote! { &self.#field_name }
}).collect();
let expanded = quote! {
impl #name {
pub const TABLE_NAME: &'static str = #table_name;
pub fn table_name() -> &'static str {
#table_name
}
pub fn column_names() -> Vec<&'static str> {
vec![#(#column_list),*]
}
pub async fn find_by_id(
pool: &sqlx::PgPool,
id: i64
) -> Result<Self, sqlx::Error> {
let query = format!(
"SELECT {} FROM {} WHERE id = $1",
Self::column_names().join(", "),
#table_name
);
sqlx::query(&query)
.bind(id)
.fetch_one(pool)
.await
.and_then(|row| {
Ok(Self {
#(#field_assignments),*
})
})
}
pub async fn find_all(
pool: &sqlx::PgPool
) -> Result<Vec<Self>, sqlx::Error> {
let query = format!(
"SELECT {} FROM {}",
Self::column_names().join(", "),
#table_name
);
sqlx::query(&query)
.fetch_all(pool)
.await?
.into_iter()
.map(|row| {
Ok(Self {
#(#field_assignments),*
})
})
.collect()
}
pub async fn insert(&self, pool: &sqlx::PgPool) -> Result<(), sqlx::Error> {
let query = format!(
"INSERT INTO {} ({}) VALUES ({})",
#table_name,
Self::column_names().join(", "),
vec![#(#insert_placeholders),*].join(", ")
);
sqlx::query(&query)
#(.bind(#insert_values))*
.execute(pool)
.await?;
Ok(())
}
pub async fn update(&self, pool: &sqlx::PgPool) -> Result<(), sqlx::Error> {
// Generate UPDATE query based on primary key
todo!("Implement UPDATE logic")
}
pub async fn delete(pool: &sqlx::PgPool, id: i64) -> Result<(), sqlx::Error> {
let query = format!("DELETE FROM {} WHERE id = $1", #table_name);
sqlx::query(&query).bind(id).execute(pool).await?;
Ok(())
}
}
// Query builder integration
impl #name {
pub fn query() -> QueryBuilder<#name> {
QueryBuilder::new(#table_name)
}
}
};
TokenStream::from(expanded)
}
// Helper functions
fn extract_table_name(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("table_name") {
if let Ok(lit) = attr.parse_args::<syn::LitStr>() {
return Some(lit.value());
}
}
}
None
}
fn extract_column_name(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("column") {
if let Ok(meta) = attr.parse_args::<syn::Meta>() {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("name") {
if let syn::Expr::Lit(expr_lit) = nv.value {
if let syn::Lit::Str(lit_str) = expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
}
}
None
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(ch.to_lowercase().next().unwrap());
} else {
result.push(ch);
}
}
result
}
// Real-world usage
#[derive(Model, Debug, Clone)]
#[table_name = "users"]
struct User {
#[primary_key]
id: i64,
#[column(name = "full_name")]
name: String,
email: String,
created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Model, Debug, Clone)]
#[table_name = "posts"]
struct Post {
#[primary_key]
id: i64,
user_id: i64,
title: String,
content: String,
#[column(default)]
published: bool,
}
async fn example_usage(pool: &PgPool) -> Result<(), Error> {
// Find user by ID
let user = User::find_by_id(pool, 1).await?;
println!("User: {:?}", user);
// Get all posts
let posts = Post::find_all(pool).await?;
println!("Found {} posts", posts.len());
// Insert new user
let new_user = User {
id: 0,
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
created_at: chrono::Utc::now(),
};
new_user.insert(pool).await?;
Ok(())
}
Production advantages:
The serde crate's derive macros are among the most sophisticated in Rust. Here's a simplified version showing the key patterns:
// In proc-macro crate: serde_derive/src/lib.rs
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Data, Fields};
/// Derive Serialize trait
///
/// # Example
///
/// #[derive(Serialize)]
/// struct User {
/// id: u64,
/// #[serde(rename = "full_name")]
/// name: String,
/// #[serde(skip_serializing_if = "Option::is_none")]
/// email: Option
/// }
///
#[proc_macro_derive(Serialize, attributes(serde))]
pub fn derive_serialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let serialize_body = match &input.data {
Data::Struct(data) => {
match &data.fields {
Fields::Named(fields) => {
let field_serializations: Vec<_> = fields.named.iter().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_str = field_name.to_string();
// Check for #[serde(rename = "...")] attribute
let serialized_name = extract_serde_rename(&f.attrs)
.unwrap_or_else(|| field_str.clone());
// Check for #[serde(skip)]
if has_serde_skip(&f.attrs) {
return quote! {};
}
// Check for #[serde(skip_serializing_if = "path")]
if let Some(skip_fn) = extract_skip_serializing_if(&f.attrs) {
let skip_fn_tokens: proc_macro2::TokenStream =
skip_fn.parse().unwrap();
quote! {
if !#skip_fn_tokens(&self.#field_name) {
serializer.serialize_field(
#serialized_name,
&self.#field_name
)?;
}
}
} else {
quote! {
serializer.serialize_field(
#serialized_name,
&self.#field_name
)?;
}
}
}).collect();
let num_fields = fields.named.len();
quote! {
let mut serializer = serializer.serialize_struct(
stringify!(#name),
#num_fields
)?;
#(#field_serializations)*
serializer.end()
}
}
Fields::Unnamed(fields) => {
let field_serializations: Vec<_> = fields.unnamed.iter().enumerate()
.map(|(i, _)| {
let index = syn::Index::from(i);
quote! {
serializer.serialize_field(&self.#index)?;
}
}).collect();
let num_fields = fields.unnamed.len();
quote! {
let mut serializer = serializer.serialize_tuple(#num_fields)?;
#(#field_serializations)*
serializer.end()
}
}
Fields::Unit => {
quote! {
serializer.serialize_unit_struct(stringify!(#name))
}
}
}
}
Data::Enum(data) => {
let variant_serializations: Vec<_> = data.variants.iter().map(|v| {
let variant_name = &v.ident;
let variant_str = variant_name.to_string();
match &v.fields {
Fields::Named(fields) => {
let field_names: Vec<_> = fields.named.iter()
.map(|f| f.ident.as_ref().unwrap())
.collect();
let field_serializations: Vec<_> = field_names.iter()
.map(|fname| {
let fname_str = fname.to_string();
quote! {
serializer.serialize_field(#fname_str, #fname)?;
}
}).collect();
quote! {
Self::#variant_name { #(#field_names),* } => {
let mut serializer = serializer.serialize_struct_variant(
stringify!(#name),
0,
#variant_str,
#field_serializations.len()
)?;
#(#field_serializations)*
serializer.end()
}
}
}
Fields::Unnamed(fields) => {
let field_bindings: Vec<_> = (0..fields.unnamed.len())
.map(|i| syn::Ident::new(&format!("f{}", i), v.ident.span()))
.collect();
let field_serializations: Vec<_> = field_bindings.iter()
.map(|binding| {
quote! {
serializer.serialize_field(#binding)?;
}
}).collect();
quote! {
Self::#variant_name(#(#field_bindings),*) => {
let mut serializer = serializer.serialize_tuple_variant(
stringify!(#name),
0,
#variant_str,
#field_bindings.len()
)?;
#(#field_serializations)*
serializer.end()
}
}
}
Fields::Unit => {
quote! {
Self::#variant_name => {
serializer.serialize_unit_variant(
stringify!(#name),
0,
#variant_str
)
}
}
}
}
}).collect();
quote! {
match self {
#(#variant_serializations),*
}
}
}
Data::Union(_) => {
panic!("Serialize cannot be derived for unions");
}
};
let expanded = quote! {
impl #impl_generics serde::Serialize for #name #ty_generics #where_clause {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
#serialize_body
}
}
};
TokenStream::from(expanded)
}
// Helper functions for serde attributes
fn has_serde_skip(attrs: &[syn::Attribute]) -> bool {
for attr in attrs {
if attr.path().is_ident("serde") {
if let Ok(meta) = attr.parse_args::<syn::Meta>() {
if let syn::Meta::Path(path) = meta {
if path.is_ident("skip") {
return true;
}
}
}
}
}
false
}
fn extract_serde_rename(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("serde") {
if let Ok(meta) = attr.parse_args::<syn::Meta>() {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("rename") {
if let syn::Expr::Lit(expr_lit) = nv.value {
if let syn::Lit::Str(lit_str) = expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
}
}
None
}
fn extract_skip_serializing_if(attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("serde") {
if let Ok(meta) = attr.parse_args::<syn::Meta>() {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("skip_serializing_if") {
if let syn::Expr::Lit(expr_lit) = nv.value {
if let syn::Lit::Str(lit_str) = expr_lit.lit {
return Some(lit_str.value());
}
}
}
}
}
}
}
None
}
// Real-world usage
#[derive(Serialize, Debug)]
struct User {
id: u64,
#[serde(rename = "userName")]
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
email: Option<String>,
#[serde(skip)]
password_hash: String,
}
#[derive(Serialize, Debug)]
#[serde(tag = "type", content = "data")]
enum Message {
Text(String),
Image { url: String, width: u32, height: u32 },
Video { url: String, duration: u32 },
}
fn api_response_example() -> String {
let user = User {
id: 42,
name: "Alice".to_string(),
email: Some("alice@example.com".to_string()),
password_hash: "secret".to_string(),
};
// Serializes to: {"id":42,"userName":"Alice","email":"alice@example.com"}
// Note: password_hash is skipped
serde_json::to_string(&user).unwrap()
}
Why serde patterns matter:
Procedural macros operate on TokenStream, Rust's representation of code as tokens:
use proc_macro::TokenStream;
// Every proc macro receives TokenStream and returns TokenStream
#[proc_macro]
pub fn my_macro(input: TokenStream) -> TokenStream {
// Input is the tokens passed to the macro
// Output becomes code in the caller's location
// You can inspect tokens manually (not recommended)
for token in input {
println!("Token: {:?}", token);
}
// Or use syn to parse into AST (recommended)
let parsed = syn::parse::<DeriveInput>(input).unwrap();
// Generate code with quote! (recommended)
let output = quote! {
// Generated code here
};
TokenStream::from(output)
}
TokenStream characteristics:
syn to parse into meaningful AST structuresquote! to generate new token streamssyn is the de-facto standard for parsing Rust syntax in procedural macros:
use syn::{
parse_macro_input,
DeriveInput,
Data,
Fields,
Type,
Expr,
parse::{Parse, ParseStream},
};
// Parse derive input (most common)
#[proc_macro_derive(MyTrait)]
pub fn derive_my_trait(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
// DeriveInput gives you:
let name = &input.ident; // Struct/enum name
let generics = &input.generics; // Generic parameters
let attrs = &input.attrs; // Attributes
let vis = &input.vis; // Visibility
match &input.data {
Data::Struct(data_struct) => {
match &data_struct.fields {
Fields::Named(fields) => {
// Struct with named fields
for field in &fields.named {
let field_name = &field.ident;
let field_ty = &field.ty;
let field_attrs = &field.attrs;
// Process each field
}
}
Fields::Unnamed(fields) => {
// Tuple struct
for (i, field) in fields.unnamed.iter().enumerate() {
// Access by index
}
}
Fields::Unit => {
// Unit struct
}
}
}
Data::Enum(data_enum) => {
// Enum variants
for variant in &data_enum.variants {
let variant_name = &variant.ident;
let variant_fields = &variant.fields;
// Process each variant
}
}
Data::Union(_) => {
// Unions (rarely used)
}
}
// ... generate code
TokenStream::new()
}
// Custom parsing for attribute macros
struct MyAttrArgs {
timeout: u64,
retries: u32,
}
impl Parse for MyAttrArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut timeout = 30;
let mut retries = 3;
// Parse name = value pairs
while !input.is_empty() {
let name: syn::Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
if name == "timeout" {
let lit: syn::LitInt = input.parse()?;
timeout = lit.base10_parse()?;
} else if name == "retries" {
let lit: syn::LitInt = input.parse()?;
retries = lit.base10_parse()?;
}
// Optional comma
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
Ok(MyAttrArgs { timeout, retries })
}
}
#[proc_macro_attribute]
pub fn retry(args: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as MyAttrArgs);
let input_fn = parse_macro_input!(item as syn::ItemFn);
// Use parsed args
let timeout = args.timeout;
let retries = args.retries;
// ... generate code
TokenStream::new()
}
Key syn patterns:
parse_macro_input! for error handlingDeriveInput for derive macrosItemFn, ItemStruct, ItemEnum for attribute macrosParse implementations for complex attribute syntaxData and Fields enumssyn::Errorquote! allows you to write Rust code that generates Rust code:
use quote::{quote, format_ident};
// Basic usage
let tokens = quote! {
fn hello() {
println!("Hello, world!");
}
};
// Interpolation with #
let name = format_ident!("MyStruct");
let field_type = syn::parse_str::<Type>("String").unwrap();
let tokens = quote! {
struct #name {
field: #field_type,
}
};
// Repetition with #(...)*
let field_names = vec![
format_ident!("field1"),
format_ident!("field2"),
format_ident!("field3"),
];
let tokens = quote! {
struct MyStruct {
#(pub #field_names: String,)*
}
};
// Expands to:
// struct MyStruct {
// pub field1: String,
// pub field2: String,
// pub field3: String,
// }
// Using #(...),* for comma-separated lists
let param_names = vec![format_ident!("a"), format_ident!("b"), format_ident!("c")];
let tokens = quote! {
fn call_function(#(#param_names: i32),*) {
// ...
}
};
quote! best practices:
format_ident! for generating identifiers#(...)* for repetition#(...),* for comma-separated listsProcedural macros should provide excellent error messages:
use syn::spanned::Spanned;
#[proc_macro_derive(MyTrait)]
pub fn derive_my_trait(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
// Check that it's a struct
let data_struct = match &input.data {
Data::Struct(s) => s,
_ => {
// Use the input's span for the error location
return syn::Error::new_spanned(
&input,
"MyTrait can only be derived for structs"
).to_compile_error().into();
}
};
// Check named fields
let fields = match &data_struct.fields {
Fields::Named(f) => &f.named,
_ => {
return syn::Error::new_spanned(
&data_struct.fields,
"MyTrait requires named fields"
).to_compile_error().into();
}
};
// Validate each field
for field in fields {
let field_ty = &field.ty;
// Check if type is supported
if !is_supported_type(field_ty) {
return syn::Error::new_spanned(
field_ty,
format!("Type {:?} is not supported by MyTrait", field_ty)
).to_compile_error().into();
}
}
// Generate code...
TokenStream::new()
}
Error handling best practices:
syn::Error::new_spanned for precise error locationsto_compile_error()#[derive(MyTrait)] for your own traits#[my_attr] to modify functions or typesmacro_rules! for simple patterns// BAD: Generic error without location
#[proc_macro_derive(Bad)]
pub fn bad(input: TokenStream) -> TokenStream {
let input = syn::parse::<DeriveInput>(input).unwrap(); // Panic!
if input.generics.params.is_empty() {
panic!("Must have generics!"); // Bad error
}
TokenStream::new()
}
// GOOD: Specific errors with spans
#[proc_macro_derive(Good)]
pub fn good(input: TokenStream) -> TokenStream {
let input = match syn::parse::<DeriveInput>(input) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
if input.generics.params.is_empty() {
return syn::Error::new_spanned(
&input,
"Good requires at least one generic parameter. \
Try: #[derive(Good)] struct MyStruct<T> { ... }"
).to_compile_error().into();
}
TokenStream::new()
}
// BAD: Only testing happy path
#[cfg(test)]
mod tests {
#[test]
fn test_basic() {
let input = quote! {
struct Simple {
field: String,
}
};
assert!(derive_impl(input).is_ok());
}
}
// GOOD: Test all edge cases
#[cfg(test)]
mod tests {
#[test]
fn test_generic_struct() {
let input = quote! {
struct Generic<T> { field: T }
};
assert!(derive_impl(input).is_ok());
}
#[test]
fn test_lifetime_params() {
let input = quote! {
struct WithLifetime<'a> { field: &'a str }
};
assert!(derive_impl(input).is_ok());
}
#[test]
fn test_where_clause() {
let input = quote! {
struct WithWhere<T> where T: Clone { field: T }
};
assert!(derive_impl(input).is_ok());
}
#[test]
fn test_enum_error() {
let input = quote! { enum NotSupported {} };
let output = derive_impl(input).to_string();
assert!(output.contains("compile_error"));
}
}
// BAD: Name collisions possible
let output = quote! {
impl MyTrait for #name {
fn method(&self) {
let result = 42; // Might collide with user code!
println!("{}", result);
}
}
};
// GOOD: Use fresh identifiers
use quote::format_ident;
let result_var = format_ident!("__private_result_{}", name);
let output = quote! {
impl MyTrait for #name {
fn method(&self) {
let #result_var = 42; // Hygienic
println!("{}", #result_var);
}
}
};
Procedural macros execute during compilation:
// Simple derive: ~1-10ms overhead per struct
#[derive(SimpleDerive)]
struct User {
name: String,
}
// Complex derive with validation: ~10-50ms per struct
#[derive(ComplexDerive)]
#[validate(all_fields)]
struct ComplexData {
field1: String,
field2: i32,
// ... 50 fields
}
// Many derives: multiplicative cost
#[derive(
Debug, // ~1ms
Clone, // ~1ms
Serialize, // ~5ms
Deserialize,// ~5ms
Builder, // ~3ms
Validate, // ~2ms
)]
struct HeavyStruct { /* ... */ }
// Total: ~17ms compile time for this one struct
Implement a procedural macro that derives Display for enums:
// Create this derive macro
#[derive(Display)]
enum Status {
Active,
Inactive,
Pending,
}
// Should generate Display implementation
Create an attribute macro that measures function execution time:
// Implement this attribute
#[timed]
fn expensive_operation() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
// Should log: "expensive_operation took 1000ms"
Implement a builder derive macro with compile-time and runtime validation:
// Implement this advanced builder
#[derive(Builder)]
#[builder(validate)]
struct User {
#[builder(required)]
name: String,
#[builder(default = "18", validate = "age > 0 && age < 150")]
age: u32,
#[builder(validate = "email.contains('@')")]
email: String,
}
The most widely used procedural macros in Rust:
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Debug)]
struct Config {
#[serde(rename = "server_host")]
host: String,
#[serde(default = "default_port")]
port: u16,
#[serde(skip_serializing_if = "Option::is_none")]
api_key: Option<String>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
fn default_port() -> u16 { 8080 }
Simplifies async main functions:
#[tokio::main]
async fn main() {
let response = reqwest::get("https://api.example.com").await.unwrap();
println!("Status: {}", response.status());
}
Enables async methods in traits:
use async_trait::async_trait;
#[async_trait]
trait Storage {
async fn get(&self, key: &str) -> Result<Vec<u8>, Error>;
async fn put(&self, key: &str, value: Vec<u8>) -> Result<(), Error>;
}
Simplifies error type creation:
use thiserror::Error;
#[derive(Error, Debug)]
enum ApiError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Not found: {resource} with id {id}")]
NotFound { resource: String, id: u64 },
}
Command-line argument parsing:
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "myapp")]
#[command(about = "A sample CLI application")]
struct Cli {
#[arg(short, long)]
config: Option<String>,
#[arg(short, long, default_value_t = 8080)]
port: u16,
}
Run this code in the official Rust Playground