【目标检测】RetinaNet 论文详解




背景与简介

目前最高的检测器是以 R-CNN 开启的 two-stage 检测方法,这种方法把分类器应用在离散的目标候选区。而 one-stage 检测器则是通过在可能出现目标的地方做规则、密集的采样,然后将检测器用于这些地方

one-stage 检测器更快和更简单的,但是在精度是却落后two-stage,导致 one-stage detector 精度较低的主要原因是训练时样本存在极端的正负样本类(foreground & background class)不平衡现象,即 class imbalance。大致可以理解为由于训练时出现大量的容易 negative 样本,使得损失函数绝大部分由不含信息的负样本构成,进而得到的 loss 无法为模型训练提供好的指导。

eg: 在二分类中正负样本比例存在较大差距,导致模型的预测偏向某一类别。如果正样本占据 1%,而负样本占据 99%,那么模型只需要对所有样本输出预测为负样本,那么模型轻松可以达到 99% 的正确率,所以这样的样本训练出来的模型是不准确的。


解决类别不平衡(class imbalance)

  • R-CNN 这样的 two stage detecors,
    • 在其第一个 stage —— proposal stage 的时候就缩小了候选区域的数量,过滤了大多数的 background 样本(负样本);
    • 第二 stage 的时候,通过一些启发式的采样,如保持正负样本 1:3 比例,以及 onem 在线难样本挖掘等操作,很好的保持了 class balance。
  • one-stage detectors 则需要处理大的多的候选区,密集的涵盖了每个 spatial location,scale,和 ratio,虽然也能使用启发式的采样,但是效果很差,因为训练被大量易区分的 background 样本所支配着,这是检测中的经典问题,通常以难样本挖掘等技术来解决。


作者提出以一个新损失函数 Focus loss 来解决 class imbalance 现象,该 loss 函数重新改造了标准的交叉熵损失(standard cross entropy loss),是一个动态尺度的交叉熵损失(dynamically scaled cross entropy loss),随着正确样本的置信度上升 scale 会降低至 0。这个损失函数能减小分配给良好分类的样本的损失权重,并聚焦于那些难样本。这些损失在后面会细讲~

所谓难样本,可以简单理解为不容易区分他是否含有目标物,比如只包含一条腿的一个 anchor,难以判断是否为人这个目标。

为了评估这个损失函数 Focus loss 的有效性,作者设计并训练了一个 one-stage 的样本密度检测器——RetinaNet,结果表明 Focus loss 来解决 class imbalance 的方案远远好于启发式采样或者难样本挖掘等以往应用在 one-stage 中的方案,使用 focus loss 时 RetinaNet 既有one-stage 的速度,还超过了当时所有最好的 two-stage detector 的精度。

Note: focus loss 的具体形式并不是很关键。


Focal Loss

focal loss 是一个能够动态缩放的cross entropy loss,当正确类别的置信度提高时缩放因子衰减为0,缩放因子可以自动降低easy例子在训练期间贡献loss的权重.


为了介绍 Focal loss,首先引入用于二分类的 cross entropy(CE) 交叉熵的概念

cross entropy(CE):

上式中 y 取值 +1 或 -1,+1 代表为 ground truth class,-1 则不是。p 则代表 y 为 groud truth 的情况下,估计其为该 class 的概率。按照惯例,写作:

令 CE(p, y) = CE(pt) = − log(pt) 。pt 可以看作样本被正确分类的一个概率值,包括被正确分类为 background 的错误样本和被正确分类为 gt class 的正确样本。


CE(pt) 损失的曲线可以被看作上图中的蓝色曲线,该曲线具有的一个显著特征是,即使样本很容易分辨(即 pt>0.5,属于 easy sample),仍然会造成较大的损失。当大量的 easy examples loss 相加起来就会得到一个很大的 loss 值,其他样本的 loss 就会被 overwhelm 掉。


Balanced Cross Entropy

解决 class imbalance 的一个方法是为类 1 添加一个权重因子 α ∈ [0, 1],为类 -1 添加一个权重因子 1-α。在实践中,α 可以通过逆类频率来设置,也可以作为一个超参数通过交叉验证来设置。为了方便,以类似定义 pt 的方式定义 αt。然后得到 α-balanced CE loss:

这是交叉熵的一个简单扩展,通过该 α 权值,我们可以解决正负样本的不平衡。


Focal Loss Definition

α 解决了正负样本的不平衡,但是却不能区分 easy/hard examples。因此,focal loss 改造 cross loss,使之降低对 easy examples 的权重,聚焦于 hard examples 的训练。

具体来讲,给交叉熵增加了一个调节因子(modulating factor)——(1 − pt)γ,其中可调节的参数 γ>0。

于是我们的 focal loss 为:

不同 γ 取值对应的曲线上图中有显示,其中 γ 为 0 时函数退化为交叉熵损失。该损失函数有两个特性:

  • 样本被误分类时 pt 很小,整个调节因子接近于 1,对 loss 值几乎无影响。而当样本被很好分类时,pt 值趋近于 1,调节因子接近于 0,使得良好分类的样本 loss 降低,从而相对的变为关注难样本。
  • γ 增加,调节因子的影响也增加,实验中 γ 取2最好。

在计算 [公式] 时用 sigmoid 方法比 softmax 准确度更高;

Focal Loss 的公式并不是固定的,也可以有其它形式,性能差异不大,所以说 Focal Loss 的表达式并不关键。


直观地看,调节因子减小了简单样本的损失贡献,扩展了简单样本的低损失范围。

实践中,使用带 α-balance 变量的 focal loss 精度会有小的提升,α 和 γ 是有关系的,γ 增加时,α 应该稍微减小,最佳的取值为 γ=2,α=0.25:



Total loss

训练的总损失由分类的 Focal loss 和边框回归的 smooth L1 loss 组成,其中 Focal loss 要计算所有 anchor 的 loss,这一点与其他方法不同,其他方法至少选择一部分的 anchor。并且做 normalization 时除以的是被分配到 ground truth 的那些 anchor 的数量。


Model Initialization

在有显著的 class imbalance 时,即 background examples 远大于 foreground examples,但是网络初始化时,对于任何输入,其判断为 positive 和 negative 的概率是一样的,而我们 focal loss 对判断为正确的样本损失会降权重,所以谁被误判的多,谁造成的 loss 更大。

可认为一半负样本被误分为正,一半正样本被误分为负,但由于负样本基数远大于正样本,于是负样本几乎贡献了 loss 绝大部分,于是在训练初期,函数会不稳定,模型会偏向于去把样本分为负样本。

为了缓解这个偏差,作者对最后一级用于分类的卷积的 bias (具体位置见图2) 作了下小修改,把它初始化成一个特殊的值 b = − log((1 − π)/π) 。π 在论文中取 0.01,这样做能在训练初始阶段提高 positive 的分类概率。


RetinaNet Detector

RetinaNet Detector 由 backbone 和两个用于特定任务的 subnetworks 组成,backbone 提取整张图片的特征。然后两个 subnetworks 在 backbone 输出的特征图上做卷积计算,第一个 subnetworks 卷积计算 object classification 结果,第二个 subnetworks 做 bounding box regression.

结构图:

backbone 使用的是 RenNet+FPN,每个使用的三个特征层都为 256channel。每个特征层用三个不同比例,三个尺寸的 anchor,总共 9 个 anchor。每个 anchor 得到一个 k 维的分类向量和一个 4 维的 box 回归向量。IoU>0.5 的 anchor 为 positive,被分配到对应的 Ground Truth,IoU 在 [0, 0.4] 的为 negative,即 background,IoU 为 [0.4,0.5] 的忽略。


分类分支和边框回归分支都是 FCN 子网络,网络的参数对于 FPN 的每层特征都共享。


分类分支最后输出的 channel 是 A×k,即 anchor 数 × num_class。

边框回归分支输出的 channel 是 4A,每个 anchor 对应 4 个参数预测了 anchor 和 ground truth 的相对位移。使用 RCNN 中的参数化方法来回归。


Inference

inference 只需简单前向传播就行,为了加速,每个 FPN 层只选择检测阈值 0.05 筛选后得分前 1000 的 anchor 做 decode(还原得到 anchor 在原图上的位置),最后所有 RPN 的预测做 merge,然后非极大值抑制 nms 得到最终结果。





reference

https://zhuanlan.zhihu.com/p/310508374