SwinIR图像恢复模型的核心网络架构解析
2025-07-08 04:16:53作者:胡易黎Nicole
概述
SwinIR是基于Swin Transformer架构的图像恢复模型,在图像超分辨率、去噪等任务中表现出色。本文将深入解析SwinIR模型的核心网络架构实现,帮助读者理解其关键组件和工作原理。
核心组件
1. 多层感知机(MLP)
MLP模块是Transformer架构中的标准组件,由两个全连接层和激活函数组成:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
MLP模块的特点:
- 使用GELU激活函数
- 每个全连接层后都包含Dropout层防止过拟合
- 隐藏层维度默认为输入维度的4倍(通过mlp_ratio参数控制)
2. 窗口划分与还原
SwinIR采用窗口化的注意力机制,需要先将图像划分为窗口:
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
windows = x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)
return windows
窗口还原则是逆过程:
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
x = x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)
return x
3. 窗口注意力机制(WindowAttention)
这是SwinIR的核心创新之一,在局部窗口内计算自注意力:
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None,
attn_drop=0., proj_drop=0.):
# 初始化相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1), num_heads))
# 计算相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
# ... (详细位置编码计算)
关键特点:
- 引入相对位置偏置,使模型能够感知空间位置信息
- 支持多头注意力机制
- 计算效率高,仅在局部窗口内计算注意力
4. Swin Transformer块(SwinTransformerBlock)
这是构成SwinIR的基本构建块:
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
每个块包含:
- 层归一化(LayerNorm)
- 窗口注意力(WindowAttention)
- 残差连接
- MLP前馈网络
- 可选的shift操作实现窗口间通信
5. 残差Swin Transformer块(RSTB)
这是SwinIR的高层组件,由多个SwinTransformerBlock组成:
class RSTB(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
特点:
- 包含多个SwinTransformerBlock(深度由depth参数控制)
- 可选的下采样层
- 支持梯度检查点以节省内存
关键创新点
- 局部窗口注意力:将全局注意力计算限制在局部窗口内,大幅降低计算复杂度
- 移位窗口机制:通过周期性移位窗口实现不同窗口间的信息交互
- 层次化设计:通过Patch Merging构建层次化特征表示
- 相对位置编码:在注意力计算中引入相对位置偏置,保留空间信息
性能考虑
SwinIR在设计中考虑了计算效率:
- 使用窗口化注意力降低计算复杂度
- 实现FLOPs计算方法便于评估计算量
- 支持梯度检查点节省内存
总结
SwinIR的网络架构通过巧妙结合CNN的局部性和Transformer的全局建模能力,在图像恢复任务中取得了优异的性能。其核心创新在于窗口化的注意力机制和层次化设计,既保持了Transformer的强大表示能力,又控制了计算复杂度。理解这些核心组件对于进一步研究和应用SwinIR模型至关重要。