目标检测4--Adaptive Training Sample Selection(ATSS)算法
论文Bridging the Gap Between Anchor-based and Anchor-free Detection via代码https://github.com/sfzhang15/ATSSATSS是中科院自动化研究所的等最早于2019年12月份提交的论文中提出的方法,发表在CVPR2020会议上。文中分析了和的检测方法,性能差异的主要原因在于正负训练样本的定义方式不同,而和回归
文章目录
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
1.简介
论文Bridging the Gap Between Anchor-based and Anchor-free Detection via
Adaptive Training Sample Selection
代码https://github.com/sfzhang15/ATSS
ATSS
是中科院自动化研究所的Shifeng Zhang
等最早于2019年12月份提交的论文中提出的方法,发表在CVPR2020会议上。
文中分析了Anchor Based
和Anchor Free
的检测方法,性能差异的主要原因在于正负训练样本的定义方式不同,而和回归目标是基于**点式(point)还是盒式(box)**关系不大。Anchor Free
检测常用的有两种方法,一种是keypoint_based
,另一种是center_based
。keypoint_based
的Anchor Free
目标检测算法同标准的keypoint estimation pipeline
,和anchor based
的目标检测算法差异较大。但center_based
的Anchor Free
目标检测算法与Anchor Based
的方法比较相近,center_based
方法将point
作为预设样本(如FCOS),Anchor Based
方法是将anchor
作为预设样本(如RetinaNet)。Anchor Based
的RetinatNet
与Center Based
的FCOS
的主要区别是:
- 1)
feature map
中每个位置的anchor
数量不同,RetinaNet
每个点生成多个anchor boxes
,FCOS
每个点生成一个anchor point
- 2)正负样本的定义方式不同,
RetinaNet
使用IoU
来判定正负样本,FCOS
使用patial and scale constraints
来判断。 - 3)回归起始状态不同。
RetinaNet
是基于Anchor Box
的 ( t x , t y . t ω , t h ) (t_x,t_y.t_\omega,t_h) (tx,ty.tω,th),FCOS
是基于Anchor Point
的 ( l l , l t , l r , l o ) (l_l,l_t,l_r,l_o) (ll,lt,lr,lo)。
ATSS
分析了Anchor Based
和Anchor Free
检测算法实现上的差异,得出的结论是正负样本定义方式的不同影响了两种方法检测效果的差异。基于此论文提出了Adaptive Training Sample Selection(ATSS)
算法以基于目标特征自动的计算正负样本。本文还基于实验得出了在同个位置没必要使用多个anchor box
做检测的结论。
2.目标检测相关
3.Anchor Based
与Anchor Free
目标检测算法的差异分析
Anchor Based
选择RetinaNet
作为代表,Anchor Free
选择FCOS
作为代表,从以下三方面进行分析:
- 1)正负样本定义
- 2)初始回归状态,是回归 t x , t y , t w , t h t_x,t_y,t_w,t_h tx,ty,tw,th还是 l l , l t , l r , l o l_l,l_t,l_r,l_o ll,lt,lr,lo
- 3)每个位置的
anchor
数量
3.1 RetinaNet
与FCOS
的对比
设置RetinaNet
的Anchor box
数量为1
。对FCOS
的改进:
- 1)将
centernerss
移到regression
分支 - 2)使用
GIoU Loss
- 3)将回归目标使用对应level的stride来归一化
这些提升了FCOS
的检测效果,coco minival
上的map
从37.1
提升到了37.8
,进一步拉开了Anchor=1
的RetinaNet
与FCOS
的差距。
FCOS
中使用的一些trick
在Anchor=1
的RetinaNet
中也能使用,如检测头中使用的Group Normlization
, GIoU
,限制ground truth box
中的正样本,对特征金字塔的每层加上一个中心度分支和可训练参数。将这些trick
逐一加到RetinaNet
上的对比结果为:
从上图可以看出,将所有的通用trick
都应用到RetinaNet
上后,MAP
依然有0.8的差距。除了以上指出的通用性差异后,还有两点不同,一个是正负样本的定义方式,另一个是回归任务本身,RetinaNet
是基于Anchor Box
回归,FCOS
是基于Anchor Point
回归。
3.1 正负样本定义的区别
如上图,RetinaNet
根据ground truth box
与anchor box
之间的IoU
的值来判断是正样本还是负样本,通常设置两个超参数
(
I
o
U
n
e
g
,
I
o
U
p
o
s
)
(IoU_{neg}, IoU_{pos})
(IoUneg,IoUpos),小于
I
o
U
n
e
g
IoU_{neg}
IoUneg的是负样本,大于
I
o
U
p
o
s
IoU_{pos}
IoUpos的是正样本,在两者之间的Anchor Box
被忽略,不参与训练,RPN
生产的Proposal Box
基于FPN
论文中提出的方程式2赋值给某个feature
层。FCOS
则先根据Anchor Point
的空间位置是否落在ground truth box
中找出可能为正的Anchor Point
,再根据Anchor Point
对应feature map
上的回归范围regression scale
来近一步确认是否为正样本,参考见博客FCOSNet。基于Spatial and Scale
的正样本判定方式决定了检测器的优秀性能,如下表,使用Spatial and Scale
后,Anchor=1的RetinaNet
的MAP
也提升到了37.8
,换用IoU
的FCOS
的MAP
降到了36.9
:
3.2 回归起始位置的差异
如下图,Anchor=1的RetinaNet
回归的是AnchorBox
相对于ground truth box
的平移缩放
(
t
x
,
t
y
,
t
w
,
t
h
)
(t_x,t_y,t_w,t_h)
(tx,ty,tw,th)即基于box
的回归,而FCOS
回归的是中心点距离ground truth box
四边的距离
l
l
,
l
t
,
l
r
,
l
b
l_l,l_t,l_r,l_b
ll,lt,lr,lb,即基于点的回归。从上图中按行方向比较可以发现,使用box
或point
的回归方式对最终的结果影响不大,37->36.9
,‵37.8->37.8`。
综合3.1和3.2的分析,可以得出结论:是正负样本的定义方式不同影响了Anchor Based
和Anchor Free
算法的性能。
4.自适应训练样本选择
从前面作者得出的结论,How to define positive and negative samples极大影响了检测器的性能,基于此作者提出了新的samples
分类算法,自适应训练样本选择(Adaptive Training Sample Selection, ATSS)。
Anchor Based
基于IoU
和Anchor Free
基于Scale Range
的正样本定义方法都依赖预先定义好的超参数,ATSS
提出了一种自适应取阈值的方法,减少了sample definition
所需的超参数。
以一张输入图像为例说明上图ATSS
算法的工作流程:
- 1)对于1个
ground truth box
,分别在每个金字塔特征层上取中心 L 2 L_2 L2距离最近的 k k k个anchor boxes
作为候选positive sample
,对于有 L \mathcal{L} L个金字塔特征层的网络,共得到 k L k\mathcal{L} kL个candidate positive anchor boxes
- 2)计算
candidates
与ground truth boxes
g ∈ D g g\in \mathcal{D}_{g} g∈Dg之间的IoU
- 3)计算2)中
IoU
的均值 m g m_{\mathcal{g}} mg和标准差 v g \mathcal{v}_{\mathcal{g}} vg - 4)取
t
g
=
m
g
+
v
g
t_g=m_{\mathcal{g}}+\mathcal{v}_{\mathcal{g}}
tg=mg+vg作为阈值,大于
t
g
t_g
tg的是
positive
,其余的Anchor Boxes
都是negative
作者指出,当一个anchor box
同时落入两个ground truth box
中时,会将其分配给IoU
比较大的ground truth box
。
从上图可以看出ATSS
的作用,对于某个ground truth box
,图a中标准差较大,意味着有某个金字塔特征层比较适合预测该box
,因此阈值
t
g
t_g
tg也比较大。图b中标准差不大,意味者可能有多个特征层适合预测当前box
,因此选取的阈值
t
g
t_g
tg也较小。
作者还指出使用ATSS
,可以使得对于不同大小的目标对象得到相同比例的正负训练样本。对于标准正态分布有16%
的样本落在
[
v
+
σ
,
1
]
[v+\sigma,1]
[v+σ,1]之间,虽然IoU of candidates
不是正态分布,正样本的比例依然保持在了20% of
k
L
k\mathcal{L}
kL 左右,和目标
s
c
a
l
e
/
a
s
p
e
c
t
r
a
t
i
o
/
l
o
c
a
t
i
o
n
scale/aspect ratio/location
scale/aspectratio/location无关。而RetinaNet
和FCOS
都会倾向于对大目标生成更多的正样本。
ATSS
使用的超参数很少,只有k
一个,且算法效果对k
不敏感。实验证明k
取[3, 5, 7, 9, 11, 13, 15, 17, 19]
时map
变化不大:
5.代码实现
mmdetection
中ATSS
算法的实现在ATSSAssigner
类中,assign
的部分代码如下:
# Selecting candidates based on the center distance
candidate_idxs = []
start_idx = 0
for level, bboxes_per_level in enumerate(num_level_bboxes):
# on each pyramid level, for each gt,
# select k bbox whose center are closest to the gt center
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_k = min(self.topk, bboxes_per_level)
_, topk_idxs_per_level = distances_per_level.topk(
selectable_k, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
start_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
# get corresponding iou for the these candidates, and compute the
# mean and std, set mean + std as the iou threshold
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
overlaps_mean_per_gt = candidate_overlaps.mean(0)
overlaps_std_per_gt = candidate_overlaps.std(0)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
# limit the positive sample's center in gt
for gt_idx in range(num_gt):
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
candidate_idxs = candidate_idxs.view(-1)
# calculate the left, top, right, bottom distance between positive
# bbox center and gt side
l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
is_pos = is_pos & is_in_gts
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
参考资料
更多推荐
所有评论(0)