tensorflow之tf.record如何实现存浮点数数组 tensorflow之tf.record实现存浮点数数组代码

作者:袖梨 2020-02-17

本篇文章小编给大家分享一下tensorflow之tf.record实现存浮点数数组代码,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。

原因:数据存入tf.record,转为二进制也就是使用来tobytes()函数,再将数据存入tf.record,浮点数以二进制存入会有精度丢失问题。

如何将浮点数组存进tf.record:简单记录

import tensorflow as tf
import numpy as np
 
def _floats_feature(value):
  #这里的value=后面没有括号
  #千万不要写成return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
# data you would like to save, dtype=float32
#这里我生成了一个浮点数数组,来假定作为我的数据
data = np.random.randn(shape=(5, 5))
#这里一定要铺平,不然存不进去
data = data.flatten()
 
 
# open tfrecord file
writer = tf.python_io.TFRecordWriter(train_data_path)
 
# make train example
example = tf.train.Example(features=tf.train.Features(
  feature={'data': _floats_feature(data)}))
 
# write on the file
writer.write(example.SerializeToString())

这就是存数据了,下一步读取数据,一定要注意将原来铺平的数据reshape为原来的形状。

# open tfrecorder reader
reader = tf.TFRecordReader()
 
# read file
_, serialized_example = reader.read(filename_queue)
 
# read data
features = tf.parse_single_example(serialized_example,
  features={'data': tf.VarLenFeature(tf.float32)})
 
# make it dense tensor
data = tf.sparse_tensor_to_dense(features['data'], default_value=0)
 
# reshape
data = tf.reshape(data, [5,5])
 
return tf.train.batch(data, batch_size, num_threads, capacity)

相关文章

精彩推荐