dataset
之前的文章,稍微讲了一下Estimator的用法,也提到Estimator的数据处理使用的是tf.data这两个模块是Tensorflow初学者必须掌握的内容。现在,就让我们从大的概念入手,来慢慢理解tf.data的用法
转载请注明出处
推荐官方文档:https://tensorflow.Google.cn/programmers_guide/datasets
tf.data的作用
在机器学习过程中,对数据的获取、过滤、使用、存储是很重要的一个内容,因为数据可能是不完整的、有杂质的、来源不同的。面对海量数据,我们当然不可能每次都手动整合。Tensorflow框架下,对数据的处理使用的是tf.data,它可以帮助我们以多种方式获取数据、灵活的处理数据和保存数据,使我们能够把更多的精力专注在算法的逻辑上。下面就让我们一起来学习。
tf.data获取数据的方式
这里着着重理解Dataset的概念
Dataset是存储Tensor结构的类,它可以保存一批Tensor结构,以供模型来训练或者测试。这里,Tensor结构是自己定义的,可以有多种格式。
Dataset获取数据的方式有多种,可以从Tensor获取,也可以从另一个Dataset转换而来,我们暂时只讲从Tensor获取。
用到的接口为:
tf.data.Dataset.from_tensor_slices()
这个接口允许我们传递一个或多个Tensor结构给Dataset,因为默认把Tensor的第一个维度作为数据数目的标识,所以要保持数据结构中第一维的一致性,用代码说明一下:
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
这里包含如下信息:
1、该接口可以接受一个字典变量。实际上,该接口接受任何Iterator
2、第一个维度被认为是数据的数量,可以看到,观察数据的shapes的时候,只显示第一维以后的,为什么呢,因为第一维被认为是数据的数量,所以不参与构成shapes
Dataset输出数据的方式
make_one_shot_iterator迭代器
有进就有出,那么数据怎么从Dataset出来呢,代码如下:
dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(10,3))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.session() as sess:
for i in range(10):
value = sess.run(next_element)
print(i, value)
output:
0 [ 0.78891609 0.31016679 -2.22261044]
1 [ 3.06918115 0.14014906 0.86654045]
2 [ 2.08348332 0.57866576 -0.66946627]
3 [-1.28344434 1.96287407 0.70896466]
4 [-1.28056116 -0.65352575 0.39975416]
5 [-0.70007014 -0.94034185 1.02308444]
6 [ 0.70819506 -0.56918389 0.75509558]
7 [ 0.26925763 -0.18980865 -0.90350774]
8 [ 1.45644465 -1.13308881 -0.37863782]
9 [ 0.4485246 -0.48737583 -0.40142893]
这里,我们先用numpy生成随机数据并储存在Dataset,之后,是用了dataset.make_one_shot_iterator()
迭代器来读取数据。one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了。
这就构成了数据进出的一种方式,下面,我们多了解几种数据输出的迭代器
make_initializable_iterator 迭代器
可初始化迭代器允许Dataset中存在占位符,这样可以在数据需要输出的时候,再进行feed操作。实验代码如下:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})#需要取数据的时候才将需要的参数feed进去
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})#feed了不同的参数
for i in range(100):
value = sess.run(next_element)
assert i == value
reinitializable 迭代器
这个迭代器构造方式是根据数据的shapes和type,所以只要shapes和type相同,就可以接受不同的数据源来进行初始化,且可以反复初始化,见以下代码:
# define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
with tf.Session() as sess:
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):#可以重复初始化
# Initialize an iterator over the training dataset.
sess.run(training_init_op)#每次初始化的数据可以不同
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)#初始化了另一组数据
for _ in range(50):
sess.run(next_element)
Iterator.from_string_handle 迭代器
可以看到,reinitializable 已经具有较强的灵活性了,但是它还是每次加载数据都需要重新初始化,有没有可能省掉这一步呢,是可以的,Iterator.from_string_handle通过feed初始化句柄的方式,取得了更高的灵活性,代码如下
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
从Dataset初始化Dataset
讲完了如何读取数据,我们再回过头来讲获取数据的另一种方法:从Dataset获取。
之所以这么安排,是因为要结合输出才能理解这种获取方法的意义。
为了从dataset中初始化,这里有三个接口:
Dataset.map
Dataset.flat_map
Dataset.filter
这三个接口从字面上就很好理解,map就是对于给定Dataset中的每一个元素,都执行一次map操作,而flat_map就是既执行了map,还对数据进行了一次扁平化,也就是降维,而filter就是进行了一次过滤, 我们直接从代码的角度看以看这三个接口怎么用。
代码虽然很大一段,需要理解的东西却很少, 首先,我们定义了一个3*2*3的随机多维数组,可以看到
正常的输出就是输出3个2*3的数组
而map的作用我这里写的是各数加一,所以输出的是3个2*3的各数加一的数组
flat_map的作用是降维,所以输出的是6个1*3的数组
filter的作用是过滤,所以输出的是我的过滤内容:[0][0]元素大于0.8的,由于最有一个数组其元素为0.76103773,被过滤掉了
with tf.Session() as sess:
np.random.seed(0)#持有种子,使得每次随机出来的数组是一样的
normal_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3))
np.random.seed(0)
map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).map(map_func=lambda x:x+1)#各数加一
np.random.seed(0)
flat_map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).flat_map(map_func=lambda x:tf.data.Dataset.from_tensor_slices(x))#输出的还是原来的x,但是降维了
np.random.seed(0)
filter_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).filter(lambda x:x[0][0] > 0.8)#进行了一次过滤
iterator1 = tf.data.Iterator.from_structure(normal_dataset.output_types,
normal_dataset.output_shapes)
iterator2 = tf.data.Iterator.from_structure(map_dataset.output_types,
map_dataset.output_shapes)
iterator3 = tf.data.Iterator.from_structure(flat_map_dataset.output_types,
flat_map_dataset.output_shapes)
iterator4 = tf.data.Iterator.from_structure(filter_dataset.output_types,
filter_dataset.output_shapes)
next_element1 = iterator1.get_next()
next_element2 = iterator2.get_next()
next_element3 = iterator3.get_next()
next_element4 = iterator4.get_next()
training_init_op1 = iterator1.make_initializer(normal_dataset)
training_init_op2 = iterator2.make_initializer(map_dataset)
training_init_op3 = iterator3.make_initializer(flat_map_dataset)
training_init_op4 = iterator4.make_initializer(filter_dataset)
print("normal:")
sess.run(training_init_op1)
for _ in range(3):
print(sess.run(next_element1))
print("map:")
sess.run(training_init_op2)
for _ in range(3):
print(sess.run(next_element2))
print("falt_map:")
sess.run(training_init_op3)
for _ in range(6):
print(sess.run(next_element3))
print("filter:")
sess.run(training_init_op4)
for _ in range(2):
print(sess.run(next_element4))
output:
normal:
[[ 1.76405235 0.40015721 0.97873798]
[ 2.2408932 1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885]
[ 0.4105985 0.14404357 1.45427351]]
[[ 0.76103773 0.12167502 0.44386323]
[ 0.33367433 1.49407907 -0.20515826]]
map:
[[ 2.76405235 1.40015721 1.97873798]
[ 3.2408932 2.86755799 0.02272212]]
[[ 1.95008842 0.84864279 0.89678115]
[ 1.4105985 1.14404357 2.45427351]]
[[ 1.76103773 1.12167502 1.44386323]
[ 1.33367433 2.49407907 0.79484174]]
#各数加一了
falt_map:
[ 1.76405235 0.40015721 0.97873798]
[ 2.2408932 1.86755799 -0.97727788]
[ 0.95008842 -0.15135721 -0.10321885]
[ 0.4105985 0.14404357 1.45427351]
[ 0.76103773 0.12167502 0.44386323]
[ 0.33367433 1.49407907 -0.20515826]
#这里降维成一维数组了
filter:
[[ 1.76405235 0.40015721 0.97873798]
[ 2.2408932 1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885]
[ 0.4105985 0.14404357 1.45427351]]
#最后一个被过滤掉了
到这里,就基本讲述了一下Dataset的输入输出方法,篇幅有限,这篇博文就到这里,之后会另开一篇,写一写数据的消费等更高级的操作!
相关阅读
一、为什么要用require.js? 最早的时候,所有Javascript代码都写在一个文件里面,只要加载这一个文件就够了。后来,代码越来越多,一个文
MONTHS_BETWEEN (date1, date2) 用于计算date1和date2之间有几个月。 如果date1在日历中比date2晚,那么MONTHS_BET
1.判断值为null (null表示空值并不是空字符串,有区别的) <c:if test="${ empty var.ASSESSNOTE_2ND}"></c:if> 2.判断值不等
函数: stringObject.substring(start,stop) 参数: start 必需。一个非负的整数,规定要提取的子串的第一个字符在 stringObject 中
1)如何查看本机所开端口: 用netstat -an命令查看!再stat下面有一些英文,我来简单说一下这些英文具体都代表什么~ LISTEN:侦听来自远