Chapter 10

实战:多跳问答 RAG + 蒸馏到小模型

把前 9 章串起来:从 HotpotQA 数据出发,写 MultiHop Module,用 MIPROv2 把 GPT-4o-mini 调到 F1 0.51,再 BootstrapFinetune 到 Llama-3 8B,推理成本降到 1/20。

任务定义

HotpotQA:给定问题 + 维基百科检索接口,要回答"哪部小说的作者曾在哈佛任教?"这种需要多次检索 + 推理的开放问答。

目标:

  1. baseline:单跳 RAG + GPT-4o-mini,F1 约 0.32
  2. MultiHop + MIPROv2:同模型 F1 提到 0.51
  3. 蒸馏到 Llama-3 8B:成本降 20 倍,F1 保持 0.46+

Step 1:准备数据

from datasets import load_dataset
from dspy import Example
import random

raw = load_dataset("hotpot_qa", "distractor", split="validation")

def to_example(r):
    return Example(
        question=r["question"],
        answer=r["answer"],
        gold_titles=set(r["supporting_facts"]["title"]),
    ).with_inputs("question")

data = [to_example(r) for r in raw]
random.seed(42); random.shuffle(data)

trainset = data[:200]
valset   = data[200:400]
testset  = data[400:600]

200 train + 200 val 跑 MIPROv2 差不多 30 分钟,成本 $8 左右(gpt-4o-mini + gpt-4o 提 prompt)。

Step 2:写 MultiHop Module

import dspy

class GenerateQuery(dspy.Signature):
    """根据已有 context 和 question 生成下一次的检索关键词。
    关键词要能补充缺失的事实,而不是重复已知的。"""
    context: list[str] = dspy.InputField()
    question: str = dspy.InputField()
    search_query: str = dspy.OutputField()

class GenerateAnswer(dspy.Signature):
    """从 context 中提取简短答案(通常 1-5 个词)。"""
    context: list[str] = dspy.InputField()
    question: str = dspy.InputField()
    answer: str = dspy.OutputField()

class MultiHopRAG(dspy.Module):
    def __init__(self, max_hops=2, k=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=k)
        self.gen_query = dspy.ChainOfThought(GenerateQuery)
        self.gen_answer = dspy.ChainOfThought(GenerateAnswer)
        self.max_hops = max_hops

    def forward(self, question):
        context = []
        for _ in range(self.max_hops):
            q = self.gen_query(context=context, question=question).search_query
            context = dspy.deduplicate(context + self.retrieve(q).passages)
        return self.gen_answer(context=context, question=question)

Step 3:metric

import re

def normalize(s):
    return re.sub(r"[^\w\s]", "", s.lower()).strip()

def hotpot_f1(ex, pred, trace=None):
    g = set(normalize(ex.answer).split())
    p = set(normalize(pred.answer).split())
    if not g or not p: return 0.0
    prec = len(g & p) / len(p)
    rec  = len(g & p) / len(g)
    return 2*prec*rec/(prec+rec) if (prec+rec) else 0.0

Step 4:baseline 跑一次

dspy.configure(
    lm=dspy.LM("openai/gpt-4o-mini", max_tokens=300),
    rm=dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts"),
)

from dspy.evaluate import Evaluate
ev = Evaluate(devset=testset, metric=hotpot_f1, num_threads=8)

baseline = MultiHopRAG()
print("baseline F1:", ev(baseline))   # 约 0.32

Step 5:MIPROv2 编译

from dspy.teleprompt import MIPROv2

mipro = MIPROv2(
    metric=hotpot_f1,
    prompt_model=dspy.LM("openai/gpt-4o"),
    task_model=dspy.LM("openai/gpt-4o-mini"),
    auto="medium",
    num_threads=8,
)

teacher = mipro.compile(
    student=MultiHopRAG(),
    trainset=trainset,
    valset=valset,
    requires_permission_to_run=False,
)

teacher.save("artifacts/hotpot_teacher.json")
print("teacher F1:", ev(teacher))      # 约 0.51

看看 MIPRO 选出了什么指令:

for name, pred in teacher.named_predictors():
    print("==", name)
    print(pred.extended_signature.instructions[:300])

你会看到类似这样的生成指令——不是你写的,是 GPT-4o 自动搜出来的:

== gen_query.predict
You are a research assistant that formulates precise Wikipedia search
queries. Analyze which entity or fact is missing from the current context
and produce a keyword query that targets exactly that gap. Prefer proper
nouns and distinctive terms over common words...

Step 6:BootstrapFinetune 蒸馏

from dspy.teleprompt import BootstrapFinetune

finetune = BootstrapFinetune(
    metric=hotpot_f1,
    num_threads=8,
    multitask=True,          # gen_query 和 gen_answer 共用一个 LoRA
)

student = finetune.compile(
    teacher,
    trainset=trainset + valset,    # 蒸馏时可以合并
    target="meta-llama/Llama-3-8B-Instruct",
    epochs=3,
    lr=1e-5,
    batch_size=8,
)

student.save("artifacts/hotpot_student.json")
蒸馏的本质
BootstrapFinetune 让 teacher 把 trainset 每道题都跑一遍,记录完整 trace(每次 LLM 的输入 prompt + 输出)。过 metric 的 trace 变成 Llama-3 的 supervised 数据,用 LoRA 小步长 finetune 几 epoch。

Step 7:换 LM 评估 student

with dspy.context(lm=dspy.LM("hosted_vllm/Llama-3-8B-Instruct+ft", base_url="http://vllm:8000/v1")):
    print("student F1:", ev(student))   # 约 0.46

F1 只掉了 5 个点,但是:

GPT-4o-mini teacherLlama-3-8B student
F10.510.46
每千次推理成本$1.20$0.06
p95 latency2.4s0.9s(本地 GPU)
数据可控❌(OpenAI)✅(自托管)

Step 8:上线

from fastapi import FastAPI

app = FastAPI()

@app.on_event("startup")
def setup():
    dspy.configure(
        lm=dspy.LM("hosted_vllm/Llama-3-8B-Instruct+ft", base_url="http://vllm:8000/v1"),
        rm=dspy.ColBERTv2(url="http://colbert:8893"),
    )
    app.state.rag = MultiHopRAG()
    app.state.rag.load("artifacts/hotpot_student.json")

@app.post("/ask")
async def ask(q: dict):
    pred = await app.state.rag.acall(question=q["question"])
    return {"answer": pred.answer}

全流程时间 / 成本账

阶段时间成本
数据准备10 分钟$0
baseline 评估5 分钟$0.5
MIPROv2 编译(auto=medium)30 分钟$8
teacher 评估5 分钟$0.8
BootstrapFinetune(3 epoch,1 张 A100)约 40 分钟$2 GPU 租金
student 评估3 分钟$0(本地)
总计约 1.5 小时约 $11.3

从 $11 投入到每千次推理省下 $1.14,大约 1 万次调用就回本,之后都是净赚。

踩过的坑

  1. retriever 返回的段落过长挤爆 context:把 k=5 改成 k=3,或在 retriever 端开 passage truncation
  2. MIPRO 有一轮突然劣化:task_model 临时 rate-limit 导致部分样本失败被当低分,加 num_retries=3
  3. Llama-3 蒸馏后输出冗长:Signature desc 里硬加 "answer in 1-5 words",student 才学会简短
  4. vLLM 加载 LoRA 要 merge:训练完把 LoRA merge 回 base 权重,避免推理时多一层 adapter 延迟

扩展方向

本章小结

通关感言
你已经学完 10 章了。现在回到第 1 章那段"以后 prompt 要像代码一样调"——希望读完这本教程,你能把它当成一条生产原则,而不仅是口号。祝编译愉快。