基于AWS SageMaker和Hugging Face的检索增强生成(RAG)问答系统实现
2025-07-10 05:13:44作者:宣利权Counsellor
概述
本文将详细介绍如何在AWS SageMaker平台上构建一个检索增强生成(Retrieval-Augmented Generation, RAG)问答系统。该系统结合了Hugging Face的大型语言模型(LLM)和句子嵌入技术,能够基于自定义知识库提供准确的问答服务。
技术架构
整个系统由以下几个核心组件构成:
- 大型语言模型(LLM):使用Flan T5 XL作为生成模型
- 嵌入模型:使用MiniLM句子嵌入模型
- 检索机制:基于向量相似度的文档检索
- 提示工程:精心设计的提示模板
环境准备
首先需要安装必要的Python库:
!pip install -qU \
sagemaker==2.173.0 \
pinecone-client==3.1.0 \
ipywidgets==7.0.0
部署大型语言模型
我们使用SageMaker JumpStart来部署Flan T5 XL模型:
import sagemaker
from sagemaker.huggingface import (
HuggingFaceModel,
get_huggingface_llm_image_uri
)
role = sagemaker.get_execution_role()
hub_config = {
'HF_MODEL_ID':'google/flan-t5-xl',
'HF_TASK':'text-generation'
}
llm_image = get_huggingface_llm_image_uri(
"huggingface",
version="0.8.2"
)
huggingface_model = HuggingFaceModel(
env=hub_config,
role=role,
image_uri=llm_image
)
llm = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g5.4xlarge",
endpoint_name="flan-t5-demo"
)
直接提问的局限性
在没有提供上下文的情况下,直接向模型提问会得到不准确的结果:
question = "Which instances can I use with Managed Spot Training in SageMaker?"
out = llm.predict({"inputs": question})
# 输出可能不准确
检索增强生成(RAG)方法
1. 部署嵌入模型
我们使用MiniLM模型来生成文档和查询的嵌入向量:
hub_config = {
'HF_MODEL_ID': 'sentence-transformers/all-MiniLM-L6-v2',
'HF_TASK': 'feature-extraction'
}
huggingface_model = HuggingFaceModel(
env=hub_config,
role=role,
transformers_version="4.6",
pytorch_version="1.7",
py_version="py36",
)
encoder = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.t2.large",
endpoint_name="minilm-demo"
)
2. 嵌入向量生成
定义一个函数来处理文本并生成嵌入向量:
from typing import List
import numpy as np
def embed_docs(docs: List[str]) -> List[List[float]]:
out = encoder.predict({'inputs': docs})
embeddings = np.mean(np.array(out), axis=1)
return embeddings.tolist()
3. 知识库准备
使用Amazon SageMaker FAQs作为知识库:
s3_path = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv"
!aws s3 cp $s3_path Amazon_SageMaker_FAQs.csv
完整RAG流程实现
- 文档处理:读取知识库文档并生成嵌入向量
- 查询处理:对用户查询生成嵌入向量
- 相似度计算:找到与查询最相关的文档
- 答案生成:将相关文档作为上下文与问题一起输入LLM
prompt_template = """Answer the following QUESTION based on the CONTEXT
given. If you do not know the answer and the CONTEXT doesn't
contain the answer truthfully say "I don't know".
CONTEXT:
{context}
QUESTION:
{question}
ANSWER:
"""
def generate_answer(question, context):
text_input = prompt_template.replace("{context}", context).replace("{question}", question)
out = llm.predict({"inputs": text_input})
return out[0]["generated_text"]
系统优势
- 准确性提升:通过检索相关文档提供上下文,显著提高回答准确性
- 灵活性:可以轻松替换知识库以适应不同领域
- 可控性:通过提示工程控制模型行为
- 可扩展性:支持多种LLM和嵌入模型
实际应用建议
- 知识库优化:确保知识库文档质量高且覆盖面广
- 分块策略:合理设置文档块大小以平衡信息量和计算效率
- 模型选择:根据需求选择合适的LLM和嵌入模型
- 性能监控:定期评估系统性能并进行优化
通过本文介绍的方法,开发者可以在AWS SageMaker平台上快速构建高效、准确的问答系统,适用于各种业务场景。