1.实验环境
- torch = 1.6.0
- torchvision = 0.7.0
- matplotlib = 3.3.3 # 绘图用
- progressbar = 2.5 # 绘制进度条用
- easydict # 超参数字典功能增强
2.数据集
数据集介绍
- 包含2582张图片,3个类别(yes/unknow/no)
- 下载地址:口罩检测数据集
3.导入相关的包
# 导包
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import random
from progressbar import *
4.设置超参数
# 定义超参数
from easydict import EasyDict
super_param={
'train_data_root' : './data/train',
'val_data_root' : './data/train',
'device': torch.device('cuda:0' if torch.cuda.is_available() else cpu),
'lr': 0.001,
'epochs': 3,
'batch_size': 1,
'begain_epoch':0,
'model_load_flag':False, #是否加载以前的模型
'model_load_path':'./model/resnet18/epoch_1_0.8861347792408986.pkl',
'model_save_dir':'./model/resnet18',
}
super_param = EasyDict(super_param)
if not os.path.exists(super_param.model_save_dir):
os.mkdir(super_param.model_save_dir)
5.模型搭建
# 模型搭建,调用预训练模型Resnet18
class Modified_Resnet18(nn.Module):
"""docstring for ClassName"""
def __init__(self, num_classs=3):
super(Modified_Resnet18, self).__init__()
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features,num_classs)
self.model = model
def forward(self, x):
x = self.model(x)
return x
model = Modified_Resnet18()
print(model)
Modified_Resnet18(
(model): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=3, bias=True)
)
)
6.加载数据集
# 训练数据封装成dataloader
# 定义数据处理的transform
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(), # 0-255 to 0-1
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_dataset =torchvision.datasets.ImageFolder(root=super_param.train_data_root,transform=transform)
train_loader =DataLoader(train_dataset,batch_size=super_param.batch_size, shuffle=True)
val_dataset =torchvision.datasets.ImageFolder(root=super_param.val_data_root,transform=transform)
val_loader =DataLoader(val_dataset,batch_size=super_param.batch_size, shuffle=True)
# 保存类别与类别索引的对应关系
class_to_idx = train_dataset.class_to_idx
idx_to_class = dict([val,key] for key,val in class_to_idx.items())
print(len(train_dataset),len(val_dataset))
- 查看一个数据样例
#查看一个数据
import matplotlib.pyplot as plt
index = random.randint(0,len(val_dataset))
img,idx = val_dataset[index]
img = img.permute(1,2,0)
label = idx_to_class[idx]
print("label=",label)
plt.figure(dpi=100)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
plt.show()
7.定义损失函数和优化器
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=super_param.lr)
8.定义单个epoch的训练函数
# 定义单个epoch的训练函数
def train_epoch(model,train_loader,super_param,criterion,optimizer,epoch):
model.train()#训练声明
for batch_index,(imgs,labels) in enumerate(train_loader):
#data上device
imgs,labels = imgs.to(super_param.device),labels.to(super_param.device)
#梯度清零
optimizer.zero_grad()
#前向传播
output = model(imgs)
#损失计算
loss = criterion(output,labels)
#梯度计算,反向传播
loss.backward()
#参数优化
optimizer.step()
#打印参考信息
if(batch_index % 10 == 0):
print("\rEpoch:{} Batch Index(batch_size={}):{}/{} Loss:{}".format(epoch,super_param.batch_size,batch_index,len(train_loader),loss.item()),end="")
9.定义验证函数
# 定义验证函数
def val(model,val_loader,super_param,criterion):
model.eval()#测试声明
correct = 0.0 #正确数量
val_loss = 0.0 #测试损失
#定义进度条
widgets = ['Val: ',Percentage(), ' ', Bar('#'),' ', Timer(),' ', ETA()]
pbar = ProgressBar(widgets=widgets, maxval=100).start()
with torch.no_grad(): # 不会计算梯度,也不会进行反向传播
for batch_index,(imgs,labels) in enumerate(val_loader):
imgs,labels = imgs.to(super_param.device),labels.to(super_param.device)
output = model(imgs)#模型预测
val_loss += criterion(output,labels).item() # 计算测试损失
#argmax返回 值,索引 dim=1表示要索引
pred = output.argmax(dim=1) # 找到概率最大的下标
correct += pred.eq(labels.view_as(pred)).sum().item()# 统计预测正确数量
# print("pred===========",pred)
pbar.update(batch_index/len(val_loader)*100)#更新进度条进度
#释放进度条
pbar.finish()
val_loss /= len(val_loader.dataset)
val_accuracy = correct / len(val_loader.dataset)
time.sleep(0.01)
print("Val --- Avg Loss:{},Accuracy:{}".format(val_loss,val_accuracy))
return val_loss,val_accuracy
10.模型训练
model = model.to(super_param.device)
if super_param.model_load_flag:
#加载训练过的模型
model.load_state_dict(torch.load(super_param.model_load_path))
# 数据统计-用于绘图和模型保存
epoch_list = []
loss_list = []
accuracy_list =[]
best_accuracy = 0.0
for epoch in range(super_param.begain_epoch,super_param.begain_epoch+super_param.epochs):
train_epoch(model,train_loader,super_param,criterion,optimizer,epoch)
val_loss,val_accuracy = val(model,val_loader,super_param,criterion)
#数据统计
epoch_list.append(epoch)
loss_list.append(val_loss)
accuracy_list.append(accuracy_list)
#保存准确率更高的模型
if(val_accuracy>best_accuracy):
best_accuracy = val_accuracy
torch.save(model.state_dict(),os.path.join(super_param.model_save_dir, 'epoch_' + str(epoch)+ '_' + str(best_accuracy) + '.pkl'))
print('epoch_' + str(epoch) + '_' + str(best_accuracy) + '.pkl'+"保存成功")
11.查看数据统计结果
# 查看数据统计结果
fig = plt.figure(figsize=(12,12),dpi=70)
#子图1
ax1 = plt.subplot(2,1,1)
title = "bach_size={},lr={}".format(super_param.batch_size,super_param.lr)
plt.title(title,fontsize=15)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Loss',fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.plot(epoch_list,loss_list)
#子图2
ax2 = plt.subplot(2,1,2)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.plot(epoch_list,accuracy_list,'r')
plt.show()
评论 (0)