YOLOv3学习:(二)网络结构推导与实现

jupiter
2021-02-06 / 0 评论 / 874 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2021年11月25日,已超过964天没有更新,若内容或图片失效,请留言反馈。

YOLOv3学习:(二)网络结构推导与实现

网络结构图简版:

网络结构图简版+特征图的大小变换:

网络结构-详细版

网络结构模块化

网络结构图展开(超详细版)

网络结构+示例-3D版(利用多尺度特征进行对象检测)

9种尺度的先验框

随着输出的特征图的数量和尺度的变化,先验框的尺寸也需要相应的调整。YOLO2已经开始采用K-means聚类得到先验框的尺寸,YOLO3延续了这种方法,为每种下采样尺度设定3种先验框,总共聚类出9种尺寸的先验框。在COCO数据集这9个先验框是:(10x13),(16x30),(33x23),(30x61),(62x45),(59x119),(116x90),(156x198),(373x326)。

分配上,在最小的1313特征图上(有最大的感受野)应用较大的先验框(116x90),(156x198),(373x326),适合检测较大的对象。中等的2626特征图上(中等感受野)应用中等的先验框(30x61),(62x45),(59x119),适合检测中等大小的对象。较大的52*52特征图上(较小的感受野)应用较小的先验框(10x13),(16x30),(33x23),适合检测较小的对象。

感受一下9种先验框的尺寸,下图中蓝色框为聚类得到的先验框。黄色框式ground truth,红框是对象中心点所在的网格。

输入到输出的映射(包含输出参数的解释)

不考虑神经网络结构细节的话,总的来说,对于一个输入图像,YOLO3将其映射到3个尺度的输出张量,代表图像各个位置存在各种对象的概率。

我们看一下YOLO3共进行了多少个预测。对于一个416416的输入图像,在每个尺度的特征图的每个网格设置3个先验框,总共有 13133 + 26263 + 5252*3 = 10647 个预测。每一个预测是一个(4+1+80)=85维向量,这个85维向量包含边框坐标(4个数值),边框置信度(1个数值),对象类别的概率(对于COCO数据集,有80种对象)。

对比一下,YOLO2采用13135 = 845个预测,YOLO3的尝试预测边框数量增加了10多倍,而且是在不同分辨率上进行,所以mAP以及对小物体的检测效果有一定的提升。

代码实现

代码

import torch
import torch.nn as nn

# Darknet53 中的基本块--卷积块,由Conv+BN+LeakyReLU共同组成
class ConvBNReLU(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
        super(ConvBNReLU,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
        self.BN = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.ReLU6(inplace=True)
    def forward(self,x):
        x = self.conv(x)
        x = self.BN(x)
        x = self.leaky_relu(x)
        return x

# Darknet53 中的基本块--下采样块,用卷积(stride=2)实现
class DownSample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DownSample,self).__init__()
        self.down_samp = nn.Conv2d(in_channels,out_channels,3,2,1)
    def forward(self,x):
        x = self.down_samp(x)
        return x

# Darknet53 中的基本块--ResBlock
class ResBlock(nn.Module):
    def __init__(self, nchannels):
        super(ResBlock, self).__init__()
        mid_channels = nchannels // 2
        self.conv1x1 = ConvBNReLU(nchannels, mid_channels,1,1,0)
        self.conv3x3 = ConvBNReLU(mid_channels, nchannels,3,1,1)

    def forward(self, x):
        out = self.conv3x3(self.conv1x1(x))
        return out + x
    
# YOLOv3 骨干网络 -DarkNet53
class DarkNet53_YOLOv3(nn.Module):
    def __init__(self):
        super(DarkNet53_YOLOv3, self).__init__()
        self.conv_bn_relu = ConvBNReLU(3,32,3,1,1)
        self.down_samp_0 = DownSample(32,64)
        self.res_block_1 = ResBlock(64)
        self.down_samp_1 = DownSample(64,128)
        self.res_block_2 = ResBlock(128)
        self.down_samp_2 = DownSample(128,256)
        self.res_block_3 = ResBlock(256)
        self.down_samp_3 = DownSample(256,512)
        self.res_block_4 = ResBlock(512)
        self.down_samp_4 = DownSample(512,1024)
        self.res_block_5 = ResBlock(1024)

    def forward(self, x):
        out1 = self.conv_bn_relu(x)
        
        out1 = self.down_samp_0(out1)
        
        out1 = self.res_block_1(out1)
        
        out1 = self.down_samp_1(out1)
        
        out1 = self.res_block_2(out1)
        out1 = self.res_block_2(out1)
        
        out1 = self.down_samp_2(out1)
        
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        out1 = self.res_block_3(out1)
        
        out2 = self.down_samp_3(out1)
        
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        out2 = self.res_block_4(out2)
        
        out3 = self.down_samp_4(out2)
        
        out3 = self.res_block_5(out3)
        out3 = self.res_block_5(out3)
        out3 = self.res_block_5(out3)
        out3 = self.res_block_5(out3)
        out3 = self.res_block_5(out3)
        
        return out1,out2,out3

# YOLOv3 13*13 输出分支的darknet53后的几层
class Out1LastLayers(nn.Module): #input_shape = (1024, 13, 13) out_shape = (255,13,13) out_branck_shape = (512,13,13)
    def __init__(self):
        super(Out1LastLayers, self).__init__()
        self.conv1x1 = ConvBNReLU(1024,512,1,1,0)
        self.conv3x3 = ConvBNReLU(512, 1024,3,1,1)
        self.conv1x1_last = ConvBNReLU(1024,255,1,1,0)
    
    def forward(self,x):
        out = self.conv1x1(x)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out_branch = out
        out = self.conv3x3(out)
        
        out = self.conv1x1_last(out)
        
        return out,out_branch

# YOLOv3 26*26 输出分支的darknet53后的几层
class Out2LastLayers(nn.Module): #input_shape = (512, 26, 26) out_shape = (255,26,26) out_branck_shape = (256,26,26)
    def __init__(self):
        super(Out2LastLayers, self).__init__()
        self.conv1x1 = ConvBNReLU(512,256,1,1,0)
        self.conv3x3 = ConvBNReLU(256,512,3,1,1)
        self.up_sample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1x1_after_concat = ConvBNReLU(768,256,1,1,0)
        self.conv1x1_last = ConvBNReLU(512,255,1,1,0)
    
    def forward(self,x,x_branch):
        out = self.conv1x1(x_branch)
        out = self.up_sample(out)
        out = torch.cat([x,out],1)
        
        out = self.conv1x1_after_concat(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out_branch = out
        out = self.conv3x3(out)
        
        out = self.conv1x1_last(out)
        
        return out,out_branch

# YOLOv3 52*52 输出分支的darknet53后的几层
class Out3LastLayers(nn.Module): #input_shape = (256, 52, 52) out_shape = (255,52,52) 
    def __init__(self):
        super(Out3LastLayers, self).__init__()
        self.conv1x1 = ConvBNReLU(256,128,1,1,0)
        self.conv3x3 = ConvBNReLU(128,256,3,1,1)
        self.up_sample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1x1_after_concat = ConvBNReLU(384,128,1,1,0)
        self.conv1x1_last = ConvBNReLU(256,255,1,1,0)
    
    def forward(self,x,x_branch):
        out = self.conv1x1(x_branch)
        out = self.up_sample(out)
        out = torch.cat([x,out],1)
        
        out = self.conv1x1_after_concat(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1(out)
        out = self.conv3x3(out)
        
        out = self.conv1x1_last(out)
        
        return out

# YOLOv3模型
class YOLOv3(nn.Module):
    def __init__(self):
        super(YOLOv3, self).__init__()
        self.darknet53 = DarkNet53_YOLOv3()
        self.out1_last_layers =  Out1LastLayers()
        self.out2_last_layers =  Out2LastLayers()
        self.out3_last_layers =  Out3LastLayers()
        
        
    def forward(self, x):
        out3,out2,out1 = self.darknet53(x) # out1.shape,out2.shape,out3.shape =  (256, 52, 52),(512, 26, 26),(1024, 13, 13)
        out1,out1_branch = self.out1_last_layers(out1)
        out2,out2_branch = self.out2_last_layers(out2,out1_branch)
        out3 = self.out3_last_layers(out3,out2_branch)
        
        return out1,out2,out3

输入输出测试

fake_input = torch.zeros((1,3,416,416))
print(fake_input.shape)
model = YOLOv3()
out1,out2,out3= model(fake_input)
print(out1.shape,out2.shape,out3.shape)
torch.Size([1, 3, 416, 416])
torch.Size([1, 255, 13, 13]) torch.Size([1, 255, 26, 26]) torch.Size([1, 255, 52, 52])

参考资料

  1. YOLOv3网络结构和解析:https://blog.csdn.net/dz4543/article/details/90049377
  2. Darknet53网络各层参数详解:https://blog.csdn.net/qq_40210586/article/details/106144197
  3. 目标检测0-02:YOLO V3-网络结构输入输出解析:https://blog.csdn.net/weixin_43013761/article/details/98349080
  4. YOLOv3 深入理解:https://www.jianshu.com/p/d13ae1055302
0

评论 (0)

打卡
取消