pytorch 实现打印模型的参数值

yipeiwu_com6年前Python基础

对于简单的网络

例如全连接层Linear

可以使用以下方法打印linear层:

fc = nn.Linear(3, 5)
params = list(fc.named_parameters())
print(params.__len__())
print(params[0])
print(params[1])

输出如下:

由于Linear默认是偏置bias的,所有参数列表的长度是2。第一个存的是全连接矩阵,第二个存的是偏置。

对于稍微复杂的网络

例如MLP

mlp = nn.Sequential(
      nn.Dropout(p=0.3),
      nn.Linear(1024, 256),
      nn.Linear(256, 64),
      nn.Linear(64, 16),
      nn.Linear(16, 1)
    )
params = list(mlp.named_parameters())
print(params.__len__())

print(params[0])
print(params[1])

print(params[2])
print(params[3])

输出:

可以发现,堆叠起来的网络,参数是依次放置的。先是全连接的权重,然后偏置。然后是下一层网络的权重+偏置。依次进行下去。

这里有4层fc,4*2=8.所以一共有8个参数矩阵。

以上这篇pytorch 实现打印模型的参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python自动结束mysql慢查询会话的实例代码

生产环境的有些sql查询写得太复杂,或是表很大,对应索引未建立或建立不合理,或是查询未充分使用索引等,就有可能出现慢查询,一些慢查询需要修改程序,可能没那么快能解决,这时如果有个脚本能自...

python读写json文件的简单实现

python读写json文件的简单实现

JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式。它基于ECMAScript的一个子集。 JSON采用完全独立于语言的文本格式,但是也使用了类...

python操作MySQL 模拟简单银行转账操作

python操作MySQL 模拟简单银行转账操作

一、基础知识 1、MySQL-python的安装 下载,然后 pip install 安装包 2、python编写通用数据库程序的API规范 (1)、数据库连接对象 connection...

python树的同构学习笔记

python树的同构学习笔记

一、题意理解 给定两棵树T1和T2。如果T1可以通过若干次左右孩子互换就变成T2,则我们称两棵树是“同构的”。现给定两棵树,请你判断它们是否是同构的。 输入格式:输入给出2棵二叉树的信...

python如何制作英文字典

本文实例为大家分享了python制作英文字典的具体代码,供大家参考,具体内容如下 功能有添加单词,多次添加单词的意思,查询,退出,建立单词文件。 keys=[] dic={} def...