详解tensorflow实现迁移学习实例

yipeiwu_com6年前Python基础

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python提取字典key列表的方法

本文实例讲述了python提取字典key列表的方法。分享给大家供大家参考。具体如下: 这段代码可以把字典的所有key输出为一个数组 d2 = {'spam': 2, 'ham': 1...

Python中使用hashlib模块处理算法的教程

Python的hashlib提供了常见的摘要算法,如MD5,SHA1等等。 什么是摘要算法呢?摘要算法又称哈希算法、散列算法。它通过一个函数,把任意长度的数据转换为一个长度固定的数据串(...

记录Python脚本的运行日志的方法

一、logging模块 Python中有一个模块logging,可以直接记录日志 # 日志级别 # CRITICAL 50 # ERROR 40 # WARNING 30 #...

Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例

Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例

本文实例讲述了Python2.7基于笛卡尔积算法实现N个数组的排列组合运算。分享给大家供大家参考,具体如下: 说明:本人前段时间遇到的求n个数组的所有排列组合的问题,发现笛卡尔积算法可以...

使用k8s部署Django项目的方法步骤

接触了一下docker和k8s,感觉是非常不错的东西。能够方便的部署线上环境,而且还能够更好的利用机器的资源,感觉是以后的大趋势。最近刚好有一个基于django的项目,所以就把这个项目打...