pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

微信跳一跳python辅助软件思路及图像识别源码解析

微信跳一跳python辅助软件思路及图像识别源码解析

本文将梳理github上最火的wechat_jump_game的实现思路,并解析其图像处理部分源码 首先废话少说先看效果 核心思想 获取棋子到下一个方块的中心点的距离 计算触摸屏...

pygame学习笔记(5):游戏精灵

pygame学习笔记(5):游戏精灵

据说在任天堂FC时代,精灵的作用相当巨大,可是那时候只知道怎么玩超级玛丽、魂斗罗,却对精灵一点也不知。pygame.sprite.Sprite就是Pygame里面用来实现精灵的一个类,使...

python计算日期之间的放假日期

本文实例为大家分享了python计算日期之间的放假日期,供大家参考,具体内容如下 代码如下: #encoding=utf-8 print '中国' #自动查询节日 给定...

浅析Python数字类型和字符串类型的内置方法

一、数字类型内置方法 1.1 整型的内置方法 作用 描述年龄、号码、id号 定义方式 x = 10 x = int('10') x = int(10.1) x = int('10...

Python学习笔记基本数据结构之序列类型list tuple range用法分析

本文实例讲述了Python学习笔记基本数据结构之序列类型list tuple range用法。分享给大家供大家参考,具体如下: list 和 tuple list:列表,由 []...