首页
/ 基于Unsloth的AI量化交易股票预测模型训练指南

基于Unsloth的AI量化交易股票预测模型训练指南

2025-07-09 05:30:23作者:蔡怀权

项目概述

本项目构建了一个基于金融新闻分析的股票价格预测系统,采用Unsloth框架对Qwen3-0.6B模型进行4-bit LoRA微调,实现了高效的金融文本推理能力。系统能够分析上市公司相关新闻,预测下一个交易日的股价涨跌趋势。

技术架构

核心组件

  1. 模型架构:基于Qwen3-0.6B模型,采用4-bit量化技术
  2. 微调方法:使用Unsloth框架进行LoRA微调
  3. 推理引擎:支持Ollama本地部署

性能表现

在500个测试样本上的准确率对比:

  • LoRA微调模型:49%
  • 原始Qwen3-0.6B模型:47.4%
  • DeepSeek-R1基准模型:45.6%

数据准备

数据来源

使用IDEA-FinAI/Finance-R1-Reasoning数据集,包含:

  • 1000条训练样本
  • 1000条测试样本

数据格式

每条数据包含以下字段:

  • instruction:任务指令
  • input:输入新闻内容
  • thinks:模型推理过程
  • output:预测结果(上涨/下跌)

数据处理脚本

def convert_format(in_file_path, out_file_path):
    data_lst = []
    df = pd.read_csv(in_file_path)
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        instruction = row['instruction']
        inputs = row['input']
        thinks = row['thinks']
        output = row['output']
        dialog = [
                {"role": "user", "content": '###新闻###\n' + inputs +
                                            '\n\n###任务###\n' + instruction},
                {"role": "assistant", "content": '<think>\n' + thinks + '\n</think>\n\n' + output},
        ]
        data_lst.append(dialog)
    logger.info(f"Total {len(data_lst)} records converted.")

    conv = {"conversations": data_lst}
    with open(out_file_path, 'w', encoding='utf-8') as f:
        json.dump(conv, f, ensure_ascii=False, indent=4)
    logger.info(f"Converted data saved to {out_file_path}")

环境配置

硬件要求

  • GPU:支持CUDA 12.1及以上
  • 显存:至少8GB(推荐16GB以上)

软件依赖

# 基础依赖
pip install -r requirements.txt

# 更新Unsloth
pip install --upgrade --force-reinstall --no-cache-dir unsloth unsloth_zoo

# Flash Attention 2
pip3 install flash-attn --no-build-isolation

环境验证

# 检查xformers安装
python -m xformers.info

# 检查bitsandbytes安装
python -m bitsandbytes

模型训练

训练配置

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = mdl_path,
    max_seq_length = 32768,
    load_in_4bit = True,
    load_in_8bit = False,
    full_finetuning = False,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = True,
    loftq_config = None,
)

训练参数

参数 说明
batch_size 2 单设备批次大小
gradient_accumulation 4 梯度累积步数
learning_rate 2e-5 学习率
epochs 1 训练轮数
warmup_steps 5 预热步数

显存占用参考

模型参数 4-bit QLoRA 16-bit LoRA
3B 3.5 GB 8 GB
7B 5 GB 19 GB
8B 6 GB 22 GB

模型部署

Ollama安装

apt-get update
apt-get install pciutils -y
curl -fsSL https://ollama.com/install.sh | sh

模型运行

ollama run hf.co/unsloth/Qwen3-8B-GGUF:Q4_K_XL

优化方向

数据优化

  1. 清理无效样本(新闻为空的数据)
  2. 补充股价数据作为辅助特征
  3. 扩展时间窗口(周/月级别新闻分析)

模型优化

  1. 引入强化学习机制
  2. 尝试更大规模的模型
  3. 优化推理过程提示词

测试框架

  1. 构建回测系统验证预测效果
  2. 增加多维度评估指标
  3. 优化结果提取逻辑

注意事项

  1. 本模型预测结果仅供参考,不构成投资建议
  2. 金融市场存在风险,投资需谨慎
  3. 建议在实际应用前进行充分的回测验证

通过本指南,开发者可以快速搭建一个基于金融新闻分析的股票预测系统,并可根据实际需求进行定制化调整。