pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com5年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python增加矩阵维度的实例讲解

numpy.expand_dims(a, axis) Examples >>> x = np.array([1,2]) >>> x.shape...

Python数据结构之栈、队列的实现代码分享

Python数据结构之栈、队列的实现代码分享

1. 栈 栈(stack)又名堆栈,它是一种运算受限的线性表。其限制是仅允许在表的一端进行插入和删除运算。这一端被称为栈顶,相对地,把另一端称为栈底。向一个栈插入新元素又称作进栈、入栈或...

python 常用的基础函数

Python: 1. print()函数:打印字符串 2. raw_input()函数:从用户键盘捕获字符 3. len()函数:计算字符长度 4. format(12.3654,'6....

如何将你的应用迁移到Python3的三个步骤

Python 2.x 很快就要 失去官方支持 了,尽管如此,从 Python 2 迁移到 Python 3 却并没有想象中那么难。我在上周用了一个晚上的时间将一个 3D 渲染器的前端代码...

python类的继承实例详解

python 类的继承 对于许多文章讲解python类的继承,大多数都是说一些什么oop,多态等概念,我认为这样可能对有一定基础的开发者帮助不是那么大,不如直接用在各种情况下所写的代码,...