System Architecture Overview

System Architecture Overview

This document provides a comprehensive overview of the nanochat system architecture, including the core model design, training pipeline, and deployment infrastructure.

High-Level System Design

NanoChat implements a complete end-to-end conversational AI pipeline in ~8,000 lines of code. The system is designed with modularity and hackability in mind, making it easy to understand and modify each component.

graph TB subgraph "Data Pipeline" A[Raw Text Data] --> B[Tokenizer Training] B --> C[Text Tokenization] C --> D[Distributed DataLoader] end subgraph "Model Architecture" E[GPT Transformer] --> F[Embedding Layer] F --> G[Transformer Blocks] G --> H[Language Model Head] G --> G1[Multi-Head Attention] G --> G2[Feed-Forward Network] end subgraph "Training Pipeline" D --> I[Base Pretraining] I --> J[Mid-training] J --> K[Supervised Fine-tuning] K --> L[Reinforcement Learning] end subgraph "Inference & Deployment" L --> M[Inference Engine] M --> N[KV Cache Management] M --> O[Tool Use Integration] M --> P[Web Interface] end subgraph "Evaluation Framework" Q[Task Definitions] --> R[Multiple Choice Tasks] Q --> S[Generative Tasks] R --> T[ARC, MMLU] S --> U[GSM8K, HumanEval] end

Core Components

1. GPT Transformer Model

The heart of nanochat is a modern Transformer implementation with several key improvements over the original GPT architecture.

Source: nanochat/gpt.py:1-15

python
"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""

Model Configuration

Source: nanochat/gpt.py:23-32

python
@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6 # number of query heads
    n_kv_head: int = 6 # number of key/value heads (GQA)
    n_embd: int = 768

2. Modern Attention Mechanism

Rotary Position Embeddings (RoPE)

Instead of learned positional embeddings, nanochat uses RoPE for better length generalization:

Source: nanochat/gpt.py:40-47

python
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)

Group Query Attention (GQA)

GQA reduces memory usage during inference by sharing key-value heads across multiple query heads:

Source: nanochat/gpt.py:82-89

python
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA)
if kv_cache is None or Tq == Tk:
    # During training (no KV cache), attend as usual with causal attention
    y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)

QK Normalization

Query and key normalization improves training stability:

Source: nanochat/gpt.py:74-76

python
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm

3. Feed-Forward Network with ReLU²

The MLP layers use ReLU² activation instead of traditional GELU/SwiGLU:

Source: nanochat/gpt.py:109-117

python
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()  # ReLU² activation
        x = self.c_proj(x)
        return x

4. Normalization Strategy

Uses RMSNorm without learnable parameters, applied pre-residual:

Source: nanochat/gpt.py:35-37

python
def norm(x):
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))

Source: nanochat/gpt.py:125-128

python
def forward(self, x, cos_sin, kv_cache):
    x = x + self.attn(norm(x), cos_sin, kv_cache)  # Pre-norm
    x = x + self.mlp(norm(x))  # Pre-norm
    return x

Training Architecture

Distributed Optimization Strategy

NanoChat uses a unique dual-optimizer approach for different parameter types:

Source: nanochat/gpt.py:251-276

python
# Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel
adam_groups = [
    dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
    dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)

# Create the Muon optimizer for the linear layers
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
  • AdamW: For embedding and output projection layers
  • Muon: For all linear transformation matrices in attention and MLP

Gradient Accumulation & Scaling

The system automatically calculates gradient accumulation steps to achieve target batch sizes across multiple GPUs:

Source: scripts/base_train.py:120-127

python
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd

Inference Architecture

KV Cache Management

Efficient autoregressive generation with sophisticated cache management:

Source: nanochat/engine.py:50-85

python
class KVCache:
    """
    Works hand-in-hand with the GPT model to maintain the KV cache.
    Note that the .pos advances automatically after the last layer inserts.
    """

    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        # Each of K/V is of shape (B, H, T, D) and we have one per layer
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0 # current position in time in the cache

Tool Use Integration

The inference engine supports tool use (Python calculator) through special token handling:

Source: nanochat/engine.py:25-48

python
def use_calculator(expr):
    """
    Evaluate a Python expression safely.
    Supports both math expressions and string operations like .count()
    """
    # Remove commas from numbers
    expr = expr.replace(",", "")
    
    # Check if it's a pure math expression
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # disallow power operator
            return None
        return eval_with_timeout(expr)

Data Pipeline Architecture

Distributed Data Loading

Source: nanochat/dataloader.py:15-25

python
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, 
                                                tokenizer_batch_size=128, device="cuda", 
                                                resume_state_dict=None):
    """
    Stream pretraining text from parquet files, tokenize, yield training batches.
    
    This implementation supports approximate resume training.
    The state_dict returned can be used to approximately resume from a desired point.
    """

Tokenization Strategy

NanoChat implements a GPT-4 style BPE tokenizer with special tokens for conversation structure:

Source: nanochat/tokenizer.py:11-21

python
SPECIAL_TOKENS = [
    "<|bos|>",              # Beginning of Sequence
    "<|user_start|>",       # User messages
    "<|user_end|>",
    "<|assistant_start|>",  # Assistant messages  
    "<|assistant_end|>",
    "<|python_start|>",     # Tool invocation
    "<|python_end|>",
    "<|output_start|>",     # Tool output
    "<|output_end|>",
]

Evaluation Framework

Task Abstraction

Source: tasks/common.py:12-35

python
class Task:
    """
    Base class of a Task. Allows for lightweight slicing of datasets.
    """
    
    @property
    def eval_type(self):
        # one of 'generative' | 'categorical'
        raise NotImplementedError

    def evaluate(self, problem, completion):
        raise NotImplementedError

Multiple Choice Rendering

Source: tasks/common.py:95-110

python
def render_mc(question, letters, choices):
    """
    The common multiple choice rendering format.
    Important: letter comes AFTER choice for better binding in smaller models.
    """
    query = f"Multiple Choice question: {question}\n"
    query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
    query += "\nRespond only with the letter of the correct answer."
    return query

Performance Characteristics

FLOP Estimation

The model includes detailed FLOP counting for performance analysis:

Source: nanochat/gpt.py:218-230

python
def estimate_flops(self):
    """
    Return estimated FLOPs per token for the model (forward + backward).
    Each matmul weight parameter contributes 6 FLOPs total.
    """
    nparams = sum(p.numel() for p in self.parameters())
    nparams_embedding = self.transformer.wte.weight.numel()
    l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
    num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
    return num_flops_per_token

Model Scaling

Parameter counts follow Chinchilla-optimal scaling:

Source: nanochat/gpt.py:234-245

python
def num_scaling_params(self):
    """
    Return all parameters (Chinchilla approach for cleaner scaling laws).
    Includes embedding parameters unlike Kaplan et al.
    """
    nparams = sum(p.numel() for p in self.parameters())
    return nparams

Technology Stack

  • PyTorch: Core deep learning framework with DDP for distributed training
  • FastAPI: Web server for chat interface and API endpoints
  • tiktoken/rustbpe: Efficient tokenization implementations
  • HuggingFace Datasets: Standardized dataset loading
  • wandb: Experiment tracking and logging
  • uvicorn: ASGI web server for production deployment

Design Principles

  1. Simplicity: Minimal dependencies, clean abstractions
  2. Hackability: Easy to understand and modify each component
  3. Performance: Efficient training and inference with modern optimizations
  4. Completeness: Full pipeline from data to deployment
  5. Reproducibility: Deterministic training with proper checkpointing

The architecture demonstrates that a complete ChatGPT-style system can be implemented in a remarkably compact and understandable codebase while maintaining competitive performance.


Sources:

  • nanochat/gpt.py (model architecture)
  • nanochat/engine.py (inference engine)
  • nanochat/dataloader.py (data pipeline)
  • scripts/base_train.py (training infrastructure)
  • tasks/common.py (evaluation framework)
Last updated: 1/10/2026