TensorFlow dataset.shuffle、batch、repeat的使用详解

yipeiwu_com6年前Python基础

直接看代码例子,有详细注释!!

import tensorflow as tf
import numpy as np


d = np.arange(0,60).reshape([6, 10])

# 将array转化为tensor
data = tf.data.Dataset.from_tensor_slices(d)

# 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,
# 此时会再次打乱
data = data.shuffle(buffer_size=3)

# 每次从buffer中抽取4个样本
data = data.batch(4)

# 将data数据集重复,其实就是2个epoch数据集
data = data.repeat(2)

# 构造获取数据的迭代器
iters = data.make_one_shot_iterator()

# 每次从迭代器中获取一批数据
batch = iters.get_next()

sess = tf.Session()

sess.run(batch)
# 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
In [21]: d
Out[21]: 
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
  [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
  [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
  [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
  [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])
In [22]: sess.run(batch)
Out[22]: 
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
  [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
  [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

In [23]: sess.run(batch)
Out[23]: 
array([[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
  [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

从输出结果可以看出:

shuffle是按顺序将数据放入buffer里面的;

当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。

那么,当repeat函数在shuffle之前会怎么样呢?如下:

data = data.repeat(2)

data = data.shuffle(buffer_size=3)

data = data.batch(4)
In [25]: sess.run(batch)
Out[25]: 
array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
  [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
  [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

In [26]: sess.run(batch)
Out[26]: 
array([[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
  [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
  [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])

In [27]: sess.run(batch)
Out[27]: 
array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
  [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
  [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
  [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。

以上这篇TensorFlow dataset.shuffle、batch、repeat的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python写的一个squid访问日志分析的小程序

python写的一个squid访问日志分析的小程序

这两周组里面几位想学习python,于是我们就创建了一个这样的环境和氛围来给大家学习。 昨天在群里,贴了一个需求,就是统计squid访问日志中ip 访问数和url的访问数并排序,不少同学...

python itchat实现微信自动回复的示例代码

今天在实验楼发现一个特别好玩的,Python 微信库itchat,可以实现自动回复等多种功能,好玩到根本停不下来啊,尤其是调戏调戏不懂计算机的,特别有成就感,哈哈!! 代码如下: #...

Python图像处理PIL各模块详细介绍(推荐)

Python图像处理PIL各模块详细介绍(推荐)

 Image模块 Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。如open、save、conver、show…等功能...

Python文件操作之合并文本文件内容示例代码

Python文件操作之合并文本文件内容示例代码

前言 相信大家初入某个项目,一般都要看代码。有时候,想把代码文件打印下来看,不过一般代码文件数量都在两位数或更多,逐一打开、打印,确实太耗费精力了,此外,也会出现某个代码文件打印到纸上只...

np.newaxis 实现为 numpy.ndarray(多维数组)增加一个轴

如下所示: >> type(np.newaxis) NoneType >> np.newaxis == None True np.newaxis 在使用和功...