读取数据共有3种方法:

  • 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据
  • 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)
  • 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据

参考资料: I/O Op:https://www.tensorflow.org/api_guides/python/io_ops

Feeding机制

TF的feeding机制允许将数据写入到计算图中的任意Tensor中,甚至包括Variable和Constants,一般使用tf.placeholder节点作为feed的目标。计算时,tf.placeholder如果没有被赋值,则会报错。

Feeding机制通过在run()eval()调用中使用feed_dict={}参数修改计算图中的Tensor。

with tf.Session():
  input = tf.placeholder(tf.float32)
  classifier = ...
  print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))

加载所有数据

使用tf.constant加载所有数据时,数据会拷贝到计算图中,浪费内存。使用tf.Variable,在计算时feed数据,可以减少这部分内存浪费。

然后,使用tf.train.slice_input_producertf.train.batch来读取一个batch数据,这两个Op都使用Queue实现,并添加各自的QueueRunner到Graph's QUEUE_RUNNER集合中。

  • tf.train.slice_input_producer:将Tensor切分成不含第一维的单个Tensor (即[n,a,b,c]变成n个[a,b,c]),可以shuffle
  • tf.train.batch:合并batch个Tensor,生成一个大Tensor

参考:

从文件读取

results matching ""

    No results matching ""