Evaluation Framework

Evaluation Framework

The NanoChat evaluation framework provides a comprehensive system for assessing model performance across multiple domains. Built around a flexible Task abstraction, it supports both multiple-choice and generative evaluation tasks with sophisticated mixing and sequencing capabilities.

Framework Overview

Source: tasks/common.py:1-8

python
"""
Base class for all Tasks.
A Task is basically a dataset of conversations, together with some
metadata and often also evaluation criteria.
Example tasks: MMLU, ARC-Easy, ARC-Challenge, GSM8K, HumanEval, SmolTalk.
"""

The evaluation framework consists of:

  • Base Task abstraction for consistent interfaces
  • Task implementations for specific benchmarks
  • Task mixing utilities for training data composition
  • Evaluation loops for different task types
  • Distributed evaluation support for multi-GPU setups

Base Task Architecture

Core Task Class

Source: tasks/common.py:10-35

python
class Task:
    """
    Base class of a Task. Allows for lightweight slicing of the underlying dataset.
    """

    def __init__(self, start=0, stop=None, step=1):
        # allows a lightweight logical view over a dataset
        assert start >= 0, f"Start must be non-negative, got {start}"
        assert stop is None or stop >= start, f"Stop should be greater than or equal to start, got {stop} and {start}"
        assert step >= 1, f"Step must be strictly positive, got {step}"
        self.start = start
        self.stop = stop # could be None here
        self.step = step

    @property
    def eval_type(self):
        # one of 'generative' | 'categorical'
        raise NotImplementedError

    def num_examples(self):
        raise NotImplementedError

    def get_example(self, index):
        raise NotImplementedError

    def evaluate(self, problem, completion):
        raise NotImplementedError

Key Interface Methods

  1. eval_type: Returns either 'generative' or 'categorical'
  2. num_examples(): Total number of examples in the task
  3. get_example(index): Retrieve a specific example as a conversation
  4. evaluate(problem, completion): Score a model's response

Dataset Slicing

The Task class supports efficient slicing without data duplication:

Source: tasks/common.py:37-49

python
def __len__(self):
    start = self.start
    stop = self.num_examples() if self.stop is None else self.stop
    step = self.step
    span = stop - start
    num = (span + step - 1) // step # ceil_div(span, step)
    assert num >= 0, f"Negative number of examples???: {num}" # prevent footguns
    return num

def __getitem__(self, index: int):
    assert isinstance(index, int), f"Index must be an integer, got {type(index)}"
    physical_index = self.start + index * self.step
    conversation = self.get_example(physical_index)
    return conversation

This enables creating task subsets like Task(start=100, stop=200, step=2) without copying data.

Task Composition

TaskMixture

For training, tasks can be mixed with deterministic shuffling:

Source: tasks/common.py:52-87

python
class TaskMixture(Task):
    """
    For SFT Training it becomes useful to train on a mixture of datasets.
    Fun trick: if you wish to oversample any task, just pass it in multiple times in the list.
    """

    def __init__(self, tasks, **kwargs):
        super().__init__(**kwargs)
        # tasks is a list of Task objects
        self.tasks = tasks
        self.lengths = [len(task) for task in self.tasks]
        self.num_conversations = sum(self.lengths)
        
        # Build list of all (task_idx, local_idx) pairs
        self.index_map = []
        for task_idx, task_length in enumerate(self.lengths):
            for local_idx in range(task_length):
                self.index_map.append((task_idx, local_idx))
        
        # Deterministically shuffle to mix tasks throughout training
        rng = random.Random(42)
        rng.shuffle(self.index_map)

    def get_example(self, index):
        """
        Access conversations according to a deterministic shuffle of all examples.
        This ensures tasks are mixed throughout training, regardless of dataset size.
        """
        task_idx, local_idx = self.index_map[index]
        return self.tasks[task_idx][local_idx]

TaskSequence

For curriculum learning, tasks can be sequenced:

Source: tasks/common.py:90-115

python
class TaskSequence(Task):
    """
    For SFT Training sometimes we want to sequentially train on a list of tasks.
    This is useful for cases that require a training curriculum.
    """

    def __init__(self, tasks, **kwargs):
        super().__init__(**kwargs)
        self.tasks = tasks
        self.lengths = [len(task) for task in self.tasks]
        self.num_conversations = sum(self.lengths)

    def get_example(self, index):
        assert 0 <= index < self.num_conversations, f"Index {index} out of range"
        for task_idx, task_length in enumerate(self.lengths):
            if index < task_length:
                return self.tasks[task_idx][index]
            index -= task_length

Multiple Choice Tasks

Rendering Format

Multiple choice tasks use a standardized format optimized for smaller models:

Source: tasks/common.py:118-135

python
def render_mc(question, letters, choices):
    """
    The common multiple choice rendering format we will use.

    Note two important design decisions:
    1) Smaller models prefer to have the letter *after* the choice for better binding.
    2) There is no whitespace between the delimiter (=) and the letter.
       This is critical because the tokenizer has different token ids
       for " A" vs. "A". The assistant responses will be just the letter itself.
    """
    query = f"Multiple Choice question: {question}\n"
    query += "".join([f"- {choice}={letter}\n" for letter, choice in zip(letters, choices)])
    query += "\nRespond only with the letter of the correct answer."
    return query

Example output:

text
Multiple Choice question: What is the capital of France?
- London=A
- Berlin=B
- Paris=C
- Madrid=D

Respond only with the letter of the correct answer.

ARC Implementation

Source: tasks/arc.py:9-40

python
class ARC(Task):

    def __init__(self, subset, split, **kwargs):
        super().__init__(**kwargs)
        assert subset in ["ARC-Easy", "ARC-Challenge"], "ARC subset must be ARC-Easy or ARC-Challenge"
        assert split in ["train", "validation", "test"], "ARC split must be train|validation|test"
        self.ds = load_dataset("allenai/ai2_arc", subset, split=split).shuffle(seed=42)

    @property
    def eval_type(self):
        return 'categorical'

    def get_example(self, index):
        row = self.ds[index]
        question = row["question"]
        choices = row["choices"]["text"]
        answer_string = row["answerKey"] # e.g. "A", "B", "C", "D"
        letters = row["choices"]["label"]
        
        # create and return the Conversation object
        user_message = render_mc(question, letters, choices)
        messages = [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": answer_string}
        ]
        conversation = {
            "messages": messages,
            "letters": letters, # useful during evaluation for answer validation
        }
        return conversation

    def evaluate(self, conversation, assistant_response):
        assert assistant_response in conversation['letters']
        assistant_message = conversation['messages'][-1]['content']
        return assistant_response == assistant_message

Generative Tasks

GSM8K Mathematics

GSM8K demonstrates tool use with calculator integration:

Source: tasks/gsm8k.py:35-75

python
def get_example(self, index):
    row = self.ds[index]
    question = row['question']
    answer = row['answer'] # string with solution and answer after #### marker
    
    # Parse tool calls from GSM8K format
    assistant_message_parts = []
    parts = re.split(r'(<<[^>]+>>)', answer)
    for part in parts:
        if part.startswith('<<') and part.endswith('>>'):
            # This is a calculator tool call
            inner = part[2:-2]  # Remove << >>
            if '=' in inner:
                expr, result = inner.rsplit('=', 1)
            else:
                expr, result = inner, ""
            # Add the tool call as a part
            assistant_message_parts.append({"type": "python", "text": expr})
            # Add the result as a part
            assistant_message_parts.append({"type": "python_output", "text": result})
        else:
            # Regular text between tool calls
            assistant_message_parts.append({"type": "text", "text": part})
    
    messages = [
        {"role": "user", "content": question},
        {"role": "assistant", "content": assistant_message_parts},
    ]
    return {"messages": messages}

Answer Extraction

Source: tasks/gsm8k.py:15-25

python
GSM_RE = re.compile(r"#### (\-?[0-9\.\,]+)")

def extract_answer(completion):
    """
    Extract the numerical answer after #### marker.
    Follows official code for normalization.
    """
    match = GSM_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    return None

Evaluation Execution

Categorical Evaluation

Source: scripts/chat_eval.py:85-140

python
def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems=None):
    
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    device = model.get_device()
    
    num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
    num_passed, total = 0, 0
    
    # Process in batches
    for batch_start in range(ddp_rank * batch_size, num_problems, ddp_world_size * batch_size):
        batch_end = min(batch_start + batch_size, num_problems)
        batch_conversations = [task_object[i] for i in range(batch_start, batch_end)]
        
        # Batch the prompts and get logits
        batch_prompts = [tokenizer.render_for_completion(conv) for conv in batch_conversations]
        max_length = max(len(prompt) for prompt in batch_prompts)
        
        # Pad all prompts to same length
        inputs = torch.full((len(batch_prompts), max_length), tokenizer.get_bos_token_id(), 
                           dtype=torch.long, device=device)
        for i, prompt in enumerate(batch_prompts):
            inputs[i, :len(prompt)] = torch.tensor(prompt, dtype=torch.long)
        
        # Get logits and extract predictions
        with torch.no_grad():
            logits = model(inputs)  # (B, T, vocab_size)
            last_logits = logits[:, -1, :]  # (B, vocab_size)
        
        # Convert letter choices to token ids and find most likely
        for i, conversation in enumerate(batch_conversations):
            letters = conversation.get("letters", ["A", "B", "C", "D"])
            letter_logits = []
            for letter in letters:
                letter_token = tokenizer.encode(letter)[0]
                letter_logits.append(last_logits[i, letter_token])
            
            # Find the letter with highest logit
            best_idx = torch.argmax(torch.stack(letter_logits))
            predicted_letter = letters[best_idx]
            
            # Evaluate this prediction
            outcome = task_object.evaluate(conversation, predicted_letter)
            num_passed += int(outcome)
            total += 1

    return num_passed / total if total > 0 else 0.0

Generative Evaluation

Source: scripts/chat_eval.py:25-65

python
def run_generative_eval(task_object, tokenizer, model, engine, num_samples, max_new_tokens, 
                       temperature, top_k, max_problems=None):

    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    device = model.get_device()
    
    num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
    num_passed, total = 0, 0
    
    for i in range(ddp_rank, num_problems, ddp_world_size):
        conversation = task_object[i]
        
        # Tokenize the prompt for completion
        encoded_prompt = tokenizer.render_for_completion(conversation)
        
        # Generate response(s)
        if num_samples == 1:
            response_tokens_list, _ = engine.generate_batch(
                encoded_prompt, num_samples=1, max_tokens=max_new_tokens,
                temperature=temperature, top_k=top_k
            )
            response_tokens = response_tokens_list[0]
            response_text = tokenizer.decode(response_tokens)
            
            # Evaluate the response
            outcome = task_object.evaluate(conversation, response_text)
            num_passed += int(outcome)
        else:
            # Multiple samples - take best of N
            best_outcome = 0
            response_tokens_list, _ = engine.generate_batch(
                encoded_prompt, num_samples=num_samples, max_tokens=max_new_tokens,
                temperature=temperature, top_k=top_k
            )
            for response_tokens in response_tokens_list:
                response_text = tokenizer.decode(response_tokens)
                outcome = task_object.evaluate(conversation, response_text)
                best_outcome = max(best_outcome, outcome)
            num_passed += int(best_outcome)
        
        total += 1

    return num_passed / total if total > 0 else 0.0

Evaluation Entry Point

Unified Evaluation Interface

Source: scripts/chat_eval.py:160-180

python
def run_chat_eval(task_name, model, tokenizer, engine,
                  batch_size=1, num_samples=1, max_new_tokens=512, temperature=0.0, top_k=50,
                  max_problems=None):
    # Create the evaluation object
    task_module = {
        'HumanEval': HumanEval,
        'MMLU': partial(MMLU, subset="all", split="test"),
        'ARC-Easy': partial(ARC, subset="ARC-Easy", split="test"),
        'ARC-Challenge': partial(ARC, subset="ARC-Challenge", split="test"),
        'GSM8K': partial(GSM8K, subset="main", split="test"),
        'SpellingBee': partial(SpellingBee, size=256, split="test"),
    }[task_name]
    task_object = task_module()
    
    # Run the appropriate evaluation type
    if task_object.eval_type == 'generative':
        acc = run_generative_eval(task_object, tokenizer, model, engine, 
                                num_samples, max_new_tokens, temperature, top_k, 
                                max_problems=max_problems)
    elif task_object.eval_type == 'categorical':
        acc = run_categorical_eval(task_object, tokenizer, model, batch_size, 
                                 max_problems=max_problems)
    else:
        raise ValueError(f"Unsupported task evaluation type: {task_object.eval_type}")
    return acc

CORE Metric

Base Model Evaluation

The CORE metric from the DCLM paper evaluates base model quality:

Source: nanochat/core_eval.py:1-15

python
"""
Functions for evaluating the CORE metric, as described in the DCLM paper.
https://arxiv.org/abs/2406.11794

TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""

def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
    """Render complete prompts for a multiple choice question"""
    template_str = """
{%- for example in fewshot_examples -%}
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}

{% endfor -%}
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()

Available Tasks

Multiple Choice Tasks

  • ARC-Easy/Challenge: Science reasoning questions
  • MMLU: Broad knowledge across academic subjects
  • Auxiliary MMLU: Training subset with diverse topics

Generative Tasks

  • GSM8K: Grade school math with tool use
  • HumanEval: Python code generation
  • SpellingBee: Letter counting and spelling
  • SimpleSpelling: Basic word spelling

Conversational Tasks

  • SmolTalk: General conversation dataset
  • CustomJSON: Synthetic identity conversations

Usage Examples

Training Data Mixture

Source: scripts/chat_sft.py:85-95

python
train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"), # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
    GSM8K(subset="main", split="train"), # 8K rows
    SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
    CustomJSON(filepath=identity_conversations_filepath), # 1K synthetic identity
    SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling
    SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee
]) # 23K rows total

Evaluation Command

bash
# Single task evaluation
python -m scripts.chat_eval -a ARC-Easy

# Distributed evaluation
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a GSM8K

# Multiple tasks
python -m scripts.chat_eval -a ARC-Easy,MMLU,GSM8K

The evaluation framework provides a comprehensive, extensible system for assessing model capabilities across diverse domains while maintaining consistency and supporting distributed execution.


Sources:

  • tasks/common.py (base framework and task composition)
  • tasks/arc.py (multiple choice task implementation)
  • tasks/gsm8k.py (generative task with tool use)
  • scripts/chat_eval.py (evaluation execution logic)
  • nanochat/core_eval.py (CORE metric implementation)
Last updated: 1/10/2026