Evaluation Framework
Evaluation Framework
The NanoChat evaluation framework provides a comprehensive system for assessing model performance across multiple domains. Built around a flexible Task abstraction, it supports both multiple-choice and generative evaluation tasks with sophisticated mixing and sequencing capabilities.
Framework Overview
Source: tasks/common.py:1-8
"""
Base class for all Tasks.
A Task is basically a dataset of conversations, together with some
metadata and often also evaluation criteria.
Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
"""
The evaluation framework consists of:
- Base Task abstraction for consistent interfaces
- Task implementations for specific benchmarks
- Task mixing utilities for training data composition
- Evaluation loops for different task types
- Distributed evaluation support for multi-GPU setups
Base Task Architecture
Core Task Class
Source: tasks/common.py:10-35
class Task:
"""
Base class of a Task. Allows for lightweight slicing of the underlying dataset.
"""
def __init__(self, start=0, stop=None, step=1):
# allows a lightweight logical view over a dataset
assert start >= 0, f"Start must be non-negative, got {start}"
assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
assert step >= 1, f"Step must be strictly positive, got {step}"
self.start = start
self.stop = stop # could be None here
self.step = step
@property
def eval_type(self):
# one of 'generative' | 'categorical'
raise NotImplementedError
def num_examples(self):
raise NotImplementedError
def get_example(self, index):
raise NotImplementedError
def evaluate(self, problem, completion):
raise NotImplementedError
Key Interface Methods
eval_type: Returns either'generative'or'categorical'num_examples(): Total number of examples in the taskget_example(index): Retrieve a specific example as a conversationevaluate(problem, completion): Score a model's response
Dataset Slicing
The Task class supports efficient slicing without data duplication:
Source: tasks/common.py:37-49
def __len__(self):
start = self.start
stop = self.num_examples() if self.stop is None else self.stop
step = self.step
span = stop - start
num = (span + step - 1) // step # ceil_div(span, step)
assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
return num
def __getitem__(self, index: int):
assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
physical_index = self.start + index * self.step
conversation = self.get_example(physical_index)
return conversation
This enables creating task subsets like Task(start=100, stop=200, step=2) without copying data.
Task Composition
TaskMixture
For training, tasks can be mixed with deterministic shuffling:
Source: tasks/common.py:52-87
class TaskMixture(Task):
"""
For SFT Training it becomes useful to train on a mixture of datasets.
Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
# tasks is a list of Task objects
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
# Build list of all (task_idx, local_idx) pairs
self.index_map = []
for task_idx, task_length in enumerate(self.lengths):
for local_idx in range(task_length):
self.index_map.append((task_idx, local_idx))
# Deterministically shuffle to mix tasks throughout training
rng = random.Random(42)
rng.shuffle(self.index_map)
def get_example(self, index):
"""
Access conversations according to a deterministic shuffle of all examples.
This ensures tasks are mixed throughout training, regardless of dataset size.
"""
task_idx, local_idx = self.index_map[index]
return self.tasks[task_idx][local_idx]
TaskSequence
For curriculum learning, tasks can be sequenced:
Source: tasks/common.py:90-115
class TaskSequence(Task):
"""
For SFT Training sometimes we want to sequentially train on a list of tasks.
This is useful for cases that require a training curriculum.
"""
def __init__(self, tasks, **kwargs):
super().__init__(**kwargs)
self.tasks = tasks
self.lengths = [len(task) for task in self.tasks]
self.num_conversations = sum(self.lengths)
def get_example(self, index):
assert 0 <= index < self.num_conversations, f"Index {index} out of range"
for task_idx, task_length in enumerate(self.lengths):
if index < task_length:
return self.tasks[task_idx][index]
index -= task_length
Multiple Choice Tasks
Rendering Format
Multiple choice tasks use a standardized format optimized for smaller models:
Source: tasks/common.py:118-135
def render_mc(question, letters, choices):
"""
The common multiple choice rendering format we will use.
Note two important design decisions:
1) Smaller models prefer to have the letter *after* the choice for better binding.
2) There is no whitespace between the delimiter (=) and the letter.
This is critical because the tokenizer has different token ids
for " A" vs. "A". The assistant responses will be just the letter itself.
"""
query = f"Multiple Choice question: {question}\n"
query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
query += "\nRespond only with the letter of the correct answer."
return query
Example output:
Multiple Choice question: What is the capital of France?
- London=A
- Berlin=B
- Paris=C
- Madrid=D
Respond only with the letter of the correct answer.
ARC Implementation
Source: tasks/arc.py:9-40
class ARC(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)
@property
def eval_type(self):
return 'categorical'
def get_example(self, index):
row = self.ds[index]
question = row["question"]
choices = row["choices"]["text"]
answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
letters = row["choices"]["label"]
# create and return the Conversation object
user_message = render_mc(question, letters, choices)
messages = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer_string}
]
conversation = {
"messages": messages,
"letters": letters, # useful during evaluation for answer validation
}
return conversation
def evaluate(self, conversation, assistant_response):
assert assistant_response in conversation['letters']
assistant_message = conversation['messages'][-1]['content']
return assistant_response == assistant_message
Generative Tasks
GSM8K Mathematics
GSM8K demonstrates tool use with calculator integration:
Source: tasks/gsm8k.py:35-75
def get_example(self, index):
row = self.ds[index]
question = row['question']
answer = row['answer'] # string with solution and answer after #### marker
# Parse tool calls from GSM8K format
assistant_message_parts = []
parts = re.split(r'(<<[^>]+>>)', answer)
for part in parts:
if part.startswith('<<') and part.endswith('>>'):
# This is a calculator tool call
inner = part[2:-2] # Remove << >>
if '=' in inner:
expr, result = inner.rsplit('=', 1)
else:
expr, result = inner, ""
# Add the tool call as a part
assistant_message_parts.append({"type": "python", "text": expr})
# Add the result as a part
assistant_message_parts.append({"type": "python_output", "text": result})
else:
# Regular text between tool calls
assistant_message_parts.append({"type": "text", "text": part})
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": assistant_message_parts},
]
return {"messages": messages}
Answer Extraction
Source: tasks/gsm8k.py:15-25
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
"""
Extract the numerical answer after #### marker.
Follows official code for normalization.
"""
match = GSM_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
return None
Evaluation Execution
Categorical Evaluation
Source: scripts/chat_eval.py:85-140
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
num_passed, total = 0, 0
# Process in batches
for batch_start in range(ddp_rank * batch_size, num_problems, ddp_world_size * batch_size):
batch_end = min(batch_start + batch_size, num_problems)
batch_conversations = [task_object[i] for i in range(batch_start, batch_end)]
# Batch the prompts and get logits
batch_prompts = [tokenizer.render_for_completion(conv) for conv in batch_conversations]
max_length = max(len(prompt) for prompt in batch_prompts)
# Pad all prompts to same length
inputs = torch.full((len(batch_prompts), max_length), tokenizer.get_bos_token_id(),
dtype=torch.long, device=device)
for i, prompt in enumerate(batch_prompts):
inputs[i, :len(prompt)] = torch.tensor(prompt, dtype=torch.long)
# Get logits and extract predictions
with torch.no_grad():
logits = model(inputs) # (B, T, vocab_size)
last_logits = logits[:, -1, :] # (B, vocab_size)
# Convert letter choices to token ids and find most likely
for i, conversation in enumerate(batch_conversations):
letters = conversation.get("letters", ["A", "B", "C", "D"])
letter_logits = []
for letter in letters:
letter_token = tokenizer.encode(letter)[0]
letter_logits.append(last_logits[i, letter_token])
# Find the letter with highest logit
best_idx = torch.argmax(torch.stack(letter_logits))
predicted_letter = letters[best_idx]
# Evaluate this prediction
outcome = task_object.evaluate(conversation, predicted_letter)
num_passed += int(outcome)
total += 1
return num_passed / total if total > 0 else 0.0
Generative Evaluation
Source: scripts/chat_eval.py:25-65
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens,
temperature, top_k, max_problems=None):
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
num_passed, total = 0, 0
for i in range(ddp_rank, num_problems, ddp_world_size):
conversation = task_object[i]
# Tokenize the prompt for completion
encoded_prompt = tokenizer.render_for_completion(conversation)
# Generate response(s)
if num_samples == 1:
response_tokens_list, _ = engine.generate_batch(
encoded_prompt, num_samples=1, max_tokens=max_new_tokens,
temperature=temperature, top_k=top_k
)
response_tokens = response_tokens_list[0]
response_text = tokenizer.decode(response_tokens)
# Evaluate the response
outcome = task_object.evaluate(conversation, response_text)
num_passed += int(outcome)
else:
# Multiple samples - take best of N
best_outcome = 0
response_tokens_list, _ = engine.generate_batch(
encoded_prompt, num_samples=num_samples, max_tokens=max_new_tokens,
temperature=temperature, top_k=top_k
)
for response_tokens in response_tokens_list:
response_text = tokenizer.decode(response_tokens)
outcome = task_object.evaluate(conversation, response_text)
best_outcome = max(best_outcome, outcome)
num_passed += int(best_outcome)
total += 1
return num_passed / total if total > 0 else 0.0
Evaluation Entry Point
Unified Evaluation Interface
Source: scripts/chat_eval.py:160-180
def run_chat_eval(task_name, model, tokenizer, engine,
batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
max_problems=None):
# Create the evaluation object
task_module = {
'HumanEval': HumanEval,
'MMLU': partial(MMLU, subset="all", split="test"),
'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
'GSM8K': partial(GSM8K, subset="main", split="test"),
'SpellingBee': partial(SpellingBee, size=256, split="test"),
}[task_name]
task_object = task_module()
# Run the appropriate evaluation type
if task_object.eval_type == 'generative':
acc = run_generative_eval(task_object, tokenizer, model, engine,
num_samples, max_new_tokens, temperature, top_k,
max_problems=max_problems)
elif task_object.eval_type == 'categorical':
acc = run_categorical_eval(task_object, tokenizer, model, batch_size,
max_problems=max_problems)
else:
raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
return acc
CORE Metric
Base Model Evaluation
The CORE metric from the DCLM paper evaluates base model quality:
Source: nanochat/core_eval.py:1-15
"""
Functions for evaluating the CORE metric, as described in the DCLM paper.
https://arxiv.org/abs/2406.11794
TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
{% endfor -%}
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
Available Tasks
Multiple Choice Tasks
- ARC-Easy/Challenge: Science reasoning questions
- MMLU: Broad knowledge across academic subjects
- Auxiliary MMLU: Training subset with diverse topics
Generative Tasks
- GSM8K: Grade school math with tool use
- HumanEval: Python code generation
- SpellingBee: Letter counting and spelling
- SimpleSpelling: Basic word spelling
Conversational Tasks
- SmolTalk: General conversation dataset
- CustomJSON: Synthetic identity conversations
Usage Examples
Training Data Mixture
Source: scripts/chat_sft.py:85-95
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K synthetic identity
SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling
SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee
]) # 23K rows total
Evaluation Command
# Single task evaluation
python -m scripts.chat_eval -a ARC-Easy
# Distributed evaluation
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a GSM8K
# Multiple tasks
python -m scripts.chat_eval -a ARC-Easy,MMLU,GSM8K
The evaluation framework provides a comprehensive, extensible system for assessing model capabilities across diverse domains while maintaining consistency and supporting distributed execution.
Sources:
- tasks/common.py (base framework and task composition)
- tasks/arc.py (multiple choice task implementation)
- tasks/gsm8k.py (generative task with tool use)
- scripts/chat_eval.py (evaluation execution logic)
- nanochat/core_eval.py (CORE metric implementation)