基于pytorch的保存和加载模型参数的方法

yipeiwu_com6年前Python基础

当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了。

保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path):

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path)):

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path):

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path):

功能:加载保存到path中的整个神经网络

说明:官方推荐方式一,原因自然是保存的内容少,速度会更快。

以上这篇基于pytorch的保存和加载模型参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

将python代码和注释分离的方法

python的注释方式和C语言、C++、java有所不同 python语言中,使用‘#' 来进行注释,其次还有使用 三个引号来进行注释 本文的程序将把 python 中 使用‘#' 号...

python列表list保留顺序去重的实例

常规通过迭代或set方法,都无法保证去重后的顺序问题 如下,我们可以通过列表的索引功能,对set结果进行序列化 old_list=["a",1,"b","a","b",2,5,1]...

django 信号调度机制详解

django 信号调度机制详解

前言 Django中提供了“信号调度”,用于在框架执行操作时解耦。通俗来讲,就是一些动作发生的时候,信号允许特定的发送者去提醒一些接受者。 1、Django内置信号 Model si...

python利用sklearn包编写决策树源代码

python利用sklearn包编写决策树源代码

本文实例为大家分享了python编写决策树源代码,供大家参考,具体内容如下 因为最近实习的需要,所以用python里的sklearn包重新写了一次决策树。 工具:sklearn,将dot...

解决python2 绘图title,xlabel,ylabel出现中文乱码的问题

解决python2 绘图title,xlabel,ylabel出现中文乱码的问题

绘制图形时使用了中文标题,会出现乱码 原因是matplotlib.pyplot在显示时无法找到合适的字体。 先把需要的字体(在系统盘C盘的windows下的fonts目录内)添加到Fo...