实战教程:使用FastAPI构建RAG服务

2025-06发布5次浏览

构建基于RAG(Retrieval-Augmented Generation)的服务是当前自然语言处理领域中的热门话题。FastAPI 是一个现代、快速的 Python 框架,非常适合用于构建高性能的 API 服务。本文将详细介绍如何使用 FastAPI 构建一个 RAG 服务,并涵盖从数据准备到模型部署的完整流程。


1. RAG 简介

RAG 是一种结合检索和生成的混合模型架构,它通过以下步骤工作:

  1. 检索阶段:从大规模文档集合中检索与用户查询相关的上下文。
  2. 生成阶段:将检索到的上下文与用户查询一起输入到生成模型中,生成最终的回答。

这种架构的优势在于可以动态地利用外部知识库,而不需要重新训练模型。


2. 环境准备

在开始之前,请确保安装了以下依赖项:

pip install fastapi uvicorn transformers faiss-cpu sentence-transformers torch
  • FastAPI:用于构建 API。
  • Transformers:Hugging Face 提供的库,用于加载预训练模型。
  • FAISS:Facebook 开发的高效向量检索库。
  • Sentence-Transformers:用于生成文本嵌入。

3. 数据准备

假设我们有一个包含文档的 JSON 文件 data.json,格式如下:

[
    {"id": "doc1", "title": "Python 编程入门", "content": "Python 是一种易于学习的编程语言..."},
    {"id": "doc2", "title": "机器学习基础", "content": "机器学习是一种人工智能技术..."}
]

步骤:

  1. 加载文档数据。
  2. 使用 sentence-transformers 将文档内容转换为向量表示。
  3. 将向量存储到 FAISS 索引中以支持快速检索。

代码示例:

from sentence_transformers import SentenceTransformer
import faiss
import json

# 加载文档数据
with open("data.json", "r") as f:
    documents = json.load(f)

# 初始化句子嵌入模型
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

# 提取文档内容并生成嵌入
texts = [doc["content"] for doc in documents]
embeddings = model.encode(texts)

# 创建 FAISS 索引
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

4. 构建 FastAPI 服务

接下来,我们将创建一个 FastAPI 应用程序,提供两个主要功能:

  1. 检索相关文档
  2. 生成回答

完整代码示例:

from fastapi import FastAPI, Query
from transformers import pipeline
from typing import List

app = FastAPI()

# 加载生成模型
generator = pipeline("text2text-generation", model="facebook/bart-base")

@app.post("/retrieve/")
def retrieve(query: str = Query(..., description="用户查询")):
    """
    根据用户查询检索相关文档。
    """
    query_embedding = model.encode([query])
    _, indices = index.search(query_embedding, k=3)  # 检索前 3 个最相似的文档
    related_docs = [documents[i] for i in indices[0]]
    return {"related_docs": related_docs}

@app.post("/generate/")
def generate(query: str = Query(..., description="用户查询")):
    """
    根据用户查询生成回答。
    """
    # 检索相关文档
    query_embedding = model.encode([query])
    _, indices = index.search(query_embedding, k=3)
    contexts = [documents[i]["content"] for i in indices[0]]

    # 将查询和上下文拼接
    input_text = f"Query: {query}\nContext: {' '.join(contexts)}"
    
    # 生成回答
    result = generator(input_text, max_length=100, num_return_sequences=1)
    return {"answer": result[0]["generated_text"]}

5. 测试服务

启动服务:

uvicorn main:app --reload

访问以下端点进行测试:

  • 检索相关文档

    curl -X POST "http://127.0.0.1:8000/retrieve/" -d '{"query":"Python"}'
    
  • 生成回答

    curl -X POST "http://127.0.0.1:8000/generate/" -d '{"query":"什么是机器学习?"}'
    

6. 性能优化

为了提高性能,可以考虑以下优化措施:

  1. 异步处理:使用 FastAPI 的异步功能处理耗时任务。
  2. 批量检索:对多个查询进行批量处理以减少延迟。
  3. GPU 支持:将模型加载到 GPU 上以加速推理。

7. 讨论与扩展

RAG 服务的应用场景非常广泛,例如客服系统、搜索引擎增强等。未来可以进一步探索:

  • 引入更复杂的检索模型(如 BM25 和语义检索结合)。
  • 使用更大的生成模型(如 T5 或 GPT 系列)。
  • 实现多语言支持。
flowchart TD
    A[用户查询] --> B{检索阶段}
    B --> C[FAISS 检索相关文档]
    C --> D{生成阶段}
    D --> E[生成回答]