Pytorch 实战:使用Resnet18实现对是否戴口罩进行图片分类

jupiter
2021-11-25 / 0 评论 / 730 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2021年12月07日,已超过842天没有更新,若内容或图片失效,请留言反馈。

1.实验环境

  • torch = 1.6.0
  • torchvision = 0.7.0
  • matplotlib = 3.3.3 # 绘图用
  • progressbar = 2.5 # 绘制进度条用
  • easydict # 超参数字典功能增强

2.数据集

3.导入相关的包

# 导包
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import  DataLoader

import torchvision
from torchvision import datasets,transforms

import matplotlib.pyplot as plt
import random
from progressbar import *

4.设置超参数

# 定义超参数
from easydict import EasyDict

super_param={
    'train_data_root' : './data/train',
    'val_data_root' : './data/train',
    'device': torch.device('cuda:0' if torch.cuda.is_available() else cpu),
    'lr': 0.001,
    'epochs': 3,
    'batch_size': 1,
    'begain_epoch':0,
    'model_load_flag':False, #是否加载以前的模型
    'model_load_path':'./model/resnet18/epoch_1_0.8861347792408986.pkl',
    'model_save_dir':'./model/resnet18',
}
super_param = EasyDict(super_param)

if not os.path.exists(super_param.model_save_dir):
    os.mkdir(super_param.model_save_dir)

5.模型搭建

# 模型搭建,调用预训练模型Resnet18
class Modified_Resnet18(nn.Module):
    """docstring for ClassName"""
    def __init__(self, num_classs=3):
        super(Modified_Resnet18, self).__init__()
        model = torchvision.models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features,num_classs)
        self.model = model
    def forward(self, x):
        x = self.model(x)
        return x
model = Modified_Resnet18()

print(model)
Modified_Resnet18(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer3): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=512, out_features=3, bias=True)
  )
)

6.加载数据集

# 训练数据封装成dataloader

# 定义数据处理的transform
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),  # 0-255 to 0-1
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset =torchvision.datasets.ImageFolder(root=super_param.train_data_root,transform=transform)
train_loader =DataLoader(train_dataset,batch_size=super_param.batch_size, shuffle=True)

val_dataset =torchvision.datasets.ImageFolder(root=super_param.val_data_root,transform=transform)
val_loader =DataLoader(val_dataset,batch_size=super_param.batch_size, shuffle=True)

# 保存类别与类别索引的对应关系
class_to_idx = train_dataset.class_to_idx
idx_to_class = dict([val,key] for key,val in class_to_idx.items())

print(len(train_dataset),len(val_dataset))
  • 查看一个数据样例
#查看一个数据
import matplotlib.pyplot as plt

index = random.randint(0,len(val_dataset))
img,idx = val_dataset[index]

img = img.permute(1,2,0)

label = idx_to_class[idx]

print("label=",label)
plt.figure(dpi=100)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
plt.show()

7.定义损失函数和优化器

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=super_param.lr)

8.定义单个epoch的训练函数

# 定义单个epoch的训练函数
def train_epoch(model,train_loader,super_param,criterion,optimizer,epoch):
    model.train()#训练声明
    for batch_index,(imgs,labels) in enumerate(train_loader):
        #data上device
        imgs,labels = imgs.to(super_param.device),labels.to(super_param.device)
        #梯度清零
        optimizer.zero_grad()
        #前向传播
        output = model(imgs)
        
        #损失计算
        loss = criterion(output,labels)
        #梯度计算,反向传播
        loss.backward()
        #参数优化
        optimizer.step()
        
        #打印参考信息
        if(batch_index % 10 == 0):
            print("\rEpoch:{} Batch Index(batch_size={}):{}/{}   Loss:{}".format(epoch,super_param.batch_size,batch_index,len(train_loader),loss.item()),end="")

9.定义验证函数

# 定义验证函数
def val(model,val_loader,super_param,criterion):
    model.eval()#测试声明

    correct = 0.0 #正确数量
    val_loss = 0.0 #测试损失
    
    #定义进度条
    widgets = ['Val: ',Percentage(), ' ', Bar('#'),' ', Timer(),' ', ETA()]
    pbar = ProgressBar(widgets=widgets, maxval=100).start()
    

    with torch.no_grad(): # 不会计算梯度,也不会进行反向传播
        for batch_index,(imgs,labels) in enumerate(val_loader):
            imgs,labels = imgs.to(super_param.device),labels.to(super_param.device)
            
            output = model(imgs)#模型预测
            
            val_loss += criterion(output,labels).item() # 计算测试损失

            #argmax返回 值,索引 dim=1表示要索引
            pred = output.argmax(dim=1) # 找到概率最大的下标 
            correct += pred.eq(labels.view_as(pred)).sum().item()# 统计预测正确数量
#             print("pred===========",pred)
            
            pbar.update(batch_index/len(val_loader)*100)#更新进度条进度
            
    #释放进度条
    pbar.finish()

    val_loss /= len(val_loader.dataset)
    val_accuracy = correct / len(val_loader.dataset)
    time.sleep(0.01)
    print("Val --- Avg Loss:{},Accuracy:{}".format(val_loss,val_accuracy))
    
    return val_loss,val_accuracy

10.模型训练

model = model.to(super_param.device)

if super_param.model_load_flag:
    #加载训练过的模型
    model.load_state_dict(torch.load(super_param.model_load_path))

# 数据统计-用于绘图和模型保存
epoch_list = []
loss_list = []
accuracy_list =[]

best_accuracy = 0.0

for epoch in range(super_param.begain_epoch,super_param.begain_epoch+super_param.epochs):
    train_epoch(model,train_loader,super_param,criterion,optimizer,epoch)
    val_loss,val_accuracy = val(model,val_loader,super_param,criterion)
    
    #数据统计
    epoch_list.append(epoch)
    loss_list.append(val_loss)
    accuracy_list.append(accuracy_list)
    
    #保存准确率更高的模型
    if(val_accuracy>best_accuracy):
        best_accuracy = val_accuracy
        torch.save(model.state_dict(),os.path.join(super_param.model_save_dir, 'epoch_' + str(epoch)+ '_' + str(best_accuracy) + '.pkl'))
        print('epoch_' + str(epoch) + '_' + str(best_accuracy) + '.pkl'+"保存成功")

11.查看数据统计结果

# 查看数据统计结果

fig = plt.figure(figsize=(12,12),dpi=70)

#子图1
ax1 = plt.subplot(2,1,1)
title = "bach_size={},lr={}".format(super_param.batch_size,super_param.lr)
plt.title(title,fontsize=15)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Loss',fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.plot(epoch_list,loss_list)


#子图2
ax2 = plt.subplot(2,1,2)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.plot(epoch_list,accuracy_list,'r')

plt.show()
0

评论 (0)

打卡
取消