基于pytorch的保存和加载模型参数的方法

yipeiwu_com5年前Python基础

当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了。

保存和加载模型参数有两种方式:

方式一:

torch.save(net.state_dict(),path):

功能:保存训练完的网络的各层参数(即weights和bias)

其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)

net2.load_state_dict(torch.load(path)):

功能:加载保存到path中的各层参数到神经网络

注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

方式二:

torch.save(net,path):

功能:保存训练完的整个网络模型(不止weights和bias)

net2=torch.load(path):

功能:加载保存到path中的整个神经网络

说明:官方推荐方式一,原因自然是保存的内容少,速度会更快。

以上这篇基于pytorch的保存和加载模型参数的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Django如何开发简单的查询接口详解

前言 Django处理json也是一把好手,有时候在工作中各个部门都会提供自己的相关接口,但是信息也只是单方的信息,这时候需要运维将各个部门的信息进行集成,统一出一个查询接口或页面,方便...

Django RBAC权限管理设计过程详解

Django RBAC权限管理设计过程详解

一.权限简介 1. 问:为什么程序需要权限控制? 答:生活中的权限限制,① 看灾难片电影《2012》中富人和权贵有权登上诺亚方舟,穷苦老百姓只有等着灾难的来临;② 屌丝们,有没有想过为...

Python中的模块导入和读取键盘输入的方法

导入模块 import 语句 想使用Python源文件,只需在另一个源文件里执行import语句,语法如下: import module1[, module2[,... modul...

浅析pandas 数据结构中的DataFrame

浅析pandas 数据结构中的DataFrame

DataFrame 类型类似于数据库表结构的数据结构,其含有行索引和列索引,可以将DataFrame 想成是由相同索引的Series组成的Dict类型。在其底层是通过二维以及一维的数据块...

Python使用psutil获取进程信息的例子

psutil是什么 psutil是一个能够获取系统信息(包括进程、CPU、内存、磁盘、网络等)的Python模块。主要用来做系统监控,性能分析,进程管理,像glances也是基于psut...