pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python和c语言的主要区别总结

python和c语言的主要区别总结

Python可以说是目前最火的语言之一了,人工智能的兴起让Python一夜之间变得家喻户晓,Python号称目前最最简单易学的语言,现在有不少高校开始将Python作为大一新生的入门语言...

致Python初学者 Anaconda入门使用指南完整版

打算学习 Python 来做数据分析的你,是不是在开始时就遇到各种麻烦呢? 到底该装 Python2 呢还是 Python3 ? 为什么安装 Python 时总是出错? 怎么安装工具包呢...

python多线程http压力测试脚本

本文实例为大家分享了python多线程http压力测试的具体代码,供大家参考,具体内容如下 #coding=utf-8 import sys import time import...

Python的函数嵌套的使用方法

例子:复制代码 代码如下:def re_escape(fn):    def arg_escaped(this, *args):  &n...

python调用百度语音识别api

python调用百度语音识别api

最近在处理语音检索相关的事。 其中用到语音识别,调用的是讯飞与百度的api,前者使用js是实现,后者用python3实现(因为自己使用python) 环境: python3.5 ce...