首页
/ PyTorch Image Models (timm) 快速入门指南

PyTorch Image Models (timm) 快速入门指南

2025-07-05 03:38:22作者:龚格成

前言

PyTorch Image Models (timm) 是一个强大的计算机视觉模型库,提供了大量预训练模型和训练工具。本文将带你快速了解如何使用 timm 库进行模型加载、微调、特征提取和推理等常见任务。

安装 timm

使用 pip 可以轻松安装 timm 库:

pip install timm

建议在 Python 3.7+ 和 PyTorch 1.8+ 环境中使用 timm 以获得最佳体验。

加载预训练模型

timm 提供了 create_model 函数来加载预训练模型。以下示例展示了如何加载 MobileNetV3 大型模型:

import timm

# 加载预训练模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True)

# 将模型设置为评估模式
model.eval()

重要提示:新加载的模型默认处于训练模式,如果用于推理,必须调用 .eval() 方法。

查看可用模型

timm 提供了丰富的预训练模型,可以通过以下方式查看:

import timm
from pprint import pprint

# 查看所有有预训练权重的模型
pretrained_models = timm.list_models(pretrained=True)
pprint(pretrained_models)

# 使用通配符筛选特定模型
resnet_models = timm.list_models('*resne*t*')
pprint(resnet_models)

模型微调

微调预训练模型通常只需要修改最后的分类层:

# 假设我们的新任务有10个类别
NUM_CLASSES = 10

# 加载模型并修改分类层
model = timm.create_model(
    'mobilenetv3_large_100', 
    pretrained=True, 
    num_classes=NUM_CLASSES
)

微调时,你需要准备自己的数据集并编写训练循环,或者修改 timm 提供的训练脚本。

特征提取

timm 支持不修改网络结构直接提取特征:

import torch

# 创建随机输入张量
x = torch.randn(1, 3, 224, 224)

# 加载模型
model = timm.create_model('mobilenetv3_large_100', pretrained=True)

# 提取特征(跳过分类头和全局池化)
features = model.forward_features(x)
print(features.shape)  # 输出特征图维度

图像预处理

正确的图像预处理对模型性能至关重要。timm 提供了便捷的预处理方法:

# 创建通用预处理变换
transform = timm.data.create_transform((3, 224, 224))

# 获取模型特定的预处理配置
model = timm.create_model('mobilenetv3_large_100', pretrained=True)
data_config = timm.data.resolve_data_config(model.pretrained_cfg)

# 创建模型特定的预处理变换
specific_transform = timm.data.create_transform(**data_config)

完整推理示例

下面是一个完整的图像分类推理流程:

import requests
from PIL import Image
import torch
import timm

# 1. 加载图像
url = 'https://example.com/cat.jpg'  # 替换为实际图片URL
image = Image.open(requests.get(url, stream=True).raw)

# 2. 加载模型和预处理
model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
transform = timm.data.create_transform(
    **timm.data.resolve_data_config(model.pretrained_cfg)
)

# 3. 预处理图像
image_tensor = transform(image).unsqueeze(0)  # 添加批次维度

# 4. 模型推理
with torch.no_grad():
    output = model(image_tensor)

# 5. 处理输出
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_values, top5_indices = torch.topk(probabilities, 5)

# 6. 显示结果
labels = [...]  # 这里应该是实际的标签列表
results = [{'label': labels[idx], 'probability': val.item()} 
           for val, idx in zip(top5_values, top5_indices)]
print(results)

总结

timm 库提供了丰富的功能和便捷的接口,使得计算机视觉任务变得更加简单。通过本文介绍的方法,你可以快速开始使用 timm 进行模型加载、微调、特征提取和推理等任务。对于更高级的用法,建议查阅 timm 的详细文档和源代码。