Inference Engine

Inference Engine

The NanoChat inference engine provides efficient autoregressive text generation with advanced features like KV caching, tool use integration, and multi-sample generation. This document details the architecture and implementation of the inference system.

Overview

Source: nanochat/engine.py:1-15

python
"""
Engine for efficient inference of our models.

Everything works around token sequences:
- The user can send token sequences to the engine
- The engine returns the next token

Notes:
- The engine knows nothing about tokenization, it's purely token id sequences.

The whole thing is made as efficient as possible.
"""

The engine architecture separates concerns cleanly:

  • Token-level operations: Core generation logic works with token IDs
  • KV cache management: Efficient memory usage during autoregressive generation
  • Tool use integration: Python calculator and other tool capabilities
  • Multi-sample generation: Parallel generation of multiple completions

KV Cache Architecture

The KV cache is central to efficient autoregressive generation, avoiding recomputation of attention keys and values.

KVCache Class

Source: nanochat/engine.py:82-94

python
class KVCache:
    """
    Works hand-in-hand with the GPT model to maintain the KV cache.
    Note that the .pos advances automatically after the last layer inserts.
    """

    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0 # current position in time in the cache

The cache stores keys and values for all transformer layers in a 6D tensor:

  • Dimension 0: Number of layers
  • Dimension 1: 2 (separate K and V)
  • Dimension 2: Batch size
  • Dimension 3: Number of attention heads
  • Dimension 4: Sequence length (time dimension)
  • Dimension 5: Head dimension

Cache Operations

Dynamic Growth

Source: nanochat/engine.py:135-148

python
def insert_kv(self, layer_idx, k, v):
    # Lazy initialize the cache here because we need to know the dtype/device
    if self.kv_cache is None:
        self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
    
    # Insert new keys/values to the cache and return the full cache so far
    B, H, T_add, D = k.size()
    t0, t1 = self.pos, self.pos + T_add
    
    # Dynamically grow the cache if needed
    if t1 > self.kv_cache.size(4):
        t_needed = t1 + 1024  # as much as we need plus buffer of 1024
        t_needed = (t_needed + 1023) & ~1023  # round up to nearest multiple of 1024
        additional_shape = list(self.kv_cache.shape)
        additional_shape[4] = t_needed - self.kv_cache.size(4)
        additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
        self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()

Cache Prefilling

For multi-sample generation, the engine supports cache prefilling from a single-batch cache:

Source: nanochat/engine.py:100-125

python
def prefill(self, other):
    """
    Prefill given another KV cache. Optionally expand along batch dim.
    This is used when we do batch 1 prefill and then want to generate
    multiple samples in parallel from there.
    """
    # Validate the shapes
    assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
    assert other.kv_cache is not None, "Cannot prefill with a None KV cache"

    # Extract and validate dimensions
    self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
    other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape

    # Batch size can be expanded (other can be 1, self can be larger)
    assert self_batch == other_batch or other_batch == 1

    # Initialize and copy cache data
    dtype, device = other.kv_cache.dtype, other.kv_cache.device
    self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
    self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
    self.pos = other.pos

Sampling Strategy

Token Sampling

Source: nanochat/engine.py:162-178

python
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
    """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
    assert temperature >= 0.0, "temperature must be non-negative"
    if temperature == 0.0:
        return torch.argmax(logits, dim=-1, keepdim=True)
    if top_k is not None and top_k > 0:
        k = min(top_k, logits.size(-1))
        vals, idx = torch.topk(logits, k, dim=-1)
        vals = vals / temperature
        probs = F.softmax(vals, dim=-1)
        choice = torch.multinomial(probs, num_samples=1, generator=rng)
        return idx.gather(1, choice)
    else:
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1, generator=rng)

The sampling supports:

  • Greedy decoding (temperature=0.0)
  • Temperature sampling with configurable randomness
  • Top-k sampling to limit choices to most probable tokens

Tool Use Integration

Calculator Tool

The engine supports a safe Python calculator for mathematical expressions:

Source: nanochat/engine.py:45-80

python
def use_calculator(expr):
    """
    Evaluate a Python expression safely.
    Supports both math expressions and string operations like .count()
    """
    # Remove commas from numbers
    expr = expr.replace(",", "")

    # Check if it's a pure math expression (old behavior)
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # disallow power operator
            return None
        return eval_with_timeout(expr)

    # Check if it's a string operation we support
    allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
    if not all([x in allowed_chars for x in expr]):
        return None

    # Disallow dangerous patterns
    dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
                         'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
                         'getattr', 'setattr', 'delattr', 'hasattr']
    if any(pattern in expr.lower() for pattern in dangerous_patterns):
        return None

    # Only allow .count() method for now
    if '.count(' not in expr:
        return None

    return eval_with_timeout(expr)

Timeout Protection

Source: nanochat/engine.py:18-33

python
@contextmanager
def timeout(duration, formula):
    def timeout_handler(signum, frame):
        raise Exception(f"'{formula}': timed out after {duration} seconds")

    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(duration)
    yield
    signal.alarm(0)

def eval_with_timeout(formula, max_time=3):
    try:
        with timeout(max_time, formula):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", SyntaxWarning)
                return eval(formula, {"__builtins__": {}}, {})
    except Exception:
        return None

Generation Architecture

Row State Management

For multi-sample generation, each "row" (sample) maintains its own state:

Source: nanochat/engine.py:181-190

python
class RowState:
    # Per-row state tracking during generation
    def __init__(self, current_tokens=None):
        self.current_tokens = current_tokens or [] # Current token sequence for this row
        self.forced_tokens = deque() # Queue of tokens to force inject
        self.in_python_block = False # Whether we are inside a python block
        self.python_expr_tokens = [] # Tokens of the current python expression
        self.completed = False # Whether this row has completed generation

Main Generation Loop

Source: nanochat/engine.py:225-250

python
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
    batch_size=1,
    seq_len=len(tokens),
    **kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :].expand(num_samples, -1)  # (num_samples, vocab_size)

# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
    batch_size=num_samples,
    seq_len=kv_length_hint,
    **kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around

# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]

Tool Use State Machine

Source: nanochat/engine.py:270-290

python
# Handle tool logic
if next_token == python_start:
    state.in_python_block = True
    state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
    state.in_python_block = False
    if state.python_expr_tokens:
        expr = self.tokenizer.decode(state.python_expr_tokens)
        result = use_calculator(expr)
        if result is not None:
            result_tokens = self.tokenizer.encode(str(result))
            state.forced_tokens.append(output_start)
            state.forced_tokens.extend(result_tokens)
            state.forced_tokens.append(output_end)
    state.python_expr_tokens = []
elif state.in_python_block:
    state.python_expr_tokens.append(next_token)

The engine manages tool use through special tokens:

  • <|python_start|> and <|python_end|> delimit tool invocations
  • <|output_start|> and <|output_end|> delimit tool outputs
  • Expressions are evaluated and results are injected as forced tokens

Engine Class Interface

Main Engine Class

Source: nanochat/engine.py:192-205

python
class Engine:

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer # needed for tool use

    @torch.inference_mode()
    def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
        """Same as generate, but does single prefill and then clones the KV cache."""
        assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
        device = self.model.get_device()
        rng = torch.Generator(device=device)
        rng.manual_seed(seed)

Batch Generation

Source: nanochat/engine.py:300-320

python
def generate_batch(self, tokens, num_samples=1, **kwargs):
    """
    Non-streaming batch generation that just returns the final token sequences.
    Returns a list of token sequences (list of lists of ints).
    Terminal tokens (assistant_end, bos) are not included in the results.
    """
    assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
    bos = self.tokenizer.get_bos_token_id()
    results = [tokens.copy() for _ in range(num_samples)]
    masks = [[0] * len(tokens) for _ in range(num_samples)]
    completed = [False] * num_samples
    
    for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
        for i, (token, mask) in enumerate(zip(token_column, token_masks)):
            if not completed[i]:
                if token == assistant_end or token == bos:
                    completed[i] = True
                else:
                    results[i].append(token)
                    masks[i].append(mask)
        if all(completed):
            break
    return results, masks

Performance Optimizations

Memory Efficiency

  1. Lazy Cache Initialization: KV cache is allocated only when needed
  2. Dynamic Growth: Cache grows as needed with 1024-token buffers
  3. Cache Prefilling: Efficient batch expansion for multi-sample generation
  4. Buffer Management: Old cache memory is properly released

Computational Efficiency

  1. Single Prefill: Prompt is processed once, then cache is replicated
  2. Streaming Generation: Tokens are yielded as they're generated
  3. Early Termination: Generation stops when all samples complete
  4. Tool Integration: Tool outputs are injected without model recomputation

Integration with GPT Model

The engine works seamlessly with the GPT model's attention mechanism:

Source: nanochat/gpt.py:75-85

python
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
    k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)

# Attention with different cases for training vs inference
enable_gqa = self.n_head != self.n_kv_head
if kv_cache is None or Tq == Tk:
    y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
    y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)

The inference engine provides a complete, efficient solution for autoregressive text generation with advanced features like tool use and multi-sample generation, while maintaining clean separation of concerns and optimal memory usage.


Sources:

  • nanochat/engine.py (complete inference engine implementation)
  • nanochat/gpt.py (attention mechanism integration)
  • Integration with tool use and special token handling
Last updated: 1/10/2026