PyTorch Image Models (timm) 快速入门指南
2025-07-05 03:38:22作者:龚格成
前言
PyTorch Image Models (timm) 是一个强大的计算机视觉模型库,提供了大量预训练模型和训练工具。本文将带你快速了解如何使用 timm 库进行模型加载、微调、特征提取和推理等常见任务。
安装 timm
使用 pip 可以轻松安装 timm 库:
pip install timm
建议在 Python 3.7+ 和 PyTorch 1.8+ 环境中使用 timm 以获得最佳体验。
加载预训练模型
timm 提供了 create_model
函数来加载预训练模型。以下示例展示了如何加载 MobileNetV3 大型模型:
import timm
# 加载预训练模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
# 将模型设置为评估模式
model.eval()
重要提示:新加载的模型默认处于训练模式,如果用于推理,必须调用 .eval()
方法。
查看可用模型
timm 提供了丰富的预训练模型,可以通过以下方式查看:
import timm
from pprint import pprint
# 查看所有有预训练权重的模型
pretrained_models = timm.list_models(pretrained=True)
pprint(pretrained_models)
# 使用通配符筛选特定模型
resnet_models = timm.list_models('*resne*t*')
pprint(resnet_models)
模型微调
微调预训练模型通常只需要修改最后的分类层:
# 假设我们的新任务有10个类别
NUM_CLASSES = 10
# 加载模型并修改分类层
model = timm.create_model(
'mobilenetv3_large_100',
pretrained=True,
num_classes=NUM_CLASSES
)
微调时,你需要准备自己的数据集并编写训练循环,或者修改 timm 提供的训练脚本。
特征提取
timm 支持不修改网络结构直接提取特征:
import torch
# 创建随机输入张量
x = torch.randn(1, 3, 224, 224)
# 加载模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
# 提取特征(跳过分类头和全局池化)
features = model.forward_features(x)
print(features.shape) # 输出特征图维度
图像预处理
正确的图像预处理对模型性能至关重要。timm 提供了便捷的预处理方法:
# 创建通用预处理变换
transform = timm.data.create_transform((3, 224, 224))
# 获取模型特定的预处理配置
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
data_config = timm.data.resolve_data_config(model.pretrained_cfg)
# 创建模型特定的预处理变换
specific_transform = timm.data.create_transform(**data_config)
完整推理示例
下面是一个完整的图像分类推理流程:
import requests
from PIL import Image
import torch
import timm
# 1. 加载图像
url = 'https://example.com/cat.jpg' # 替换为实际图片URL
image = Image.open(requests.get(url, stream=True).raw)
# 2. 加载模型和预处理
model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
# 3. 预处理图像
image_tensor = transform(image).unsqueeze(0) # 添加批次维度
# 4. 模型推理
with torch.no_grad():
output = model(image_tensor)
# 5. 处理输出
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_values, top5_indices = torch.topk(probabilities, 5)
# 6. 显示结果
labels = [...] # 这里应该是实际的标签列表
results = [{'label': labels[idx], 'probability': val.item()}
for val, idx in zip(top5_values, top5_indices)]
print(results)
总结
timm 库提供了丰富的功能和便捷的接口,使得计算机视觉任务变得更加简单。通过本文介绍的方法,你可以快速开始使用 timm 进行模型加载、微调、特征提取和推理等任务。对于更高级的用法,建议查阅 timm 的详细文档和源代码。