pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

将Pytorch模型从CPU转换成GPU的实现方法

最近将Pytorch程序迁移到GPU上去的一些工作和思考 环境:Ubuntu 16.04.3 Python版本:3.5.2 Pytorch版本:0.4.0 0. 序言 大家知道,在深度学...

Python简单实现两个任意字符串乘积的方法示例

本文实例讲述了Python简单实现两个任意字符串乘积的方法。分享给大家供大家参考,具体如下: 题目: 给定两个任意数字组成的字符串,求乘积,字符可能很大,但是python具有无限精度的整...

Python探索之SocketServer详解

SocketServer,网络通信服务器,是Python标准库中的一个模块,其作用是创建网络服务器。SocketServer模块定义了一些类来处理诸如TCP、UDP、UNIX流和UNIX...

用TensorFlow实现戴明回归算法的示例

用TensorFlow实现戴明回归算法的示例

如果最小二乘线性回归算法最小化到回归直线的竖直距离(即,平行于y轴方向),则戴明回归最小化到回归直线的总距离(即,垂直于回归直线)。其最小化x值和y值两个方向的误差,具体的对比图如下图。...

Python2与Python3的区别实例分析

本文实例讲述了Python2与Python3的区别。分享给大家供大家参考,具体如下: python2与python3的区别 1、性能 2、编码格式utf-8 3、打印语句变成了打印函数...