pytorch 固定部分参数训练的方法

yipeiwu_com6年前Python基础

需要自己过滤

optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

另外,如果是Variable,则可以初始化时指定

j = Variable(torch.randn(5,5), requires_grad=True)

但是如果是

m = nn.Linear(10,10)

是没有requires_grad传入的

m.requires_grad也没有

需要

for i in m.parameters():
  i.requires_grad=False

另外一个小技巧就是在nn.Module里,可以在中间插入这个

for p in self.parameters():
  p.requires_grad=False

这样前面的参数就是False,而后面的不变

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)

    for p in self.parameters():
      p.requires_grad=False

    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

以上这篇pytorch 固定部分参数训练的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

通过cmd进入python的实例操作

通过cmd进入python的实例操作

通过cmd启动Python需要先设置系统环境,设置步骤如下: 1、首先,在桌面找到 “计算机” 右键 找到 “属性”或者按下 win 键 再右键“计算机” 找到 “属性”也可以。如下图所...

python Web开发你要理解的WSGI & uwsgi详解

python Web开发你要理解的WSGI & uwsgi详解

WSGI协议 首先弄清下面几个概念: WSGI:全称是Web Server Gateway Interface,WSGI不是服务器,python模块,框架,API或者任何软件,只是一种...

python实现图片处理和特征提取详解

python实现图片处理和特征提取详解

这是一张灵异事件图。。。开个玩笑,这就是一张普通的图片。 毫无疑问,上面的那副图画看起来像一幅电脑背景图片。这些都归功于我的妹妹,她能够将一些看上去奇怪的东西变得十分吸引眼球。然而,我...

python实现批量视频分帧、保存视频帧

本篇博客介绍利用python脚本实现视频分帧,并将每一帧保存到本地。主要基于opencv包来实现,在运行代码前确保opencv包已正确安装。下面是主要代码: import os i...

Python判断文件和文件夹是否存在的方法

一、python判断文件和文件夹是否存在、创建文件夹 复制代码 代码如下: >>> import os >>> os.path.exists('d:...