首页
/ 基于AWS SageMaker和Hugging Face的检索增强生成(RAG)问答系统实现

基于AWS SageMaker和Hugging Face的检索增强生成(RAG)问答系统实现

2025-07-10 05:13:44作者:宣利权Counsellor

概述

本文将详细介绍如何在AWS SageMaker平台上构建一个检索增强生成(Retrieval-Augmented Generation, RAG)问答系统。该系统结合了Hugging Face的大型语言模型(LLM)和句子嵌入技术,能够基于自定义知识库提供准确的问答服务。

技术架构

整个系统由以下几个核心组件构成:

  1. 大型语言模型(LLM):使用Flan T5 XL作为生成模型
  2. 嵌入模型:使用MiniLM句子嵌入模型
  3. 检索机制:基于向量相似度的文档检索
  4. 提示工程:精心设计的提示模板

环境准备

首先需要安装必要的Python库:

!pip install -qU \
    sagemaker==2.173.0 \
    pinecone-client==3.1.0 \
    ipywidgets==7.0.0

部署大型语言模型

我们使用SageMaker JumpStart来部署Flan T5 XL模型:

import sagemaker
from sagemaker.huggingface import (
    HuggingFaceModel,
    get_huggingface_llm_image_uri
)

role = sagemaker.get_execution_role()

hub_config = {
    'HF_MODEL_ID':'google/flan-t5-xl',
    'HF_TASK':'text-generation'
}

llm_image = get_huggingface_llm_image_uri(
  "huggingface",
  version="0.8.2"
)

huggingface_model = HuggingFaceModel(
    env=hub_config,
    role=role,
    image_uri=llm_image
)

llm = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.4xlarge",
    endpoint_name="flan-t5-demo"
)

直接提问的局限性

在没有提供上下文的情况下,直接向模型提问会得到不准确的结果:

question = "Which instances can I use with Managed Spot Training in SageMaker?"

out = llm.predict({"inputs": question})
# 输出可能不准确

检索增强生成(RAG)方法

1. 部署嵌入模型

我们使用MiniLM模型来生成文档和查询的嵌入向量:

hub_config = {
    'HF_MODEL_ID': 'sentence-transformers/all-MiniLM-L6-v2',
    'HF_TASK': 'feature-extraction'
}

huggingface_model = HuggingFaceModel(
    env=hub_config,
    role=role,
    transformers_version="4.6",
    pytorch_version="1.7",
    py_version="py36",
)

encoder = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.t2.large",
    endpoint_name="minilm-demo"
)

2. 嵌入向量生成

定义一个函数来处理文本并生成嵌入向量:

from typing import List
import numpy as np

def embed_docs(docs: List[str]) -> List[List[float]]:
    out = encoder.predict({'inputs': docs})
    embeddings = np.mean(np.array(out), axis=1)
    return embeddings.tolist()

3. 知识库准备

使用Amazon SageMaker FAQs作为知识库:

s3_path = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv"
!aws s3 cp $s3_path Amazon_SageMaker_FAQs.csv

完整RAG流程实现

  1. 文档处理:读取知识库文档并生成嵌入向量
  2. 查询处理:对用户查询生成嵌入向量
  3. 相似度计算:找到与查询最相关的文档
  4. 答案生成:将相关文档作为上下文与问题一起输入LLM
prompt_template = """Answer the following QUESTION based on the CONTEXT
given. If you do not know the answer and the CONTEXT doesn't
contain the answer truthfully say "I don't know".

CONTEXT:
{context}

QUESTION:
{question}

ANSWER:
"""

def generate_answer(question, context):
    text_input = prompt_template.replace("{context}", context).replace("{question}", question)
    out = llm.predict({"inputs": text_input})
    return out[0]["generated_text"]

系统优势

  1. 准确性提升:通过检索相关文档提供上下文,显著提高回答准确性
  2. 灵活性:可以轻松替换知识库以适应不同领域
  3. 可控性:通过提示工程控制模型行为
  4. 可扩展性:支持多种LLM和嵌入模型

实际应用建议

  1. 知识库优化:确保知识库文档质量高且覆盖面广
  2. 分块策略:合理设置文档块大小以平衡信息量和计算效率
  3. 模型选择:根据需求选择合适的LLM和嵌入模型
  4. 性能监控:定期评估系统性能并进行优化

通过本文介绍的方法,开发者可以在AWS SageMaker平台上快速构建高效、准确的问答系统,适用于各种业务场景。