EAST文本检测模型架构解析与实现细节
2025-07-10 03:50:49作者:昌雅子Ethen
模型概述
EAST(Efficient and Accurate Scene Text detector)是一种高效准确的场景文本检测器,其核心思想是通过全卷积网络直接预测文本区域的位置和方向。本文将从技术实现角度深入解析EAST模型的架构设计、特征融合机制以及损失函数计算等关键部分。
模型架构详解
1. 输入预处理
模型首先对输入图像进行归一化处理,使用mean_image_subtraction
函数减去ImageNet数据集上的均值:
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]):
num_channels = images.get_shape().as_list()[-1]
channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(axis=3, values=channels)
这种预处理方式有助于模型更快收敛,提高训练稳定性。
2. 主干网络
EAST采用ResNet-50作为特征提取主干网络:
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
ResNet的残差连接结构能有效缓解深层网络梯度消失问题,适合提取多尺度文本特征。
3. 特征融合机制
EAST的核心创新在于其特征融合策略,它从ResNet不同阶段提取特征并进行逐步融合:
f = [end_points['pool5'], end_points['pool4'],
end_points['pool3'], end_points['pool2']]
特征融合过程采用U-Net式的上采样结构:
- 从最深层的特征开始(pool5)
- 每层与上一层的上采样结果进行拼接
- 通过1x1和3x3卷积进行特征变换
- 使用双线性插值进行上采样
for i in range(4):
if i == 0:
h[i] = f[i]
else:
c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
if i <= 2:
g[i] = unpool(h[i])
else:
g[i] = slim.conv2d(h[i], num_outputs[i], 3)
这种设计能够同时利用深层特征的语义信息和浅层特征的细节信息,对于检测不同尺度的文本区域非常有效。
4. 输出预测
模型最终输出两个部分:
- 文本得分图(F_score):使用sigmoid激活,表示每个位置是文本的概率
- 几何特征图(F_geometry):包含4个通道的轴对齐边界框和1个通道的旋转角度
F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2
F_geometry = tf.concat([geo_map, angle_map], axis=-1)
其中角度预测被限制在[-π/4, π/4]范围内,对应[-45°, 45°]的旋转角度。
损失函数设计
EAST的损失函数由两部分组成:
1. 分类损失(Dice Loss)
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
eps = 1e-5
intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask)
union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
Dice Loss相比传统的交叉熵损失,对类别不平衡问题(文本区域通常远小于非文本区域)更加鲁棒。
2. 几何损失
几何损失包括两部分:
- AABB损失:衡量预测边界框与真实框的重叠程度
- 角度损失:衡量预测角度的准确性
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0))
L_theta = 1 - tf.cos(theta_pred - theta_gt)
最终损失是两者的加权和:
L_g = L_AABB + 20 * L_theta
return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss
角度损失被赋予更高的权重(20倍),因为角度预测对最终检测结果影响较大。
模型特点总结
- 全卷积结构:可以处理任意尺寸的输入图像
- 多尺度特征融合:有效检测不同大小的文本区域
- 端到端训练:直接输出文本位置和方向,无需复杂后处理
- 高效推理:单次前向传播即可得到检测结果
EAST模型通过精心设计的网络架构和损失函数,在场景文本检测任务上实现了高精度和高效率的平衡,适合实际应用部署。