tensorflow中 batch, shuffle, repeat

该博客展示了如何利用TensorFlow 1.14.0版本创建数据集,通过设置缓冲区对数据进行随机洗牌,并进行批量处理。示例中,将一个numpy数组转换为TensorFlow数据集,使用shuffle和batch方法进行预处理,再通过repeat创建多个epoch。最后,使用迭代器获取并打印数据批次。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

notice:tensorflow 1.14.0版本

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

d = np.arange(0, 60).reshape([6, 10])

data = tf.data.Dataset.from_tensor_slices(d) // numpy 转为 tensor

#从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,
# 此时会再次打乱
data = data.shuffle(buffer_size=3)

## 每次从buffer中抽取4个样本
data = data.batch(4)

# 将data数据集重复,其实就是2个epoch数据集
data = data.repeat(2)

# 构造获取数据的迭代器
iters = data.make_one_shot_iterator()

# 每次从迭代器中获取一批数据
batch = iters.get_next()

sess = tf.Session()
while True:
    try:
        print(sess.run(batch))
    except tf.errors.OutOfRangeError:
        break

res:

[[20 21 22 23 24 25 26 27 28 29]
 [ 0  1  2  3  4  5  6  7  8  9]
 [10 11 12 13 14 15 16 17 18 19]
 [50 51 52 53 54 55 56 57 58 59]]
[[40 41 42 43 44 45 46 47 48 49]
 [30 31 32 33 34 35 36 37 38 39]]

[[ 0  1  2  3  4  5  6  7  8  9]
 [20 21 22 23 24 25 26 27 28 29]
 [30 31 32 33 34 35 36 37 38 39]
 [10 11 12 13 14 15 16 17 18 19]]
[[50 51 52 53 54 55 56 57 58 59]
 [40 41 42 43 44 45 46 47 48 49]]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值