Inference Scripts
Inference Scripts
Interactive and web-based inference interfaces for running nanochat models in production and development environments.
Overview
The inference system provides two main interfaces:
- CLI Chat Interface - Interactive terminal-based conversations
- Web Server - FastAPI-based server with UI and API endpoints
- Multi-GPU Support - Worker pool for distributed inference
- Streaming Generation - Real-time token streaming
Key Files:
scripts/chat_cli.py- Command-line chat interfacescripts/chat_web.py- Web server with UI and APInanochat/ui.html- Chat web interfacenanochat/engine.py- Underlying inference engine
CLI Chat Interface
Interactive command-line interface for single-user conversations.
Source: scripts/chat_cli.py:1-15
"""
New and upgraded chat mode because a lot of the code has changed since the last one.
Intended to be run single GPU only atm:
python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
Features
- Interactive Conversations - Real-time streaming responses
- Conversation State - Maintains context across turns
- Special Commands -
clear,quit,exitfor control - Single-Shot Mode - One-off prompts with
--prompt
Conversation Flow
Source: scripts/chat_cli.py:40-80
# Special tokens for the chat state machine
bos = tokenizer.get_bos_token_id()
user_start, user_end = tokenizer.encode_special("<|user_start|>"), tokenizer.encode_special("<|user_end|>")
assistant_start, assistant_end = tokenizer.encode_special("<|assistant_start|>"), tokenizer.encode_special("<|assistant_end|>")
# Create Engine for efficient generation
engine = Engine(model, tokenizer)
print("\nNanoChat Interactive Mode")
print("-" * 50)
print("Type 'quit' or 'exit' to end the conversation")
print("Type 'clear' to start a new conversation")
print("-" * 50)
conversation_tokens = [bos]
The conversation maintains full context by building up token sequences:
Source: scripts/chat_cli.py:90-120
# Add User message to the conversation
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(user_input))
conversation_tokens.append(user_end)
# Kick off the assistant
conversation_tokens.append(assistant_start)
generate_kwargs = {
"num_samples": 1,
"max_tokens": 256,
"temperature": args.temperature,
"top_k": args.top_k,
}
response_tokens = []
print("\nAssistant: ", end="", flush=True)
with autocast_ctx:
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])
print(token_text, end="", flush=True)
Usage Examples
# Start interactive chat with SFT model
python -m scripts.chat_cli -i sft
# Use specific model checkpoint
python -m scripts.chat_cli -i rl --model-tag v1.0 --step 5000
# Single prompt mode
python -m scripts.chat_cli -i mid --prompt "Explain quantum computing"
# Adjust generation parameters
python -m scripts.chat_cli -i sft --temperature 0.8 --top-k 40
Command Line Arguments
-i, --source- Model source:sft|mid|rl-g, --model-tag- Specific model tag to load-s, --step- Specific training step to load-p, --prompt- Single prompt mode (non-interactive)-t, --temperature- Generation temperature (default: 0.6)-k, --top-k- Top-k sampling parameter (default: 50)--device-type- Force device:cuda|cpu|mps-d, --dtype- Precision:float32|bfloat16
Web Server Interface
Production-ready FastAPI server with multi-GPU worker pool and web UI.
Source: scripts/chat_web.py:1-30
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
"""
Multi-GPU Architecture
The server uses a worker pool to distribute inference across multiple GPUs:
Source: scripts/chat_web.py:85-110
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
Worker Initialization
Each GPU gets its own model replica for parallel inference:
Source: scripts/chat_web.py:115-140
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
API Endpoints
Chat Completions API
OpenAI-compatible streaming chat API:
Source: scripts/chat_web.py:290-320
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Basic validation to prevent abuse
validate_chat_request(request)
# Log incoming conversation to console
logger.info("="*20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
logger.info("-"*20)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
Abuse Prevention
Comprehensive limits to prevent resource abuse:
Source: scripts/chat_web.py:40-55
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
# Abuse Prevention:
# - Maximum 500 messages per request
# - Maximum 8000 characters per message
# - Maximum 32000 characters total conversation length
# - Temperature clamped to 0.0-2.0
# - Top-k clamped to 1-200
# - Max tokens clamped to 1-4096
Streaming Generation
Handles UTF-8 properly for streaming responses:
Source: scripts/chat_web.py:245-275
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
accumulated_tokens = []
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
# Append the token to sequence
accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith('�'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\\n\\n"
last_clean_text = current_text
Usage Examples
# Single GPU server
python -m scripts.chat_web
# Multi-GPU server
python -m scripts.chat_web --num-gpus 4
# Custom port and host
python -m scripts.chat_web --port 9000 --host 127.0.0.1
# Specific model and generation settings
python -m scripts.chat_web -i rl --temperature 0.7 --max-tokens 1024
# Production deployment
python -m scripts.chat_web --num-gpus 8 --host 0.0.0.0 --port 8000
API Usage
# Health check
curl http://localhost:8000/health
# Worker stats
curl http://localhost:8000/stats
# Chat API example
curl -X POST http://localhost:8000/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Hello!"}
],
"temperature": 0.7,
"max_tokens": 100
}'
Command Line Arguments
-n, --num-gpus- Number of GPUs to use (default: auto-detect)-i, --source- Model source:sft|mid|rl-t, --temperature- Default temperature (default: 0.8)-k, --top-k- Default top-k (default: 50)-m, --max-tokens- Default max tokens (default: 512)-g, --model-tag- Model tag to load-s, --step- Step to load-p, --port- Server port (default: 8000)--host- Host to bind (default: 0.0.0.0)-d, --dtype- Precision:float32|bfloat16
Related Pages
Sources:
- scripts/chat_cli.py:1-15,40-80,90-120
- scripts/chat_web.py:1-30,85-110,115-140,290-320,40-55,245-275