文章目录
- 1. 为什么要使用NMS
- 2. NMS算法原理
-
- 2.1 IoU与置信度
- 2.2 算法流程
- 3. Python代码实现
1. 为什么要使用NMS
大多数目标检测算法(稠密预测)在得到最终的预测结果时,特征图的每个位置都会输出多个检测结果,整个特征图上会出很多个重叠的框。例如要检测一辆车,可能会有多个bbox都把这辆车给框了出来,因此需要从这些bbox中选出框得最好的,删除掉其它的。要定义框得好与不好,就得看bbox的预测置信度;为了删掉重叠的多余的框,就得利用IoU来检查重叠程度。
2. NMS算法原理
2.1 IoU与置信度
在NMS中,需要将与当前bbox的IoU超过一定阈值的框给删除。而当前bbox的选择则是根据置信度的排序,置信度最高的说明框得最准,将它作为基准,删除掉和它重叠度(也就是IoU)过高的bbox。而置信度的计算在不同算法中也是不同的,可以参考YOLOv1算法中置信度的定义作为一种参考。为了定义重叠度是否过高,需要引入一个阈值超参数,这个阈值和计算mAP定义正负样本时的阈值是完全不一样的。此外,在执行NMS之前一般会先把置信度过低的bbox给初筛一遍,这些bbox框得不准,放到NMS中会增加计算负担,这里也会给置信度设置一个阈值,所以区分这几个阈值对于理解目标检测算法流程非常重要。并且,NMS一般是对每个类别的bbox分别使用的,这样的话就不会把重叠度高但属于不同类别的bbox给误删了。
2.2 算法流程
Step1:将原始列表中的所有bbox按照置信度从高到低进行排序;
Step2:选取当前置信度最高的bbox,记为
b
b
b,并将其放到最终的结果列表里;
Step3:计算剩余所有bbox与
b
b
b的IoU,将IoU大于阈值的bbox全部删除;
Step4:从原始列表中删除
b
b
b;
Step5:重复Step2-4,直到原始列表中不再有bbox;
Step6:返回结果列表,即为NMS筛选后的结果。
3. Python代码实现
import numpy as np
# dets: 检测出的图中某一类别的bbox及对应的置信度,列表中的每个元素为[x1, y1, x2, y2, confidence];
# thresh: 设定的IoU阈值
def nms(dets, thresh):
# 预处理
# 提取各个bbox的位置,即左上角和右下角坐标,用于后续计算IoU里的各种面积
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4] # 提取各个bbox的置信度
areas = (x2 - x1 + 1) * (y2 - y1 + 1) # 计算各个bbox的面积
# Step1:将原始列表中的所有bbox按照置信度从高到低进行排序
order = scores.argsort()[::-1] # 在从大到小排序后返回索引值,即order[0]表示scores列表里最大值的索引
keep = [] # 保存筛选后重叠度低于阈值的bbox,注意,返回的是原始列表中要保留的bbox的索引
# Step5:重复Step2-4,直到原始列表中不再有bbox
while order.size > 0:
# Step2:选取当前置信度最高的bbox,记为$b$,并将其放到最终的结果列表里
i = order[0] # scores列表里置信度最高的bbox对应的的索引
keep.append(i) # 将当前这个框得最准的bbox保存到输出结果里
# Step3:计算剩余所有bbox与$b$的IoU,将IoU大于阈值的bbox全部删除
# 首先要计算出bbox重叠部分的左上角和右下角坐标
# 即取两个bbox中左上角值较大者和右下角值较小者
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
# 计算重叠部分的宽和高,如果算出来是负值则说明两个bbox不重叠,因此要把相应的宽/高置0
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
# 计算重叠部分的面积
inter = w * h
# 计算$b$与剩余所有bbox的IoU
ovr = inter / (areas[i] + areas[order[1:]] - inter)
# 将IoU大于阈值的bbox全部删除,也就是把重叠度较小的bbox给保留下来
inds = np.where(ovr thresh)[0]
# Step4:从原始列表中删除$b$
# 由于之前的操作都是算的剩余的bbox与$b$的关系,也就是排除了原始列表的首个元素(从算xx1开始)
# 所以上面得到的inds要+1才是真正对应到原始列表中的索引,这个过程也就自动地把$b$拿掉了
order = order[inds + 1]
# Step6:返回结果列表,即为NMS筛选后的结果
return keep
if __name__ == '__main__':
bounding_boxes = np.array([[187, 82, 337, 317, 0.9], [150, 67, 305, 282, 0.75], [246, 121, 368, 304, 0.8]])
threshold = 0.6
keep = nms(bounding_boxes, threshold)
print(keep)
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
相关推荐: 文盘Rust —— rust连接oss | 京东云技术团队
作者:京东科技 贾世闻 对象存储是云的基础组件之一,各大云厂商都有相关产品。这里跟大家介绍一下rust与对象存储交到的基本套路和其中的一些技巧。 基本连接 我们以 [S3 sdk]( https://github.com/awslabs/aws-sdk-rus…