Training Scripts

Training Scripts

NanoChat provides a comprehensive set of training scripts that handle each phase of the model development pipeline. This document details the training scripts, their usage, and key parameters.

Overview

Source: README.md:155-170

text
├── scripts
│   ├── base_train.py               # Base model: train
│   ├── chat_sft.py                 # Chat model: train SFT
│   ├── chat_rl.py                  # Chat model: reinforcement learning
│   ├── mid_train.py                # Chat model: midtraining
│   ├── tok_train.py                # Tokenizer: train it

The training pipeline follows a specific order:

  1. Tokenizer Training (tok_train.py)
  2. Base Pretraining (base_train.py)
  3. Mid-training (mid_train.py)
  4. Supervised Fine-tuning (chat_sft.py)
  5. Reinforcement Learning (chat_rl.py) - Optional

Tokenizer Training

tok_train.py

Trains a BPE tokenizer on the pretraining dataset.

Usage

bash
# Basic usage
python -m scripts.tok_train --vocab_size=65536

# With custom parameters
python -m scripts.tok_train \
    --max_chars=10000000000 \
    --vocab_size=32768 \
    --doc_cap=10000

Key Parameters

  • --max_chars: Maximum characters to train on (default: 10B)
  • --vocab_size: Vocabulary size (default: 32,768)
  • --doc_cap: Maximum characters per document (default: 10,000)

Features

  • GPT-4 style regex splitting for consistent tokenization
  • Byte-level fallback to handle any input
  • Special tokens for conversation structure
  • Progress monitoring with character count tracking

Base Pretraining

base_train.py

Trains the foundational language model on unlabeled text data.

Usage

bash
# Single GPU
python -m scripts.base_train --depth=20 --target_param_data_ratio=20

# Multi-GPU (8 GPUs)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=20 --target_param_data_ratio=20 --run=my_run

# Custom model size
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=26 --device_batch_size=16 --aspect_ratio=64

Model Architecture Parameters

  • --depth: Number of transformer layers (default: 20)
  • --aspect_ratio: Model dimension scaling factor (default: 64)
  • --head_dim: Target attention head dimension (default: 128)
  • --max_seq_len: Maximum sequence length (default: 2048)

Training Horizon (one required)

  • --num_iterations: Explicit number of optimization steps
  • --target_flops: Target total FLOPs for training
  • --target_param_data_ratio: Data-to-parameter ratio (Chinchilla=20)

Optimization Parameters

  • --device_batch_size: Per-device batch size (default: 32)
  • --total_batch_size: Total batch size in tokens (default: 524,288)
  • --embedding_lr: AdamW learning rate for embeddings (default: 0.3)
  • --unembedding_lr: AdamW learning rate for output layer (default: 0.004)
  • --matrix_lr: Muon learning rate for transformer layers (default: 0.02)
  • --weight_decay: L2 regularization for AdamW (default: 0.0)

Learning Rate Schedule

  • --warmup_ratio: Fraction of training for LR warmup (default: 0.0)
  • --warmdown_ratio: Fraction of training for LR decay (default: 0.4)
  • --final_lr_frac: Final LR as fraction of initial LR (default: 0.0)

Evaluation Parameters

  • --eval_every: Evaluate validation loss every N steps (default: 250)
  • --eval_tokens: Number of tokens for validation evaluation
  • --core_metric_every: Evaluate CORE metric every N steps (default: 2000)
  • --sample_every: Sample from model every N steps (default: 2000)

Features

  • Automatic gradient accumulation to achieve target batch size
  • Learning rate scaling based on batch size
  • Distributed training with PyTorch DDP
  • Resume capability with checkpoint states
  • Real-time monitoring with ETA estimation
  • Comprehensive evaluation with CORE metric and sampling

Mid-training

mid_train.py

Intermediate training phase that introduces conversation structure and tool use.

Usage

bash
# Multi-GPU mid-training
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train

# With custom parameters
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- \
    --device_batch_size=16 --init_lr_frac=1.0 --run=mid_run

Key Parameters

  • --device_batch_size: Per-device batch size (default: 32)
  • --total_batch_size: Total batch size in tokens (default: 524,288)
  • --max_seq_len: Maximum sequence length (default: 2048)
  • --init_lr_frac: Initial LR as fraction of base model LR (default: 1.0)

Training Data

Mid-training uses a TaskMixture including:

  • SmolTalk: General conversations (460K rows)
  • MMLU: Multiple choice problems (100K rows)
  • GSM8K: Math problems with tool use (8K rows)
  • Identity conversations: Synthetic personality data (1K rows × 2)
  • Spelling tasks: Letter counting and spelling (280K rows)

Features

  • Dynamic data generation with variable-length conversations
  • Progress tracking based on dataset completion
  • Tool use introduction through GSM8K problems
  • Learning rate decay in final 20% of training

Supervised Fine-tuning

chat_sft.py

Fine-tunes the model for instruction following and helpful dialogue.

Usage

bash
# Multi-GPU SFT
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft

# With custom parameters
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- \
    --num_epochs=2 --device_batch_size=4 --eval_metrics_every=100

Training Parameters

  • --num_epochs: Number of training epochs (default: 1)
  • --num_iterations: Override with explicit iteration count
  • --device_batch_size: Per-device batch size (default: 4)
  • --target_examples_per_step: Target examples per optimization step (default: 32)

Optimization

  • --embedding_lr: AdamW learning rate for embeddings (default: 0.2)
  • --unembedding_lr: AdamW learning rate for output layer (default: 0.004)
  • --matrix_lr: Muon learning rate for transformer layers (default: 0.02)
  • --init_lr_frac: Initial LR fraction (default: 0.02)

Evaluation

  • --eval_every: Evaluate validation loss every N steps (default: 100)
  • --eval_steps: Number of batches for validation (default: 100)
  • --eval_metrics_every: Evaluate accuracy every N steps (default: 200)
  • --eval_metrics_max_problems: Max problems per evaluation (default: 1024)

Training Data

SFT uses a curated TaskMixture:

  • ARC-Easy/Challenge: Science reasoning (3.4K rows)
  • GSM8K: Mathematical problems (8K rows)
  • SmolTalk: Conversations (10K rows)
  • Identity conversations: Personality (1K rows)
  • Spelling tasks: Basic language skills (600 rows)

Features

  • Conversation tokenization with proper masking
  • Real-time evaluation on multiple choice tasks
  • Gradient accumulation for effective large batch training
  • Loss masking to only supervise assistant responses

Reinforcement Learning

chat_rl.py

Applies reinforcement learning to improve specific capabilities.

Usage

bash
# RL training on GSM8K
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl

Features

  • Policy gradient methods for reward-based optimization
  • Focus on mathematical reasoning (GSM8K dataset)
  • Reward signal integration from task evaluation
  • Policy/value network training

Common Parameters Across Scripts

Device and Precision

  • --device_type: Device type (cuda/cpu/mps, auto-detected if empty)
  • --dtype: Precision (bfloat16/float32, default: bfloat16)

Logging and Monitoring

  • --run: Wandb run name ("dummy" disables wandb)

Model Loading (mid-training, SFT, RL)

  • --source: Source checkpoint (base/mid/sft)
  • --model_tag: Specific model tag to load
  • --model_step: Specific checkpoint step to load

Usage Examples

Complete Pipeline

bash
# 1. Train tokenizer
python -m scripts.tok_train --vocab_size=65536

# 2. Base pretraining
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=20 --target_param_data_ratio=20

# 3. Mid-training  
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train

# 4. Supervised fine-tuning
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft

# 5. Reinforcement learning (optional)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl

Memory-Constrained Training

bash
# For 40GB VRAM
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=20 --device_batch_size=16

# For 24GB VRAM  
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=16 --device_batch_size=8

# For single GPU
python -m scripts.base_train --depth=12 --device_batch_size=4

Model Size Scaling

bash
# Larger model (d26, ~GPT-2 performance)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=26 --device_batch_size=16

# Smaller model (d12)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
    --depth=12 --device_batch_size=64

Best Practices

Memory Management

  1. Monitor VRAM usage and adjust device_batch_size
  2. Use gradient accumulation instead of reducing total batch size
  3. Enable mixed precision with bfloat16 for efficiency

Hyperparameter Tuning

  1. Scale learning rates with batch size (automatic)
  2. Adjust warmup/warmdown ratios for training length
  3. Monitor validation metrics to prevent overfitting

Distributed Training

  1. Use torchrun for multi-GPU setups
  2. Ensure data shards are sufficient for training length
  3. Monitor load balancing across GPUs

The training scripts provide a complete, configurable pipeline for developing capable language models from scratch with careful attention to efficiency and best practices.


Sources:

  • scripts/base_train.py, scripts/mid_train.py, scripts/chat_sft.py, scripts/chat_rl.py, scripts/tok_train.py
  • README.md (script documentation)
  • speedrun.sh (usage examples)
Last updated: 1/10/2026