首页
/ SwinIR图像恢复模型的核心网络架构解析

SwinIR图像恢复模型的核心网络架构解析

2025-07-08 04:16:53作者:胡易黎Nicole

概述

SwinIR是基于Swin Transformer架构的图像恢复模型,在图像超分辨率、去噪等任务中表现出色。本文将深入解析SwinIR模型的核心网络架构实现,帮助读者理解其关键组件和工作原理。

核心组件

1. 多层感知机(MLP)

MLP模块是Transformer架构中的标准组件,由两个全连接层和激活函数组成:

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

MLP模块的特点:

  • 使用GELU激活函数
  • 每个全连接层后都包含Dropout层防止过拟合
  • 隐藏层维度默认为输入维度的4倍(通过mlp_ratio参数控制)

2. 窗口划分与还原

SwinIR采用窗口化的注意力机制,需要先将图像划分为窗口:

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
    windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
    return windows

窗口还原则是逆过程:

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
    return x

3. 窗口注意力机制(WindowAttention)

这是SwinIR的核心创新之一,在局部窗口内计算自注意力:

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, 
                 attn_drop=0., proj_drop=0.):
        # 初始化相对位置偏置表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1), num_heads))
        
        # 计算相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        # ... (详细位置编码计算)

关键特点:

  • 引入相对位置偏置,使模型能够感知空间位置信息
  • 支持多头注意力机制
  • 计算效率高,仅在局部窗口内计算注意力

4. Swin Transformer块(SwinTransformerBlock)

这是构成SwinIR的基本构建块:

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):

每个块包含:

  1. 层归一化(LayerNorm)
  2. 窗口注意力(WindowAttention)
  3. 残差连接
  4. MLP前馈网络
  5. 可选的shift操作实现窗口间通信

5. 残差Swin Transformer块(RSTB)

这是SwinIR的高层组件,由多个SwinTransformerBlock组成:

class RSTB(nn.Module):
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

特点:

  • 包含多个SwinTransformerBlock(深度由depth参数控制)
  • 可选的下采样层
  • 支持梯度检查点以节省内存

关键创新点

  1. 局部窗口注意力:将全局注意力计算限制在局部窗口内,大幅降低计算复杂度
  2. 移位窗口机制:通过周期性移位窗口实现不同窗口间的信息交互
  3. 层次化设计:通过Patch Merging构建层次化特征表示
  4. 相对位置编码:在注意力计算中引入相对位置偏置,保留空间信息

性能考虑

SwinIR在设计中考虑了计算效率:

  • 使用窗口化注意力降低计算复杂度
  • 实现FLOPs计算方法便于评估计算量
  • 支持梯度检查点节省内存

总结

SwinIR的网络架构通过巧妙结合CNN的局部性和Transformer的全局建模能力,在图像恢复任务中取得了优异的性能。其核心创新在于窗口化的注意力机制和层次化设计,既保持了Transformer的强大表示能力,又控制了计算复杂度。理解这些核心组件对于进一步研究和应用SwinIR模型至关重要。