YOLOv1学习:(一)网络结构推导与实现

jupiter
2021-01-12 / 0 评论 / 632 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2021年11月25日,已超过882天没有更新,若内容或图片失效,请留言反馈。

YOLOv1学习:(一)网络结构推导与实现

原论文网络结构

img

知乎看到的网络结构分析(见参考资料1)

image-20210112143928520

二次网络结构分析

image-20210112155724399

7*7*30输出解释

img

实际操作如图所示,分为7*7个小格子,每个格子预测两个bounding box。

如果一个目标的中心落入一个网格单元中,该网格单元负责检测 该目标。

对每一个切割的小单元格预测(置信度,边界框的位置),每个bounding box需要4个数值来表示其位置,(Center_x,Center_y,width,height),即(bounding box的中心点的x坐标,y坐标,bounding box的宽度,高度)

置信度定义为该区域内是否包含物体的概率,打标签的时候,正样本(与真实物体有最大IOU的边框设为正样本)置信度真值为1,负样本为0.

还要得到分类的概率结果;20个分类每个类别的概率。

7*7*30中的30=(20类概率+2*5(置信度,边框位置))

img

Pytorch实现网络结构

基本骨架

import torch
import torch.nn as nn

feature = nn.Sequential(
    nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=64,out_channels=192,kernel_size=3,stride=1,padding=1),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=192,out_channels=128,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=256,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),

    nn.Conv2d(in_channels=512,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),

    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),

    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=2,padding=1),

    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
)

classify = nn.Sequential(
    nn.Flatten(),
    nn.Linear(1024 * 7 * 7, 4096),
    nn.Linear(4096, 1470) #1470=7*7*30
)

yolov1 = nn.Sequential(
    feature,
    classify
)

基本骨架-结构打印

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (12): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (14): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (16): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    (18): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (20): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
    (21): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
    (23): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (26): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1): Sequential(
    (0): Flatten()
    (1): Linear(in_features=50176, out_features=4096, bias=True)
    (2): Linear(in_features=4096, out_features=1470, bias=True)
  )
)

加入损失函数和Dropout

import torch
import torch.nn as nn

feature = nn.Sequential(
    nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3),
    nn.LeakyReLU(),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=64,out_channels=192,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=192,out_channels=128,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=256,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1),

    nn.Conv2d(in_channels=512,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),

    nn.MaxPool2d(kernel_size=2,stride=2),

    nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),

    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=2,padding=1),
    nn.LeakyReLU(),

    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
    nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3,stride=1,padding=1),
    nn.LeakyReLU(),
)

classify = nn.Sequential(
    nn.Flatten(),
    nn.Linear(1024 * 7 * 7, 4096),
    nn.Dropout(0.5),
    nn.Linear(4096, 1470) #1470=7*7*30
)

yolov1 = nn.Sequential(
    feature,
    classify
)

print(yolov1)

参考资料

  1. YOLO V1 网络结构分析:https://zhuanlan.zhihu.com/p/220062200?utm_source=wechat_session
  2. YOLOv1算法理解:https://www.cnblogs.com/ywheunji/p/10808989.html
0

评论

博主关闭了当前页面的评论