转载需注明出处:https://www.codelast.com/

根据PyTorch文档,在把PyTorch模型保存成文件的时候有两种方法,第一种是推荐的:

torch.save(the_model.state_dict(), PATH)

对应地,加载模型这样做:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))


另一种方法是不推荐的:

torch.save(the_model, PATH)

对应地,加载模型这样做:

the_model = torch.load(PATH)

文章来源:https://www.codelast.com/
这两者的区别:第1种方法只保存了模型的参数,而第2种方法保存了整个模型(结构+参数),所以第2种方法保存出来的文件体积会比第1种方法大
使用第2种方法的话,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用或经过一些大的重构后,它可能会失效。

[原创] PyTorch模型的两种保存方法
Tagged on:         

发表评论

电子邮件地址不会被公开。 必填项已用*标注