Training Pipeline Architecture

Training Pipeline Architecture

This document describes the complete training pipeline for nanochat, from raw text data to a deployed ChatGPT-style conversational AI model. The pipeline consists of five distinct stages, each building upon the previous one.

Pipeline Overview

graph TD A[Raw Text Data] --> B[1. Tokenizer Training] B --> C[2. Base Pretraining] C --> D[3. Mid-training] D --> E[4. Supervised Fine-tuning] E --> F[5. Reinforcement Learning] F --> G[Web Deployment] C --> C1[Base Model Evaluation] D --> D1[Chat Capability Assessment] E --> E1[SFT Performance Metrics] F --> F1[RL Task Performance] style A fill:#e1f5fe style B fill:#f3e5f5 style C fill:#e8f5e8 style D fill:#fff3e0 style E fill:#fce4ec style F fill:#f1f8e9 style G fill:#e0f2f1

The complete pipeline can be run with the speedrun.sh script for approximately $100 on 8x H100 GPUs in ~4 hours.

Source: speedrun.sh:1-15

bash
#!/bin/bash

# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.

# Example launch:
# bash speedrun.sh
# Or in screen session:
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh

Stage 1: Tokenizer Training

The first stage trains a BPE tokenizer on ~2B characters of pretraining data with vocabulary size of 65,536 tokens.

Source: speedrun.sh:55-65

bash
# Download the first ~2B characters of pretraining dataset
python -m nanochat.dataset -n 8

# train the tokenizer with vocab size 2**16 = 65536
python -m scripts.tok_train --max_chars=2000000000 --vocab_size=65536

# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval

Key Features

  • GPT-4 style regex splitting with byte fallback
  • Special tokens for conversation structure and tool use
  • Compression evaluation to validate tokenizer quality

Stage 2: Base Pretraining

Base pretraining creates a language model from scratch using the distributed data loader and modern training techniques.

Data Requirements

The d20 model (561M parameters) follows Chinchilla-optimal training with 20x parameter-to-token ratio:

Source: speedrun.sh:67-76

bash
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety.
python -m nanochat.dataset -n 240

Training Configuration

Source: scripts/base_train.py:82-95

python
# Model kwargs are derived from the desired depth of the model
num_layers = args.depth
model_dim = args.depth * args.aspect_ratio

def find_num_heads(model_dim, target_head_dim):
    # Find num_heads that divides model_dim evenly, with head_dim closest to target.
    ideal = max(1, round(model_dim / target_head_dim))
    for offset in range(model_dim):
        for candidate in [ideal + offset, ideal - offset]:
            if candidate > 0 and model_dim % candidate == 0:
                return candidate

Optimization Strategy

Base training uses a dual-optimizer approach:

Source: scripts/base_train.py:140-155

python
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(
    unembedding_lr=args.unembedding_lr * batch_lr_scale,
    embedding_lr=args.embedding_lr * batch_lr_scale,
    matrix_lr=args.matrix_lr * batch_lr_scale,
    weight_decay=args.weight_decay,
    adam_betas=adam_betas,
)
adamw_optimizer, muon_optimizer = optimizers

Learning Rate Scheduling

Source: scripts/base_train.py:210-220

python
def get_lr_multiplier(it):
    warmup_iters = round(args.warmup_ratio * num_iterations)
    warmdown_iters = round(args.warmdown_ratio * num_iterations)
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    elif it <= num_iterations - warmdown_iters:
        return 1.0
    else:
        progress = (num_iterations - it) / warmdown_iters
        return progress * 1.0 + (1 - progress) * args.final_lr_frac

Training Command

Source: speedrun.sh:82-84

bash
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --target_param_data_ratio=20 --run=$WANDB_RUN

Stage 3: Mid-training

Mid-training teaches the model conversation structure, tool use, and prepares it for instruction following.

Data Mixture

Source: scripts/mid_train.py:112-118

python
train_dataset = TaskMixture([
    SmolTalk(split="train"), # 460K rows of general conversations
    MMLU(subset="auxiliary_train", split="train"), # 100K multiple choice problems
    GSM8K(subset="main", split="train"), # 8K rows teaching math and calculator tool use
    CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity
    SimpleSpelling(size=200000, split="train"), # 200K rows of spelling tasks
    SpellingBee(size=80000, split="train"), # 80K rows of spelling bee tasks
]) # total: 848K rows

Dynamic Data Generation

Mid-training uses a custom data generator that handles variable-length conversations:

Source: scripts/mid_train.py:120-135

python
def mid_data_generator(split):
    global last_step, approx_progress
    assert split in {"train", "val"}, "split must be 'train' or 'val'"
    dataset = train_dataset if split == "train" else val_dataset
    
    # Use bos token as pad token for conversations
    bos_token_id = tokenizer.get_bos_token_id()
    needed_tokens = args.device_batch_size * args.max_seq_len + 1
    token_buffer = deque()

Learning Rate Scheduling

Source: scripts/mid_train.py:175-178

python
def get_lr_multiplier(progress):
    # first 80% of training: no decay, then linearly ramp down to 0.
    return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2

Stage 4: Supervised Fine-tuning (SFT)

SFT adapts the model for high-quality instruction following using a curated mixture of tasks.

Task Mixture

Source: scripts/chat_sft.py:85-95

python
train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"), # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
    GSM8K(subset="main", split="train"), # 8K rows
    SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
    CustomJSON(filepath=identity_conversations_filepath), # 1K synthetic identity
    SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling
    SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee
]) # 23K rows total

Conversation Tokenization

SFT uses sophisticated conversation rendering that supports tool use:

Source: nanochat/tokenizer.py:320-340

python
def render_conversation(self, conversation, max_tokens=2048):
    """
    Tokenize a single Chat conversation.
    Returns:
    - ids: list[int] of token ids for this rendered conversation
    - mask: list[int] of same length, mask = 1 for tokens that Assistant trains on.
    """
    ids, mask = [], []
    def add_tokens(token_ids, mask_val):
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        ids.extend(token_ids)
        mask.extend([mask_val] * len(token_ids))

Data Collation

Source: scripts/chat_sft.py:110-125

python
def sft_data_generator(dataset, batch_size):
    pad_token_id = tokenizer.encode_special("<|assistant_end|>")
    
    def collate_and_yield(batch):
        nrows = len(batch)
        ncols = max(len(ids) for ids, mask in batch) - 1
        inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
        targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index

Stage 5: Reinforcement Learning (Optional)

RL fine-tunes the model using reward signals, currently focusing on mathematical reasoning.

RL Training Setup

The RL stage uses policy gradient methods to improve performance on specific tasks:

Source: speedrun.sh:140-145

bash
# run reinforcement learning
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K

Training Infrastructure

Distributed Training

All training stages support distributed training across multiple GPUs:

Source: scripts/base_train.py:105-115

python
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd

Checkpointing & Resume

Source: nanochat/checkpoint_manager.py:15-35

python
def save_checkpoint(checkpoint_dir, step, model_state_dict, optimizer_state_dicts, meta_data, rank=0):
    """
    Save model and optimizer state to enable training resumption.
    Only rank 0 saves to avoid race conditions.
    """
    if rank != 0:
        return
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f"step_{step:06d}.pt")

Performance Monitoring

Each training stage includes comprehensive logging:

Source: scripts/base_train.py:380-395

python
# Calculate ETA based on average time per step
steps_done = step - 10
if steps_done > 0:
    avg_time_per_step = total_training_time / steps_done
    remaining_steps = num_iterations - step
    eta_seconds = remaining_steps * avg_time_per_step
    eta_str = f" | eta: {eta_seconds/60:.1f}m"

print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m{eta_str}")

Evaluation Throughout Training

CORE Metric

The CORE metric evaluates base model quality:

Source: scripts/base_train.py:275-285

python
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
    model.eval()
    with autocast_ctx:
        results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
    print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")

Chat Evaluation

Multi-task evaluation for instruction-following capabilities:

Source: scripts/chat_sft.py:170-180

python
# evaluate accuracy of the multiple choice tasks
if last_step or (step > 0 and step % args.eval_metrics_every == 0):
    model.eval()
    metrics = {}
    with torch.no_grad(), autocast_ctx:
        metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=args.device_batch_size*2)
        metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=args.device_batch_size*2)

Performance Results

The complete pipeline produces models with measurable improvements at each stage:

Source: README.md:45-60

text
| Metric          | BASE     | MID      | SFT      | RL       |
|-----------------|----------|----------|----------|----------|
| CORE            | 0.2219   | -        | -        | -        |
| ARC-Challenge   | -        | 0.2875   | 0.2807   | -        |
| ARC-Easy        | -        | 0.3561   | 0.3876   | -        |
| GSM8K           | -        | 0.0250   | 0.0455   | 0.0758   |
| HumanEval       | -        | 0.0671   | 0.0854   | -        |
| MMLU            | -        | 0.3111   | 0.3151   | -        |
| ChatCORE        | -        | 0.0730   | 0.0884   | -        |

Total wall clock time: 3h51m

Scaling to Larger Models

The pipeline supports larger models by adjusting key hyperparameters:

Source: README.md:74-85

bash
# For d26 model (~GPT-2 performance):
python -m nanochat.dataset -n 450  # More data shards
torchrun --nproc_per_node=8 -m scripts.base_train -- --depth=26 --device_batch_size=16
torchrun --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16

The training pipeline demonstrates how a complete ChatGPT-style model can be trained efficiently with careful data curation, modern optimization techniques, and staged training objectives.


Sources:

  • speedrun.sh (complete training pipeline)
  • scripts/base_train.py (base pretraining)
  • scripts/mid_train.py (mid-training)
  • scripts/chat_sft.py (supervised fine-tuning)
  • scripts/chat_rl.py (reinforcement learning)
  • nanochat/checkpoint_manager.py (training infrastructure)
Last updated: 1/10/2026