pytorch 实现在预训练模型的 input上增减通道

yipeiwu_com5年前Python基础

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务

#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))
 
#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))
 
 
#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

以上这篇pytorch 实现在预训练模型的 input上增减通道就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python3.6.2调用ffmpeg的方法

本文是为了学习python调用C语言的库写的例子。 去ffmpeg官网下载编译好的avcodec-57.dll、avutil-55.dll、swresample-2.dll,准备好了C语...

python编程线性回归代码示例

python编程线性回归代码示例

 用python进行线性回归分析非常方便,有现成的库可以使用比如:numpy.linalog.lstsq例子、scipy.stats.linregress例子、pandas.o...

Python读csv文件去掉一列后再写入新的文件实例

用了两种方式解决该问题,都是网上现有的解决方案。 场景说明: 有一个数据文件,以文本方式保存,现在有三列user_id,plan_id,mobile_id。目标是得到新文件只有mobil...

python使用tornado实现登录和登出

本文实例为大家分享了tornado实现登录和登出的具体代码,供大家参考,具体内容如下 main.py如下: import tornado.httpserver import torn...

Python中dictionary items()系列函数的用法实例

本文实例讲述了Python中dictionary items()系列函数的用法,对Python程序设计有很好的参考借鉴价值。具体分析如下: 先来看一个示例: import html...