PyTorch加载预训练模型实例(pretrained)

yipeiwu_com6年前Python基础

使用预训练模型的代码如下:

# 加载预训练模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 # 读取参数
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 将pretained_dict里不属于model_dict的键剔除掉
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 # 更新现有的model_dict
 model_dict.update(pretrained_dict)

 # 加载真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上这篇PyTorch加载预训练模型实例(pretrained)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

基于Python数据可视化利器Matplotlib,绘图入门篇,Pyplot详解

基于Python数据可视化利器Matplotlib,绘图入门篇,Pyplot详解

Pyplot matplotlib.pyplot是一个命令型函数集合,它可以让我们像使用MATLAB一样使用matplotlib。pyplot中的每一个函数都会对画布图像作出相应的改变,...

详细分析python3的reduce函数

详细分析python3的reduce函数

reduce() 函数在 python 2 是内置函数, 从python 3 开始移到了 functools 模块。 官方文档是这样介绍的 reduce(...) reduce(fu...

Python multiprocessing模块中的Pipe管道使用实例

multiprocessing.Pipe([duplex]) 返回2个连接对象(conn1, conn2),代表管道的两端,默认是双向通信.如果duplex=False,conn1只能...

使用TensorFlow实现二分类的方法示例

使用TensorFlow实现二分类的方法示例

使用TensorFlow构建一个神经网络来实现二分类,主要包括输入数据格式、隐藏层数的定义、损失函数的选择、优化函数的选择、输出层。下面通过numpy来随机生成一组数据,通过定义一种正负...

django中使用事务及接入支付宝支付功能

django中使用事务及接入支付宝支付功能

之前一直想记录一下在项目中使用到的事务以及支付宝支付功能,自己一直犯懒没有完,趁今天有点兴致,在这记录一下。 商城项目必备的就是支付订单的功能,所以就会涉及到订单的保存以及支付接口的引入...