最简 RAG Module
class RAG(dspy.Module): def __init__(self, num_passages=3): super().__init__() self.retrieve = dspy.Retrieve(k=num_passages) self.generate = dspy.ChainOfThought("context, question -> answer") def forward(self, question): passages = self.retrieve(question).passages return self.generate(context=passages, question=question) # 配好 RM 就能直接跑 dspy.configure(lm=lm, rm=dspy.ColBERTv2(url="http://...")) rag = RAG() rag("谁发明了 Transformer?").answer
dspy.Module 的约定
__init__
把子 Module(Predict/CoT/Retrieve/自定义)存成实例属性。DSPy 会自动发现并纳入优化范围。
forward(self, ...inputs)
定义实际逻辑。参数名要和 Signature 的 InputField 对应。返回 dspy.Prediction 或一个子 Module 的调用结果。
__call__
不要自己覆盖!基类会自动走 forward,并把所有中间调用记录进 trace,Optimizer 要用。
条件分支 Module
class SmartQA(dspy.Module): def __init__(self): super().__init__() self.classify = dspy.Predict("question -> kind: Literal['factual','opinion','calc']") self.retrieve = dspy.Retrieve(k=5) self.factual = dspy.ChainOfThought("context, question -> answer") self.opinion = dspy.ChainOfThought("question -> answer") self.calc = dspy.ProgramOfThought("question -> answer") def forward(self, question): kind = self.classify(question=question).kind if kind == "factual": ctx = self.retrieve(question).passages return self.factual(context=ctx, question=question) if kind == "calc": return self.calc(question=question) return self.opinion(question=question)
把分类器也当成 Module——优化器会同时优化分类器 + 每个分支。
多跳检索:MultiHop
一个问题可能需要多次检索(比如"比较 A 和 B",得分别查 A 和 B):
class MultiHopRAG(dspy.Module): def __init__(self, max_hops=3, num_passages=3): super().__init__() self.retrieve = dspy.Retrieve(k=num_passages) self.generate_query = dspy.ChainOfThought( "context, question -> search_query" ) self.generate_answer = dspy.ChainOfThought( "context, question -> answer" ) self.max_hops = max_hops def forward(self, question): context = [] for _ in range(self.max_hops): query = self.generate_query( context=context, question=question ).search_query passages = self.retrieve(query).passages context = dspy.deduplicate(context + passages) return self.generate_answer(context=context, question=question)
关键技巧:同名 Module 可以多次调用
上面
上面
generate_query 在循环里被调用多次,但只是一个 Module。Optimizer 优化一次,循环里每次调用都用同一份优化后的 prompt/demos。
参数化 Module
通过 __init__ 把超参暴露出来,便于不同场景复用:
class Summarizer(dspy.Module): def __init__(self, max_words=80, style="neutral"): super().__init__() self.max_words = max_words self.style = style sig = type( "SumSig", (dspy.Signature,), { "__doc__": f"产生 {style} 风格的摘要,不超过 {max_words} 字。", "text": dspy.InputField(), "summary": dspy.OutputField(), "__annotations__": {"text": str, "summary": str}, }, ) self.predict = dspy.Predict(sig) def forward(self, text): return self.predict(text=text)
嵌套 Module
Module 里用别的 Module,DSPy 的 trace 会正确 unroll:
class ArticleWriter(dspy.Module): def __init__(self): super().__init__() self.outline = dspy.ChainOfThought("topic -> outline: list[str]") self.section_writer = SectionWriter() # 子 Module self.editor = dspy.ChainOfThought("draft -> polished") def forward(self, topic): outline = self.outline(topic=topic).outline sections = [self.section_writer(topic=topic, point=p).text for p in outline] draft = "\n\n".join(sections) return self.editor(draft=draft)
查看 Module 内部
rag = RAG() # 列出所有 Predict 子模块 for name, pred in rag.named_predictors(): print(name, pred.signature) # 查看最近一次调用 trace out = rag("...") dspy.inspect_history(n=3) # 打印最近 3 次 LLM 调用的完整 prompt/response
保存与加载
# 保存 Module 状态(含优化后的 demos 和指令) rag.save("rag_compiled.json") # 加载 new = RAG() new.load("rag_compiled.json")
调试 Module
三板斧
①
② 临时把某个子 Module 换成
③ 开 LM 的
①
dspy.inspect_history():看 prompt 真的长什么样② 临时把某个子 Module 换成
dspy.Predict("... -> ..."),二分法定位问题③ 开 LM 的
cache=False 排除缓存干扰
Module 设计准则
- 职责单一:一个 Module 最好只负责 1-3 步,太深了难以优化
- 接口明确:
forward参数名和 Signature 对齐,不要用 **kwargs 魔法 - 子 Module 命名语义化:
self.query_gen比self.cot1好 - 避免副作用:Module 不写文件、不改全局,Optimizer 要反复调用
- 可测:forward 逻辑能 mock LM 跑单测,至少覆盖主分支
常见坑
| 坑 | 症状 | 解法 |
|---|---|---|
忘记 super().__init__() | 保存/优化失败 | 永远调父类构造 |
| 在 forward 里 new Module | Optimizer 抓不到 | 子 Module 必须在 __init__ 里定义 |
| forward 返回 dict | 属性访问报错 | 返回 dspy.Prediction(...) 或子 Module 调用 |
| 循环里不 dedupe | context 爆炸 | 用 dspy.deduplicate 或自行维护 set |
本章小结
- 继承
dspy.Module,__init__建子模块,forward写逻辑 - 分支、循环、嵌套都写普通 Python 即可——DSPy 会追踪全程
- 同一 Module 多次调用用同一份 prompt,Optimizer 优化一次就够
save/load把编译后的程序持久化,inspect_history查调用现场