Training Pipeline Architecture
Training Pipeline Architecture
This document describes the complete training pipeline for nanochat, from raw text data to a deployed ChatGPT-style conversational AI model. The pipeline consists of five distinct stages, each building upon the previous one.
Pipeline Overview
The complete pipeline can be run with the speedrun.sh script for approximately $100 on 8x H100 GPUs in ~4 hours.
Source: speedrun.sh:1-15
#!/bin/bash
# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# Example launch:
# bash speedrun.sh
# Or in screen session:
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
Stage 1: Tokenizer Training
The first stage trains a BPE tokenizer on ~2B characters of pretraining data with vocabulary size of 65,536 tokens.
Source: speedrun.sh:55-65
# Download the first ~2B characters of pretraining dataset
python -m nanochat.dataset -n 8
# train the tokenizer with vocab size 2**16 = 65536
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
Key Features
- GPT-4 style regex splitting with byte fallback
- Special tokens for conversation structure and tool use
- Compression evaluation to validate tokenizer quality
Stage 2: Base Pretraining
Base pretraining creates a language model from scratch using the distributed data loader and modern training techniques.
Data Requirements
The d20 model (561M parameters) follows Chinchilla-optimal training with 20x parameter-to-token ratio:
Source: speedrun.sh:67-76
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety.
python -m nanochat.dataset -n 240
Training Configuration
Source: scripts/base_train.py:82-95
# Model kwargs are derived from the desired depth of the model
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio
def find_num_heads(model_dim, target_head_dim):
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
ideal = max(1, round(model_dim / target_head_dim))
for offset in range(model_dim):
for candidate in [ideal + offset, ideal - offset]:
if candidate > 0 and model_dim % candidate == 0:
return candidate
Optimization Strategy
Base training uses a dual-optimizer approach:
Source: scripts/base_train.py:140-155
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(
unembedding_lr=args.unembedding_lr * batch_lr_scale,
embedding_lr=args.embedding_lr * batch_lr_scale,
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=args.weight_decay,
adam_betas=adam_betas,
)
adamw_optimizer, muon_optimizer = optimizers
Learning Rate Scheduling
Source: scripts/base_train.py:210-220
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
Training Command
Source: speedrun.sh:82-84
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN
Stage 3: Mid-training
Mid-training teaches the model conversation structure, tool use, and prepares it for instruction following.
Data Mixture
Source: scripts/mid_train.py:112-118
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K multiple choice problems
GSM8K(subset="main", split="train"), # 8K rows teaching math and calculator tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity
SimpleSpelling(size=200000, split="train"), # 200K rows of spelling tasks
SpellingBee(size=80000, split="train"), # 80K rows of spelling bee tasks
]) # total: 848K rows
Dynamic Data Generation
Mid-training uses a custom data generator that handles variable-length conversations:
Source: scripts/mid_train.py:120-135
def mid_data_generator(split):
global last_step, approx_progress
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
# Use bos token as pad token for conversations
bos_token_id = tokenizer.get_bos_token_id()
needed_tokens = args.device_batch_size * args.max_seq_len + 1
token_buffer = deque()
Learning Rate Scheduling
Source: scripts/mid_train.py:175-178
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
Stage 4: Supervised Fine-tuning (SFT)
SFT adapts the model for high-quality instruction following using a curated mixture of tasks.
Task Mixture
Source: scripts/chat_sft.py:85-95
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K synthetic identity
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee
]) # 23K rows total
Conversation Tokenization
SFT uses sophisticated conversation rendering that supports tool use:
Source: nanochat/tokenizer.py:320-340
def render_conversation(self, conversation, max_tokens=2048):
"""
Tokenize a single Chat conversation.
Returns:
- ids: list[int] of token ids for this rendered conversation
- mask: list[int] of same length, mask = 1 for tokens that Assistant trains on.
"""
ids, mask = [], []
def add_tokens(token_ids, mask_val):
if isinstance(token_ids, int):
token_ids = [token_ids]
ids.extend(token_ids)
mask.extend([mask_val] * len(token_ids))
Data Collation
Source: scripts/chat_sft.py:110-125
def sft_data_generator(dataset, batch_size):
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
def collate_and_yield(batch):
nrows = len(batch)
ncols = max(len(ids) for ids, mask in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
Stage 5: Reinforcement Learning (Optional)
RL fine-tunes the model using reward signals, currently focusing on mathematical reasoning.
RL Training Setup
The RL stage uses policy gradient methods to improve performance on specific tasks:
Source: speedrun.sh:140-145
# run reinforcement learning
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
Training Infrastructure
Distributed Training
All training stages support distributed training across multiple GPUs:
Source: scripts/base_train.py:105-115
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
Checkpointing & Resume
Source: nanochat/checkpoint_manager.py:15-35
def save_checkpoint(checkpoint_dir, step, model_state_dict, optimizer_state_dicts, meta_data, rank=0):
"""
Save model and optimizer state to enable training resumption.
Only rank 0 saves to avoid race conditions.
"""
if rank != 0:
return
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"step_{step:06d}.pt")
Performance Monitoring
Each training stage includes comprehensive logging:
Source: scripts/base_train.py:380-395
# Calculate ETA based on average time per step
steps_done = step - 10
if steps_done > 0:
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f" | eta: {eta_seconds/60:.1f}m"
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")
Evaluation Throughout Training
CORE Metric
The CORE metric evaluates base model quality:
Source: scripts/base_train.py:275-285
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
Chat Evaluation
Multi-task evaluation for instruction-following capabilities:
Source: scripts/chat_sft.py:170-180
# evaluate accuracy of the multiple choice tasks
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
model.eval()
metrics = {}
with torch.no_grad(), autocast_ctx:
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2)
Performance Results
The complete pipeline produces models with measurable improvements at each stage:
Source: README.md:45-60
| Metric | BASE | MID | SFT | RL |
|-----------------|----------|----------|----------|----------|
| CORE | 0.2219 | - | - | - |
| ARC-Challenge | - | 0.2875 | 0.2807 | - |
| ARC-Easy | - | 0.3561 | 0.3876 | - |
| GSM8K | - | 0.0250 | 0.0455 | 0.0758 |
| HumanEval | - | 0.0671 | 0.0854 | - |
| MMLU | - | 0.3111 | 0.3151 | - |
| ChatCORE | - | 0.0730 | 0.0884 | - |
Total wall clock time: 3h51m
Scaling to Larger Models
The pipeline supports larger models by adjusting key hyperparameters:
Source: README.md:74-85
# For d26 model (~GPT-2 performance):
python -m nanochat.dataset -n 450 # More data shards
torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
torchrun --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
The training pipeline demonstrates how a complete ChatGPT-style model can be trained efficiently with careful data curation, modern optimization techniques, and staged training objectives.
Sources:
- speedrun.sh (complete training pipeline)
- scripts/base_train.py (base pretraining)
- scripts/mid_train.py (mid-training)
- scripts/chat_sft.py (supervised fine-tuning)
- scripts/chat_rl.py (reinforcement learning)
- nanochat/checkpoint_manager.py (training infrastructure)