System Architecture Overview
System Architecture Overview
This document provides a comprehensive overview of the nanochat system architecture, including the core model design, training pipeline, and deployment infrastructure.
High-Level System Design
NanoChat implements a complete end-to-end conversational AI pipeline in ~8,000 lines of code. The system is designed with modularity and hackability in mind, making it easy to understand and modify each component.
Core Components
1. GPT Transformer Model
The heart of nanochat is a modern Transformer implementation with several key improvements over the original GPT architecture.
Source: nanochat/gpt.py:1-15
"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""
Model Configuration
Source: nanochat/gpt.py:23-32
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
2. Modern Attention Mechanism
Rotary Position Embeddings (RoPE)
Instead of learned positional embeddings, nanochat uses RoPE for better length generalization:
Source: nanochat/gpt.py:40-47
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
Group Query Attention (GQA)
GQA reduces memory usage during inference by sharing key-value heads across multiple query heads:
Source: nanochat/gpt.py:82-89
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA)
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
QK Normalization
Query and key normalization improves training stability:
Source: nanochat/gpt.py:74-76
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
3. Feed-Forward Network with ReLU²
The MLP layers use ReLU² activation instead of traditional GELU/SwiGLU:
Source: nanochat/gpt.py:109-117
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # ReLU² activation
x = self.c_proj(x)
return x
4. Normalization Strategy
Uses RMSNorm without learnable parameters, applied pre-residual:
Source: nanochat/gpt.py:35-37
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
Source: nanochat/gpt.py:125-128
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache) # Pre-norm
x = x + self.mlp(norm(x)) # Pre-norm
return x
Training Architecture
Distributed Optimization Strategy
NanoChat uses a unique dual-optimizer approach for different parameter types:
Source: nanochat/gpt.py:251-276
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
- AdamW: For embedding and output projection layers
- Muon: For all linear transformation matrices in attention and MLP
Gradient Accumulation & Scaling
The system automatically calculates gradient accumulation steps to achieve target batch sizes across multiple GPUs:
Source: scripts/base_train.py:120-127
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
Inference Architecture
KV Cache Management
Efficient autoregressive generation with sophisticated cache management:
Source: nanochat/engine.py:50-85
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer inserts.
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
Tool Use Integration
The inference engine supports tool use (Python calculator) through special token handling:
Source: nanochat/engine.py:25-48
def use_calculator(expr):
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
# Check if it's a pure math expression
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
Data Pipeline Architecture
Distributed Data Loading
Source: nanochat/dataloader.py:15-25
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4,
tokenizer_batch_size=128, device="cuda",
resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation supports approximate resume training.
The state_dict returned can be used to approximately resume from a desired point.
"""
Tokenization Strategy
NanoChat implements a GPT-4 style BPE tokenizer with special tokens for conversation structure:
Source: nanochat/tokenizer.py:11-21
SPECIAL_TOKENS = [
"<|bos|>", # Beginning of Sequence
"<|user_start|>", # User messages
"<|user_end|>",
"<|assistant_start|>", # Assistant messages
"<|assistant_end|>",
"<|python_start|>", # Tool invocation
"<|python_end|>",
"<|output_start|>", # Tool output
"<|output_end|>",
]
Evaluation Framework
Task Abstraction
Source: tasks/common.py:12-35
class Task:
"""
Base class of a Task. Allows for lightweight slicing of datasets.
"""
@property
def eval_type(self):
# one of 'generative' | 'categorical'
raise NotImplementedError
def evaluate(self, problem, completion):
raise NotImplementedError
Multiple Choice Rendering
Source: tasks/common.py:95-110
def render_mc(question, letters, choices):
"""
The common multiple choice rendering format.
Important: letter comes AFTER choice for better binding in smaller models.
"""
query = f"Multiple Choice question: {question}\n"
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
query += "\nRespond only with the letter of the correct answer."
return query
Performance Characteristics
FLOP Estimation
The model includes detailed FLOP counting for performance analysis:
Source: nanochat/gpt.py:218-230
def estimate_flops(self):
"""
Return estimated FLOPs per token for the model (forward + backward).
Each matmul weight parameter contributes 6 FLOPs total.
"""
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
Model Scaling
Parameter counts follow Chinchilla-optimal scaling:
Source: nanochat/gpt.py:234-245
def num_scaling_params(self):
"""
Return all parameters (Chinchilla approach for cleaner scaling laws).
Includes embedding parameters unlike Kaplan et al.
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
Technology Stack
- PyTorch: Core deep learning framework with DDP for distributed training
- FastAPI: Web server for chat interface and API endpoints
- tiktoken/rustbpe: Efficient tokenization implementations
- HuggingFace Datasets: Standardized dataset loading
- wandb: Experiment tracking and logging
- uvicorn: ASGI web server for production deployment
Design Principles
- Simplicity: Minimal dependencies, clean abstractions
- Hackability: Easy to understand and modify each component
- Performance: Efficient training and inference with modern optimizations
- Completeness: Full pipeline from data to deployment
- Reproducibility: Deterministic training with proper checkpointing
The architecture demonstrates that a complete ChatGPT-style system can be implemented in a remarkably compact and understandable codebase while maintaining competitive performance.
Sources:
- nanochat/gpt.py (model architecture)
- nanochat/engine.py (inference engine)
- nanochat/dataloader.py (data pipeline)
- scripts/base_train.py (training infrastructure)
- tasks/common.py (evaluation framework)