Inference Scripts

Inference Scripts

Interactive and web-based inference interfaces for running nanochat models in production and development environments.

Overview

The inference system provides two main interfaces:

  • CLI Chat Interface - Interactive terminal-based conversations
  • Web Server - FastAPI-based server with UI and API endpoints
  • Multi-GPU Support - Worker pool for distributed inference
  • Streaming Generation - Real-time token streaming

Key Files:

  • scripts/chat_cli.py - Command-line chat interface
  • scripts/chat_web.py - Web server with UI and API
  • nanochat/ui.html - Chat web interface
  • nanochat/engine.py - Underlying inference engine

CLI Chat Interface

Interactive command-line interface for single-user conversations.

Source: scripts/chat_cli.py:1-15

python
"""
New and upgraded chat mode because a lot of the code has changed since the last one.

Intended to be run single GPU only atm:
python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model

Features

  1. Interactive Conversations - Real-time streaming responses
  2. Conversation State - Maintains context across turns
  3. Special Commands - clear, quit, exit for control
  4. Single-Shot Mode - One-off prompts with --prompt

Conversation Flow

Source: scripts/chat_cli.py:40-80

python
# Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")

# Create Engine for efficient generation
engine = Engine(model, tokenizer)

print("\nNanoChat Interactive Mode")
print("-" * 50)
print("Type 'quit' or 'exit' to end the conversation")
print("Type 'clear' to start a new conversation")
print("-" * 50)

conversation_tokens = [bos]

The conversation maintains full context by building up token sequences:

Source: scripts/chat_cli.py:90-120

python
# Add User message to the conversation
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(user_input))
conversation_tokens.append(user_end)

# Kick off the assistant
conversation_tokens.append(assistant_start)
generate_kwargs = {
    "num_samples": 1,
    "max_tokens": 256,
    "temperature": args.temperature,
    "top_k": args.top_k,
}
response_tokens = []
print("\nAssistant: ", end="", flush=True)
with autocast_ctx:
    for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
        token = token_column[0] # pop the batch dimension (num_samples=1)
        response_tokens.append(token)
        token_text = tokenizer.decode([token])
        print(token_text, end="", flush=True)

Usage Examples

bash
# Start interactive chat with SFT model
python -m scripts.chat_cli -i sft

# Use specific model checkpoint
python -m scripts.chat_cli -i rl --model-tag v1.0 --step 5000

# Single prompt mode
python -m scripts.chat_cli -i mid --prompt "Explain quantum computing"

# Adjust generation parameters
python -m scripts.chat_cli -i sft --temperature 0.8 --top-k 40

Command Line Arguments

  • -i, --source - Model source: sft|mid|rl
  • -g, --model-tag - Specific model tag to load
  • -s, --step - Specific training step to load
  • -p, --prompt - Single prompt mode (non-interactive)
  • -t, --temperature - Generation temperature (default: 0.6)
  • -k, --top-k - Top-k sampling parameter (default: 50)
  • --device-type - Force device: cuda|cpu|mps
  • -d, --dtype - Precision: float32|bfloat16

Web Server Interface

Production-ready FastAPI server with multi-GPU worker pool and web UI.

Source: scripts/chat_web.py:1-30

python
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.

Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.

Launch examples:

- single available GPU (default)
python -m scripts.chat_web

- 4 GPUs
python -m scripts.chat_web --num-gpus 4

To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)

Endpoints:
  GET  /           - Chat UI
  POST /chat/completions - Chat API (streaming only)
  GET  /health     - Health check with worker pool status
  GET  /stats      - Worker pool statistics and GPU utilization
"""

Multi-GPU Architecture

The server uses a worker pool to distribute inference across multiple GPUs:

Source: scripts/chat_web.py:85-110

python
@dataclass
class Worker:
    """A worker with a model loaded on a specific GPU."""
    gpu_id: int
    device: torch.device
    engine: Engine
    tokenizer: object
    autocast_ctx: torch.amp.autocast

class WorkerPool:
    """Pool of workers, each with a model replica on a different GPU."""

    def __init__(self, num_gpus: Optional[int] = None):
        if num_gpus is None:
            if device_type == "cuda":
                num_gpus = torch.cuda.device_count()
            else:
                num_gpus = 1 # e.g. cpu|mps
        self.num_gpus = num_gpus
        self.workers: List[Worker] = []
        self.available_workers: asyncio.Queue = asyncio.Queue()

Worker Initialization

Each GPU gets its own model replica for parallel inference:

Source: scripts/chat_web.py:115-140

python
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
    """Load model on each GPU."""
    print(f"Initializing worker pool with {self.num_gpus} GPUs...")
    if self.num_gpus > 1:
        assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."

    for gpu_id in range(self.num_gpus):

        if device_type == "cuda":
            device = torch.device(f"cuda:{gpu_id}")
            print(f"Loading model on GPU {gpu_id}...")
        else:
            device = torch.device(device_type) # e.g. cpu|mps
            print(f"Loading model on {device_type}...")

        model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
        engine = Engine(model, tokenizer)
        autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()

        worker = Worker(
            gpu_id=gpu_id,
            device=device,
            engine=engine,
            tokenizer=tokenizer,
            autocast_ctx=autocast_ctx
        )
        self.workers.append(worker)
        await self.available_workers.put(worker)

API Endpoints

Chat Completions API

OpenAI-compatible streaming chat API:

Source: scripts/chat_web.py:290-320

python
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
    """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""

    # Basic validation to prevent abuse
    validate_chat_request(request)

    # Log incoming conversation to console
    logger.info("="*20)
    for i, message in enumerate(request.messages):
        logger.info(f"[{message.role.upper()}]: {message.content}")
    logger.info("-"*20)

    # Acquire a worker from the pool (will wait if all are busy)
    worker_pool = app.state.worker_pool
    worker = await worker_pool.acquire_worker()

    try:
        # Build conversation tokens
        bos = worker.tokenizer.get_bos_token_id()
        user_start = worker.tokenizer.encode_special("<|user_start|>")
        user_end = worker.tokenizer.encode_special("<|user_end|>")
        assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
        assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")

Abuse Prevention

Comprehensive limits to prevent resource abuse:

Source: scripts/chat_web.py:40-55

python
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096

# Abuse Prevention:
#   - Maximum 500 messages per request
#   - Maximum 8000 characters per message
#   - Maximum 32000 characters total conversation length
#   - Temperature clamped to 0.0-2.0
#   - Top-k clamped to 1-200
#   - Max tokens clamped to 1-4096

Streaming Generation

Handles UTF-8 properly for streaming responses:

Source: scripts/chat_web.py:245-275

python
async def generate_stream(
    worker: Worker,
    tokens,
    temperature=None,
    max_new_tokens=None,
    top_k=None
) -> AsyncGenerator[str, None]:
    """Generate assistant response with streaming."""
    # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
    accumulated_tokens = []
    # Track the last complete UTF-8 string (without replacement characters)
    last_clean_text = ""

    with worker.autocast_ctx:
        for token_column, token_masks in worker.engine.generate(
            tokens,
            num_samples=1,
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            seed=random.randint(0, 2**31 - 1)
        ):
            token = token_column[0]

            # Append the token to sequence
            accumulated_tokens.append(token)
            # Decode all accumulated tokens to get proper UTF-8 handling
            current_text = worker.tokenizer.decode(accumulated_tokens)
            # Only emit text if it doesn't end with a replacement character
            # This ensures we don't emit incomplete UTF-8 sequences
            if not current_text.endswith('�'):
                # Extract only the new text since last clean decode
                new_text = current_text[len(last_clean_text):]
                if new_text:  # Only yield if there's new content
                    yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\\n\\n"
                    last_clean_text = current_text

Usage Examples

bash
# Single GPU server
python -m scripts.chat_web

# Multi-GPU server
python -m scripts.chat_web --num-gpus 4

# Custom port and host
python -m scripts.chat_web --port 9000 --host 127.0.0.1

# Specific model and generation settings
python -m scripts.chat_web -i rl --temperature 0.7 --max-tokens 1024

# Production deployment
python -m scripts.chat_web --num-gpus 8 --host 0.0.0.0 --port 8000

API Usage

bash
# Health check
curl http://localhost:8000/health

# Worker stats
curl http://localhost:8000/stats

# Chat API example
curl -X POST http://localhost:8000/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "messages": [
      {"role": "user", "content": "Hello!"}
    ],
    "temperature": 0.7,
    "max_tokens": 100
  }'

Command Line Arguments

  • -n, --num-gpus - Number of GPUs to use (default: auto-detect)
  • -i, --source - Model source: sft|mid|rl
  • -t, --temperature - Default temperature (default: 0.8)
  • -k, --top-k - Default top-k (default: 50)
  • -m, --max-tokens - Default max tokens (default: 512)
  • -g, --model-tag - Model tag to load
  • -s, --step - Step to load
  • -p, --port - Server port (default: 8000)
  • --host - Host to bind (default: 0.0.0.0)
  • -d, --dtype - Precision: float32|bfloat16

Sources:

  • scripts/chat_cli.py:1-15,40-80,90-120
  • scripts/chat_web.py:1-30,85-110,115-140,290-320,40-55,245-275
Last updated: 1/10/2026