pytorch 逐层加载模型参数

1.现象描述

当对模型结构的某些层做了微小调整之后,导致该层参数的shape发生了微小变化导致无法对整个模型进行加载,这时可以考虑通过逐层加载的方式来跳过某些层完成对保存好的模型的加载。

2.代码实现

# 模型参数逐层加载
ckpt_state_dict = torch.load("ckpt/best_att_wer_0.015267175572519083.pth")["model"]
model_state_dict = model.state_dict()
for key in model_state_dict.keys():
    if model_state_dict[key].shape == ckpt_state_dict[key].shape:
        model_state_dict[key] = ckpt_state_dict[key]
model.load_state_dict(model_state_dict)
Last modification:January 8th, 2022 at 02:04 pm
如果觉得我的文章对你有用,请随意赞赏