LaTeX-OCR 模型训练指南:从数据准备到模型微调
前言
LaTeX-OCR 是一个将数学公式图像转换为 LaTeX 代码的开源项目,基于深度学习技术实现。本文将详细介绍如何训练和微调 LaTeX-OCR 模型,包括环境配置、数据准备、模型训练等完整流程。
环境准备
基础环境配置
首先需要安装必要的 Python 包:
pip install pix2tex[train] -qq
pip install gpustat -q
pip install opencv-python-headless==4.1.2.30 -U -q
pip install --upgrade --no-cache-dir gdown -q
这些包包括:
pix2tex[train]
: LaTeX-OCR 的核心训练模块gpustat
: GPU 状态监控工具opencv-python-headless
: 图像处理库gdown
: Google Drive 文件下载工具
检查 GPU 资源
训练深度学习模型需要 GPU 支持,可以使用以下命令检查 GPU 状态:
gpustat
确保你的 GPU 有足够的内存(至少 8GB)来训练模型。
数据准备
数据集下载
LaTeX-OCR 支持混合使用不同类型的数学公式数据:
mkdir -p dataset/data
mkdir images
gdown -O dataset/data/crohme.zip --id 13vjxGYrFCuYnwgDIUqkxsNGKk__D_sOM
gdown -O dataset/data/pdf.zip --id 176PKaCUDWmTJdQwc-OfkO0y8t4gLsIvQ
gdown -O dataset/data/pdfmath.txt --id 1QUjX6PFWPa-HBWdcY-7bA5TRVUnbyS1D
数据集包括:
- 手写数学公式图像 (crohme.zip)
- PDF 提取的数学公式图像 (pdf.zip)
- PDF 数学公式的 LaTeX 标注 (pdfmath.txt)
数据解压与分割
cd dataset/data
unzip -q crohme.zip
unzip -q pdf.zip
# 将手写数据分割为训练集和验证集
cd images
mkdir ../valimages
ls | shuf -n 1000 | xargs -i mv {} ../valimages
cd ../../..
这里我们将手写数据随机选取 1000 张作为验证集,其余作为训练集。
数据集预处理
生成训练集和验证集的 pickle 文件:
python -m pix2tex.dataset.dataset -i dataset/data/images dataset/data/train -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/train.pkl
python -m pix2tex.dataset.dataset -i dataset/data/valimages dataset/data/val -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt -o dataset/data/val.pkl
这些 pickle 文件包含图像路径、尺寸和对应的 LaTeX 代码,便于后续批量加载。
模型训练
预训练模型下载
curl -L -o weights.pth https://github.com/lukas-blecher/LaTeX-OCR/releases/download/v0.0.1/weights.pth
我们使用官方提供的预训练模型作为起点进行微调。
训练配置
创建训练配置文件 colab.yaml
:
backbone_layers: [2, 3, 7]
betas: [0.9, 0.999]
batchsize: 10
bos_token: 1
channels: 1
data: dataset/data/train.pkl
debug: true
decoder_args:
attn_on_attn: true
cross_attend: true
ff_glu: true
rel_pos_bias: false
use_scalenorm: false
dim: 256
encoder_depth: 4
eos_token: 2
epochs: 50
gamma: 0.9995
heads: 8
id: null
load_chkpt: weights.pth
lr: 0.001
lr_step: 30
max_height: 192
max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
model_path: checkpoints
name: mixed
num_layers: 4
num_tokens: 8000
optimizer: Adam
output_path: outputs
pad: false
pad_token: 0
patch_size: 16
sample_freq: 2000
save_freq: 1
scheduler: StepLR
seed: 42
temperature: 0.2
test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
关键配置说明:
batchsize
: 训练批次大小lr
: 初始学习率epochs
: 训练轮数max_height/max_width
: 输入图像最大尺寸num_tokens
: tokenizer 词汇表大小save_freq
: 模型保存频率
启动训练
python -m pix2tex.train --config colab.yaml
训练过程中会输出损失值和准确率等信息。如果配置了 W&B (Weights & Biases),还可以实时监控训练过程。
训练技巧与建议
-
学习率调整:初始学习率 0.001 适合大多数情况,如果训练不稳定可以适当降低。
-
批次大小:根据 GPU 内存调整 batchsize,较大的 batchsize 通常能带来更稳定的训练。
-
数据增强:可以添加随机裁剪、旋转等增强手段提高模型鲁棒性。
-
混合精度训练:如果支持,可以启用 AMP (Automatic Mixed Precision) 加速训练。
-
早停机制:监控验证集损失,当连续多轮不下降时停止训练。
常见问题解决
-
CUDA 内存不足:减小 batchsize 或输入图像尺寸。
-
训练不收敛:检查学习率是否合适,数据标注是否正确。
-
过拟合:增加数据量或添加正则化手段如 Dropout。
-
推理效果差:确保训练数据和实际应用场景相似。
总结
本文详细介绍了 LaTeX-OCR 模型的训练流程,从环境配置、数据准备到模型训练和微调。通过混合不同类型的数据(PDF 提取和手写公式),可以训练出更强大的 OCR 模型。实际应用中,可以根据具体需求调整模型结构和训练策略,以获得最佳性能。