Training Scripts
Training Scripts
NanoChat provides a comprehensive set of training scripts that handle each phase of the model development pipeline. This document details the training scripts, their usage, and key parameters.
Overview
Source: README.md:155-170
├── scripts
│ ├── base_train.py # Base model: train
│ ├── chat_sft.py # Chat model: train SFT
│ ├── chat_rl.py # Chat model: reinforcement learning
│ ├── mid_train.py # Chat model: midtraining
│ ├── tok_train.py # Tokenizer: train it
The training pipeline follows a specific order:
- Tokenizer Training (
tok_train.py) - Base Pretraining (
base_train.py) - Mid-training (
mid_train.py) - Supervised Fine-tuning (
chat_sft.py) - Reinforcement Learning (
chat_rl.py) - Optional
Tokenizer Training
tok_train.py
Trains a BPE tokenizer on the pretraining dataset.
Usage
# Basic usage
python -m scripts.tok_train --vocab_size=65536
# With custom parameters
python -m scripts.tok_train \
--max_chars=10000000000 \
--vocab_size=32768 \
--doc_cap=10000
Key Parameters
--max_chars: Maximum characters to train on (default: 10B)--vocab_size: Vocabulary size (default: 32,768)--doc_cap: Maximum characters per document (default: 10,000)
Features
- GPT-4 style regex splitting for consistent tokenization
- Byte-level fallback to handle any input
- Special tokens for conversation structure
- Progress monitoring with character count tracking
Base Pretraining
base_train.py
Trains the foundational language model on unlabeled text data.
Usage
# Single GPU
python -m scripts.base_train --depth=20 --target_param_data_ratio=20
# Multi-GPU (8 GPUs)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=20 --target_param_data_ratio=20 --run=my_run
# Custom model size
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 --device_batch_size=16 --aspect_ratio=64
Model Architecture Parameters
--depth: Number of transformer layers (default: 20)--aspect_ratio: Model dimension scaling factor (default: 64)--head_dim: Target attention head dimension (default: 128)--max_seq_len: Maximum sequence length (default: 2048)
Training Horizon (one required)
--num_iterations: Explicit number of optimization steps--target_flops: Target total FLOPs for training--target_param_data_ratio: Data-to-parameter ratio (Chinchilla=20)
Optimization Parameters
--device_batch_size: Per-device batch size (default: 32)--total_batch_size: Total batch size in tokens (default: 524,288)--embedding_lr: AdamW learning rate for embeddings (default: 0.3)--unembedding_lr: AdamW learning rate for output layer (default: 0.004)--matrix_lr: Muon learning rate for transformer layers (default: 0.02)--weight_decay: L2 regularization for AdamW (default: 0.0)
Learning Rate Schedule
--warmup_ratio: Fraction of training for LR warmup (default: 0.0)--warmdown_ratio: Fraction of training for LR decay (default: 0.4)--final_lr_frac: Final LR as fraction of initial LR (default: 0.0)
Evaluation Parameters
--eval_every: Evaluate validation loss every N steps (default: 250)--eval_tokens: Number of tokens for validation evaluation--core_metric_every: Evaluate CORE metric every N steps (default: 2000)--sample_every: Sample from model every N steps (default: 2000)
Features
- Automatic gradient accumulation to achieve target batch size
- Learning rate scaling based on batch size
- Distributed training with PyTorch DDP
- Resume capability with checkpoint states
- Real-time monitoring with ETA estimation
- Comprehensive evaluation with CORE metric and sampling
Mid-training
mid_train.py
Intermediate training phase that introduces conversation structure and tool use.
Usage
# Multi-GPU mid-training
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train
# With custom parameters
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- \
--device_batch_size=16 --init_lr_frac=1.0 --run=mid_run
Key Parameters
--device_batch_size: Per-device batch size (default: 32)--total_batch_size: Total batch size in tokens (default: 524,288)--max_seq_len: Maximum sequence length (default: 2048)--init_lr_frac: Initial LR as fraction of base model LR (default: 1.0)
Training Data
Mid-training uses a TaskMixture including:
- SmolTalk: General conversations (460K rows)
- MMLU: Multiple choice problems (100K rows)
- GSM8K: Math problems with tool use (8K rows)
- Identity conversations: Synthetic personality data (1K rows × 2)
- Spelling tasks: Letter counting and spelling (280K rows)
Features
- Dynamic data generation with variable-length conversations
- Progress tracking based on dataset completion
- Tool use introduction through GSM8K problems
- Learning rate decay in final 20% of training
Supervised Fine-tuning
chat_sft.py
Fine-tunes the model for instruction following and helpful dialogue.
Usage
# Multi-GPU SFT
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
# With custom parameters
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- \
--num_epochs=2 --device_batch_size=4 --eval_metrics_every=100
Training Parameters
--num_epochs: Number of training epochs (default: 1)--num_iterations: Override with explicit iteration count--device_batch_size: Per-device batch size (default: 4)--target_examples_per_step: Target examples per optimization step (default: 32)
Optimization
--embedding_lr: AdamW learning rate for embeddings (default: 0.2)--unembedding_lr: AdamW learning rate for output layer (default: 0.004)--matrix_lr: Muon learning rate for transformer layers (default: 0.02)--init_lr_frac: Initial LR fraction (default: 0.02)
Evaluation
--eval_every: Evaluate validation loss every N steps (default: 100)--eval_steps: Number of batches for validation (default: 100)--eval_metrics_every: Evaluate accuracy every N steps (default: 200)--eval_metrics_max_problems: Max problems per evaluation (default: 1024)
Training Data
SFT uses a curated TaskMixture:
- ARC-Easy/Challenge: Science reasoning (3.4K rows)
- GSM8K: Mathematical problems (8K rows)
- SmolTalk: Conversations (10K rows)
- Identity conversations: Personality (1K rows)
- Spelling tasks: Basic language skills (600 rows)
Features
- Conversation tokenization with proper masking
- Real-time evaluation on multiple choice tasks
- Gradient accumulation for effective large batch training
- Loss masking to only supervise assistant responses
Reinforcement Learning
chat_rl.py
Applies reinforcement learning to improve specific capabilities.
Usage
# RL training on GSM8K
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl
Features
- Policy gradient methods for reward-based optimization
- Focus on mathematical reasoning (GSM8K dataset)
- Reward signal integration from task evaluation
- Policy/value network training
Common Parameters Across Scripts
Device and Precision
--device_type: Device type (cuda/cpu/mps, auto-detected if empty)--dtype: Precision (bfloat16/float32, default: bfloat16)
Logging and Monitoring
--run: Wandb run name ("dummy" disables wandb)
Model Loading (mid-training, SFT, RL)
--source: Source checkpoint (base/mid/sft)--model_tag: Specific model tag to load--model_step: Specific checkpoint step to load
Usage Examples
Complete Pipeline
# 1. Train tokenizer
python -m scripts.tok_train --vocab_size=65536
# 2. Base pretraining
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=20 --target_param_data_ratio=20
# 3. Mid-training
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train
# 4. Supervised fine-tuning
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
# 5. Reinforcement learning (optional)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl
Memory-Constrained Training
# For 40GB VRAM
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=20 --device_batch_size=16
# For 24GB VRAM
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=16 --device_batch_size=8
# For single GPU
python -m scripts.base_train --depth=12 --device_batch_size=4
Model Size Scaling
# Larger model (d26, ~GPT-2 performance)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 --device_batch_size=16
# Smaller model (d12)
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=12 --device_batch_size=64
Best Practices
Memory Management
- Monitor VRAM usage and adjust
device_batch_size - Use gradient accumulation instead of reducing total batch size
- Enable mixed precision with bfloat16 for efficiency
Hyperparameter Tuning
- Scale learning rates with batch size (automatic)
- Adjust warmup/warmdown ratios for training length
- Monitor validation metrics to prevent overfitting
Distributed Training
- Use torchrun for multi-GPU setups
- Ensure data shards are sufficient for training length
- Monitor load balancing across GPUs
The training scripts provide a complete, configurable pipeline for developing capable language models from scratch with careful attention to efficiency and best practices.
Sources:
- scripts/base_train.py, scripts/mid_train.py, scripts/chat_sft.py, scripts/chat_rl.py, scripts/tok_train.py
- README.md (script documentation)
- speedrun.sh (usage examples)