GPT Transformer Model

GPT Transformer Model

The GPT model in nanochat is a modern implementation of the Transformer architecture with several key improvements over the original GPT design. This document provides detailed information about the model architecture, components, and implementation.

Architecture Overview

Source: nanochat/gpt.py:1-15

python
"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
"""

Model Configuration

The model configuration is defined as a simple dataclass:

Source: nanochat/gpt.py:23-32

python
@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6 # number of query heads
    n_kv_head: int = 6 # number of key/value heads (GQA)
    n_embd: int = 768

Configuration Parameters

  • sequence_len: Maximum sequence length the model can handle
  • vocab_size: Size of the vocabulary (typically padded for efficiency)
  • n_layer: Number of transformer blocks
  • n_head: Number of attention heads (query heads)
  • n_kv_head: Number of key-value heads for Group Query Attention
  • n_embd: Model dimension (embedding size)

Core Components

1. Normalization

The model uses RMSNorm without learnable parameters:

Source: nanochat/gpt.py:35-37

python
def norm(x):
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))

This approach simplifies the model while maintaining training stability.

2. Rotary Position Embeddings (RoPE)

Instead of learned positional embeddings, the model uses RoPE for better length generalization:

Source: nanochat/gpt.py:40-47

python
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)

RoPE Precomputation

Rotary embeddings are precomputed during model initialization:

Source: nanochat/gpt.py:195-210

python
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
    # autodetect the device from model embeddings
    if device is None:
        device = self.transformer.wte.weight.device
    # stride the channels
    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    # stride the time steps
    t = torch.arange(seq_len, dtype=torch.float32, device=device)
    # calculate the rotation frequencies at each (time, channel) pair
    freqs = torch.outer(t, inv_freq)
    cos, sin = freqs.cos(), freqs.sin()
    cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
    cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims
    return cos, sin

3. Causal Self-Attention with GQA

The attention mechanism supports Group Query Attention for efficient inference:

Source: nanochat/gpt.py:49-63

python
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

Attention Forward Pass

The attention mechanism handles both training and inference scenarios:

Source: nanochat/gpt.py:65-85

python
def forward(self, x, cos_sin, kv_cache):
    B, T, C = x.size()

    # Project the input to get queries, keys, and values
    q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
    k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
    v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

    # Apply Rotary Embeddings to queries and keys
    cos, sin = cos_sin
    q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
    q, k = norm(q), norm(k) # QK norm
    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

    # Apply KV cache: insert current k,v into cache, get the full view so far
    if kv_cache is not None:
        k, v = kv_cache.insert_kv(self.layer_idx, k, v)

4. Multi-Layer Perceptron (MLP)

The MLP uses ReLU² activation instead of traditional GELU:

Source: nanochat/gpt.py:109-117

python
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()  # ReLU² activation
        x = self.c_proj(x)
        return x

The ReLU² activation provides computational benefits while maintaining expressiveness.

5. Transformer Block

Each transformer block combines attention and MLP with pre-normalization:

Source: nanochat/gpt.py:120-128

python
class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)

    def forward(self, x, cos_sin, kv_cache):
        x = x + self.attn(norm(x), cos_sin, kv_cache)  # Pre-norm residual
        x = x + self.mlp(norm(x))                      # Pre-norm residual
        return x

Main GPT Model

Model Structure

Source: nanochat/gpt.py:131-155

python
class GPT(nn.Module):
    def __init__(self, config, pad_vocab_size_to=64):
        super().__init__()
        self.config = config
        # Pad vocab_size for DDP efficiency
        padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
        
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(padded_vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
        
        # Precompute rotary embeddings
        self.rotary_seq_len = config.sequence_len * 10  # 10X over-compute for safety
        head_dim = config.n_embd // config.n_head
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)

Weight Initialization

The model uses a sophisticated initialization strategy for stable training:

Source: nanochat/gpt.py:161-188

python
def init_weights(self):
    """
    Initialize the full model for maximum clarity.

    wte (embedding):     normal, std=1.0
    lm_head:             normal, std=0.001
    for each block:
        attn.c_q:        uniform, std=1/sqrt(n_embd)
        attn.c_k:        uniform, std=1/sqrt(n_embd)
        attn.c_v:        uniform, std=1/sqrt(n_embd)
        attn.c_proj:     zeros
        mlp.c_fc:        uniform, std=1/sqrt(n_embd)
        mlp.c_proj:      zeros
    """

    # Embedding and unembedding
    torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
    torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)

    # Transformer blocks: uniform init with bound = sqrt(3) * std
    n_embd = self.config.n_embd
    s = 3**0.5 * n_embd**-0.5  # sqrt(3) multiplier for uniform distribution
    for block in self.transformer.h:
        torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
        torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
        torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
        torch.nn.init.zeros_(block.attn.c_proj.weight)  # projections are zero
        torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
        torch.nn.init.zeros_(block.mlp.c_proj.weight)

Forward Pass

Source: nanochat/gpt.py:280-310

python
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
    B, T = idx.size()

    # Grab the rotary embeddings for the current sequence length
    assert T <= self.cos.size(1), f"Sequence length grew beyond rotary cache: {T} > {self.cos.size(1)}"
    T0 = 0 if kv_cache is None else kv_cache.get_pos()
    cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]

    # Forward the trunk of the Transformer
    x = self.transformer.wte(idx)
    x = norm(x)  # norm after token embedding
    for block in self.transformer.h:
        x = block(x, cos_sin, kv_cache)
    x = norm(x)

    # Forward the lm_head (compute logits)
    softcap = 15  # smoothly cap the logits to the range [-softcap, softcap]
    logits = self.lm_head(x)
    logits = logits[..., :self.config.vocab_size]  # slice to remove padding
    logits = logits.float()  # switch to fp32 for logit softcap and loss computation
    logits = softcap * torch.tanh(logits / softcap)  # squash the logits

    if targets is not None:
        # training: compute and return the loss
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), 
                              ignore_index=-1, reduction=loss_reduction)
        return loss
    else:
        # inference: return logits
        return logits

Performance Analysis

FLOP Estimation

The model includes detailed FLOP counting for performance analysis:

Source: nanochat/gpt.py:218-230

python
def estimate_flops(self):
    """
    Return the estimated FLOPs per token for the model (forward + backward).
    Each matmul weight parameter contributes 6 FLOPs total (2 forward + 4 backward).
    """
    nparams = sum(p.numel() for p in self.parameters())
    nparams_embedding = self.transformer.wte.weight.numel()
    l, h, q, t = (self.config.n_layer, self.config.n_head, 
                 self.config.n_embd // self.config.n_head, self.config.sequence_len)
    num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
    return num_flops_per_token

Parameter Counting

Source: nanochat/gpt.py:234-245

python
def num_scaling_params(self):
    """
    Return all parameters (Chinchilla approach for cleaner scaling laws).
    Includes embedding parameters unlike Kaplan et al.
    """
    nparams = sum(p.numel() for p in self.parameters())
    return nparams

Optimizer Setup

The model uses a dual-optimizer strategy:

Source: nanochat/gpt.py:247-280

python
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, 
                    weight_decay=0.0, adam_betas=(0.8, 0.95)):
    # Separate parameters into 3 groups
    matrix_params = list(self.transformer.h.parameters())
    embedding_params = list(self.transformer.wte.parameters())
    lm_head_params = list(self.lm_head.parameters())
    
    # Scale LR by ∝1/√dmodel for AdamW parameters
    model_dim = self.config.n_embd
    dmodel_lr_scale = (model_dim / 768) ** -0.5
    
    # Create AdamW optimizer for embedding and lm_head
    adam_groups = [
        dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
        dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
    ]
    adamw_optimizer = AdamWFactory(adam_groups, ...)
    
    # Create Muon optimizer for linear layers
    muon_optimizer = MuonFactory(matrix_params, lr=matrix_lr, momentum=0.95)
    
    return [adamw_optimizer, muon_optimizer]

Key Design Decisions

1. No Bias Terms: Simplifies the model while maintaining performance

2. Pre-normalization: Applies RMSNorm before attention and MLP layers

3. Untied Weights: Separate weights for input embeddings and output projection

4. QK Normalization: Improves training stability

5. Group Query Attention: Reduces memory usage during inference

6. ReLU² Activation: Computationally efficient alternative to GELU

7. Rotary Embeddings: Better length generalization than learned positions

This modern Transformer implementation balances simplicity, performance, and training stability while incorporating state-of-the-art techniques for efficient inference.


Sources:

  • nanochat/gpt.py (complete model implementation)
  • Architecture design decisions and optimizations
Last updated: 1/10/2026