GPT Transformer Model
GPT Transformer Model
The GPT model in nanochat is a modern implementation of the Transformer architecture with several key improvements over the original GPT design. This document provides detailed information about the model architecture, components, and implementation.
Architecture Overview
Source: nanochat/gpt.py:1-15
"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""
Model Configuration
The model configuration is defined as a simple dataclass:
Source: nanochat/gpt.py:23-32
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
Configuration Parameters
- sequence_len: Maximum sequence length the model can handle
- vocab_size: Size of the vocabulary (typically padded for efficiency)
- n_layer: Number of transformer blocks
- n_head: Number of attention heads (query heads)
- n_kv_head: Number of key-value heads for Group Query Attention
- n_embd: Model dimension (embedding size)
Core Components
1. Normalization
The model uses RMSNorm without learnable parameters:
Source: nanochat/gpt.py:35-37
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
This approach simplifies the model while maintaining training stability.
2. Rotary Position Embeddings (RoPE)
Instead of learned positional embeddings, the model uses RoPE for better length generalization:
Source: nanochat/gpt.py:40-47
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
RoPE Precomputation
Rotary embeddings are precomputed during model initialization:
Source: nanochat/gpt.py:195-210
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
# stride the channels
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
# stride the time steps
t = torch.arange(seq_len, dtype=torch.float32, device=device)
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims
return cos, sin
3. Causal Self-Attention with GQA
The attention mechanism supports Group Query Attention for efficient inference:
Source: nanochat/gpt.py:49-63
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
Attention Forward Pass
The attention mechanism handles both training and inference scenarios:
Source: nanochat/gpt.py:65-85
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
4. Multi-Layer Perceptron (MLP)
The MLP uses ReLU² activation instead of traditional GELU:
Source: nanochat/gpt.py:109-117
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # ReLU² activation
x = self.c_proj(x)
return x
The ReLU² activation provides computational benefits while maintaining expressiveness.
5. Transformer Block
Each transformer block combines attention and MLP with pre-normalization:
Source: nanochat/gpt.py:120-128
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache) # Pre-norm residual
x = x + self.mlp(norm(x)) # Pre-norm residual
return x
Main GPT Model
Model Structure
Source: nanochat/gpt.py:131-155
class GPT(nn.Module):
def __init__(self, config, pad_vocab_size_to=64):
super().__init__()
self.config = config
# Pad vocab_size for DDP efficiency
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# Precompute rotary embeddings
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute for safety
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
Weight Initialization
The model uses a sophisticated initialization strategy for stable training:
Source: nanochat/gpt.py:161-188
def init_weights(self):
"""
Initialize the full model for maximum clarity.
wte (embedding): normal, std=1.0
lm_head: normal, std=0.001
for each block:
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
"""
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks: uniform init with bound = sqrt(3) * std
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier for uniform distribution
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
Forward Pass
Source: nanochat/gpt.py:280-310
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length
assert T <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T} > {self.cos.size(1)}"
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x) # norm after token embedding
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x)
logits = logits[..., :self.config.vocab_size] # slice to remove padding
logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits
if targets is not None:
# training: compute and return the loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference: return logits
return logits
Performance Analysis
FLOP Estimation
The model includes detailed FLOP counting for performance analysis:
Source: nanochat/gpt.py:218-230
def estimate_flops(self):
"""
Return the estimated FLOPs per token for the model (forward + backward).
Each matmul weight parameter contributes 6 FLOPs total (2 forward + 4 backward).
"""
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
l, h, q, t = (self.config.n_layer, self.config.n_head,
self.config.n_embd // self.config.n_head, self.config.sequence_len)
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
Parameter Counting
Source: nanochat/gpt.py:234-245
def num_scaling_params(self):
"""
Return all parameters (Chinchilla approach for cleaner scaling laws).
Includes embedding parameters unlike Kaplan et al.
"""
nparams = sum(p.numel() for p in self.parameters())
return nparams
Optimizer Setup
The model uses a dual-optimizer strategy:
Source: nanochat/gpt.py:247-280
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
weight_decay=0.0, adam_betas=(0.8, 0.95)):
# Separate parameters into 3 groups
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
# Scale LR by ∝1/√dmodel for AdamW parameters
model_dim = self.config.n_embd
dmodel_lr_scale = (model_dim / 768) ** -0.5
# Create AdamW optimizer for embedding and lm_head
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
adamw_optimizer = AdamWFactory(adam_groups, ...)
# Create Muon optimizer for linear layers
muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)
return [adamw_optimizer, muon_optimizer]
Key Design Decisions
1. No Bias Terms: Simplifies the model while maintaining performance
2. Pre-normalization: Applies RMSNorm before attention and MLP layers
3. Untied Weights: Separate weights for input embeddings and output projection
4. QK Normalization: Improves training stability
5. Group Query Attention: Reduces memory usage during inference
6. ReLU² Activation: Computationally efficient alternative to GELU
7. Rotary Embeddings: Better length generalization than learned positions
This modern Transformer implementation balances simplicity, performance, and training stability while incorporating state-of-the-art techniques for efficient inference.
Sources:
- nanochat/gpt.py (complete model implementation)
- Architecture design decisions and optimizations