构建基于RAG(Retrieval-Augmented Generation)的服务是当前自然语言处理领域中的热门话题。FastAPI 是一个现代、快速的 Python 框架,非常适合用于构建高性能的 API 服务。本文将详细介绍如何使用 FastAPI 构建一个 RAG 服务,并涵盖从数据准备到模型部署的完整流程。
RAG 是一种结合检索和生成的混合模型架构,它通过以下步骤工作:
这种架构的优势在于可以动态地利用外部知识库,而不需要重新训练模型。
在开始之前,请确保安装了以下依赖项:
pip install fastapi uvicorn transformers faiss-cpu sentence-transformers torch
假设我们有一个包含文档的 JSON 文件 data.json
,格式如下:
[
{"id": "doc1", "title": "Python 编程入门", "content": "Python 是一种易于学习的编程语言..."},
{"id": "doc2", "title": "机器学习基础", "content": "机器学习是一种人工智能技术..."}
]
sentence-transformers
将文档内容转换为向量表示。代码示例:
from sentence_transformers import SentenceTransformer
import faiss
import json
# 加载文档数据
with open("data.json", "r") as f:
documents = json.load(f)
# 初始化句子嵌入模型
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
# 提取文档内容并生成嵌入
texts = [doc["content"] for doc in documents]
embeddings = model.encode(texts)
# 创建 FAISS 索引
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
接下来,我们将创建一个 FastAPI 应用程序,提供两个主要功能:
from fastapi import FastAPI, Query
from transformers import pipeline
from typing import List
app = FastAPI()
# 加载生成模型
generator = pipeline("text2text-generation", model="facebook/bart-base")
@app.post("/retrieve/")
def retrieve(query: str = Query(..., description="用户查询")):
"""
根据用户查询检索相关文档。
"""
query_embedding = model.encode([query])
_, indices = index.search(query_embedding, k=3) # 检索前 3 个最相似的文档
related_docs = [documents[i] for i in indices[0]]
return {"related_docs": related_docs}
@app.post("/generate/")
def generate(query: str = Query(..., description="用户查询")):
"""
根据用户查询生成回答。
"""
# 检索相关文档
query_embedding = model.encode([query])
_, indices = index.search(query_embedding, k=3)
contexts = [documents[i]["content"] for i in indices[0]]
# 将查询和上下文拼接
input_text = f"Query: {query}\nContext: {' '.join(contexts)}"
# 生成回答
result = generator(input_text, max_length=100, num_return_sequences=1)
return {"answer": result[0]["generated_text"]}
启动服务:
uvicorn main:app --reload
访问以下端点进行测试:
检索相关文档:
curl -X POST "http://127.0.0.1:8000/retrieve/" -d '{"query":"Python"}'
生成回答:
curl -X POST "http://127.0.0.1:8000/generate/" -d '{"query":"什么是机器学习?"}'
为了提高性能,可以考虑以下优化措施:
RAG 服务的应用场景非常广泛,例如客服系统、搜索引擎增强等。未来可以进一步探索:
flowchart TD A[用户查询] --> B{检索阶段} B --> C[FAISS 检索相关文档] C --> D{生成阶段} D --> E[生成回答]