Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com5年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

python存储16bit和32bit图像的实例

笔记:python中存储16bit和32bit图像的方法。 说明:主要是利用scipy库和pillow库,比较其中的不同。 ''' 测试16bit和32bit图像的python存储方...

解决在pycharm中显示额外的 figure 窗口问题

解决在pycharm中显示额外的 figure 窗口问题

问题描述 在电脑中重新安装Anaconda3&PyCharm后,运行原来的程序画图时出现了下图界面。 不能弹出如下图所示的“figure”窗口。 解决方法: 这是因为PyCharm在...

利用Python自动监控网站并发送邮件告警的方法

前言 因为有一些网站需要每日检查是否有问题,所以需要一个报警监控的机制,这个需要你指定你发送的邮箱和你接收的邮箱,就可以做到对网站自动监控了。 这里用的是python3.5 需要安装的插...

python实现三次样条插值

python实现三次样条插值

本文实例为大家分享了python实现三次样条插值的具体代码,供大家参考,具体内容如下 函数: 算法分析 三次样条插值。就是在分段插值的一种情况。 要求: 在每个分段区间上是三次多...

Python操作rabbitMQ的示例代码

Python操作rabbitMQ的示例代码

引入 RabbitMQ 是一个由 Erlang 语言开发的 AMQP 的开源实现。 rabbitMQ是一款基于AMQP协议的消息中间件,它能够在应用之间提供可靠的消息传输。在易用性,扩展...