首页
/ ArtLine项目训练教程:从零开始构建线条艺术生成模型

ArtLine项目训练教程:从零开始构建线条艺术生成模型

2025-07-09 05:37:41作者:幸俭卉

项目概述

ArtLine是一个基于深度学习的图像生成项目,能够将普通的人物肖像照片转化为精美的线条艺术画作。该项目利用了生成对抗网络(GAN)和卷积神经网络(CNN)技术,通过多阶段的训练过程,逐步提高生成图像的质量和细节表现力。

环境准备

在开始训练前,我们需要准备以下Python库:

import torch
import torch.nn as nn
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from torchvision.models import vgg16_bn
from fastai.utils.mem import *
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torchvision.transforms as transforms

这些库提供了深度学习模型构建、训练和评估所需的核心功能。

边缘检测实现

边缘检测是生成线条艺术的关键步骤,项目中实现了一个自定义的梯度计算函数:

def _gradient_img(img):
    img = img.squeeze(0)
    ten=torch.unbind(img)
    x=ten[0].unsqueeze(0).unsqueeze(0)
    
    # Sobel算子X方向卷积核
    a=np.array([[1, 0, -1],[2,0,-2],[1,0,-1]])
    conv1=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
    conv1.weight=nn.Parameter(torch.from_numpy(a).float().unsqueeze(0).unsqueeze(0))
    G_x=conv1(Variable(x)).data.view(1,x.shape[2],x.shape[3])

    # Sobel算子Y方向卷积核
    b=np.array([[1, 2, 1],[0,0,0],[-1,-2,-1]])
    conv2=nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
    conv2.weight=nn.Parameter(torch.from_numpy(b).float().unsqueeze(0).unsqueeze(0))
    G_y=conv2(Variable(x)).data.view(1,x.shape[2],x.shape[3])

    # 计算梯度幅值
    G=torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))
    return G

gradient = TfmPixel(_gradient_img)

这个实现使用了Sobel算子来检测图像边缘,通过计算X和Y方向的梯度,然后合成最终的边缘图像。

数据准备

项目使用了两种类型的数据集:

  1. 混合面部特征数据集
  2. 完整肖像数据集

数据加载函数如下:

def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(xtra_tfms=[gradient()]), size=size, tfm_y=True)
           .databunch(bs=bs,num_workers = 0).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

模型架构

项目采用了ResNet34作为基础架构:

arch = models.resnet34

特征损失函数

为了生成高质量的线条艺术,项目实现了一个自定义的特征损失函数,结合了VGG16的特征提取能力:

class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

训练流程

项目采用了渐进式训练策略,从低分辨率开始,逐步提高图像分辨率:

1. 64px分辨率训练

bs,size=20,64
data = get_data(bs,size)
learn_gen = create_gen_learner()
lr = 1e-01
epoch = 5
do_fit('da', slice(lr))

2. 128px分辨率训练

data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db')
epoch =5
lr = 1E-03
do_fit('db2',slice(lr))

3. 192px分辨率训练

data = get_data(5,192)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db3')
epoch =5
lr = 1E-06
do_fit('db4')

完整肖像训练

在完成面部特征训练后,项目转向完整肖像的训练:

1. 128px分辨率

data = get_data(8,128)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db5')
epoch = 5
lr = 1e-03
do_fit('db6')

2. 192px分辨率

data = get_data(4,192)
learn_gen.data = data
learn_gen.freeze()
learn_gen.load('db7')
epoch = 5
lr = 4.37E-05
do_fit('db8')

训练技巧

  1. 渐进式训练:从低分辨率开始,逐步提高分辨率,有助于模型学习不同尺度的特征。
  2. 学习率调整:使用lr_find()方法寻找最佳学习率,并在不同阶段调整学习率。
  3. 模型冻结与解冻:在切换分辨率时冻结部分层,稳定训练过程。
  4. 特征损失:结合VGG网络的多层次特征,提高生成图像的质量。

总结

ArtLine项目展示了一个完整的线条艺术生成模型的训练流程。通过精心设计的网络架构、特征损失函数和渐进式训练策略,模型能够将普通肖像照片转化为精美的线条艺术画作。这种技术可以应用于艺术创作、设计辅助等多个领域。

训练过程中需要注意内存管理,特别是在高分辨率阶段,可能需要调整批量大小以适应GPU内存限制。此外,训练数据的质量和多样性对最终生成效果有重要影响,建议使用高质量、多样化的肖像数据集进行训练。