pytorch 逐层加载模型参数

jupiter
2022-01-08 / 0 评论 / 670 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2022年01月08日,已超过839天没有更新,若内容或图片失效,请留言反馈。

pytorch 逐层加载模型参数

1.现象描述

当对模型结构的某些层做了微小调整之后,导致该层参数的shape发生了微小变化导致无法对整个模型进行加载,这时可以考虑通过逐层加载的方式来跳过某些层完成对保存好的模型的加载。

2.代码实现

# 模型参数逐层加载
ckpt_state_dict = torch.load("ckpt/best_att_wer_0.015267175572519083.pth")["model"]
model_state_dict = model.state_dict()
for key in model_state_dict.keys():
    if model_state_dict[key].shape == ckpt_state_dict[key].shape:
        model_state_dict[key] = ckpt_state_dict[key]
model.load_state_dict(model_state_dict)
0

评论 (0)

打卡
取消