ArtLine项目训练教程:从零开始构建线条艺术生成模型
2025-07-09 05:37:41作者:幸俭卉
项目概述
ArtLine是一个基于深度学习的图像生成项目,能够将普通的人物肖像照片转化为精美的线条艺术画作。该项目利用了生成对抗网络(GAN)和卷积神经网络(CNN)技术,通过多阶段的训练过程,逐步提高生成图像的质量和细节表现力。
环境准备
在开始训练前,我们需要准备以下Python库:
import torch
import torch.nn as nn
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from torchvision.models import vgg16_bn
from fastai.utils.mem import *
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torchvision.transforms as transforms
这些库提供了深度学习模型构建、训练和评估所需的核心功能。
边缘检测实现
边缘检测是生成线条艺术的关键步骤,项目中实现了一个自定义的梯度计算函数:
def _gradient_img(img):
img = img.squeeze(0)
ten=torch.unbind(img)
x=ten[0].unsqueeze(0).unsqueeze(0)
# Sobel算子X方向卷积核
a=np.array([[1, 0, -1],[2,0,-2],[1,0,-1]])
conv1=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
conv1.weight=nn.Parameter(torch.from_numpy(a).float().unsqueeze(0).unsqueeze(0))
G_x=conv1(Variable(x)).data.view(1,x.shape[2],x.shape[3])
# Sobel算子Y方向卷积核
b=np.array([[1, 2, 1],[0,0,0],[-1,-2,-1]])
conv2=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
conv2.weight=nn.Parameter(torch.from_numpy(b).float().unsqueeze(0).unsqueeze(0))
G_y=conv2(Variable(x)).data.view(1,x.shape[2],x.shape[3])
# 计算梯度幅值
G=torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))
return G
gradient = TfmPixel(_gradient_img)
这个实现使用了Sobel算子来检测图像边缘,通过计算X和Y方向的梯度,然后合成最终的边缘图像。
数据准备
项目使用了两种类型的数据集:
- 混合面部特征数据集
- 完整肖像数据集
数据加载函数如下:
def get_data(bs,size):
data = (src.label_from_func(lambda x: path_hr/x.name)
.transform(get_transforms(xtra_tfms=[gradient()]), size=size, tfm_y=True)
.databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))
data.c = 3
return data
模型架构
项目采用了ResNet34作为基础架构:
arch = models.resnet34
特征损失函数
为了生成高质量的线条艺术,项目实现了一个自定义的特征损失函数,结合了VGG16的特征提取能力:
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
] + [f'gram_{i}' for i in range(len(layer_ids))]
def make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
self.feat_losses = [base_loss(input,target)]
self.feat_losses += [base_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self): self.hooks.remove()
训练流程
项目采用了渐进式训练策略,从低分辨率开始,逐步提高图像分辨率:
1. 64px分辨率训练
bs,size=20,64
data = get_data(bs,size)
learn_gen = create_gen_learner()
lr = 1e-01
epoch = 5
do_fit('da', slice(lr))
2. 128px分辨率训练
data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db')
epoch =5
lr = 1E-03
do_fit('db2',slice(lr))
3. 192px分辨率训练
data = get_data(5,192)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db3')
epoch =5
lr = 1E-06
do_fit('db4')
完整肖像训练
在完成面部特征训练后,项目转向完整肖像的训练:
1. 128px分辨率
data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db5')
epoch = 5
lr = 1e-03
do_fit('db6')
2. 192px分辨率
data = get_data(4,192)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db7')
epoch = 5
lr = 4.37E-05
do_fit('db8')
训练技巧
- 渐进式训练:从低分辨率开始,逐步提高分辨率,有助于模型学习不同尺度的特征。
- 学习率调整:使用
lr_find()
方法寻找最佳学习率,并在不同阶段调整学习率。 - 模型冻结与解冻:在切换分辨率时冻结部分层,稳定训练过程。
- 特征损失:结合VGG网络的多层次特征,提高生成图像的质量。
总结
ArtLine项目展示了一个完整的线条艺术生成模型的训练流程。通过精心设计的网络架构、特征损失函数和渐进式训练策略,模型能够将普通肖像照片转化为精美的线条艺术画作。这种技术可以应用于艺术创作、设计辅助等多个领域。
训练过程中需要注意内存管理,特别是在高分辨率阶段,可能需要调整批量大小以适应GPU内存限制。此外,训练数据的质量和多样性对最终生成效果有重要影响,建议使用高质量、多样化的肖像数据集进行训练。