TensorFlow版本:1.14.0
Python版本:3.6.8
TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据。
TFRecord 内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。
把一个一维的numpy数组保存为TFRecord文件很容易,但如果numpy数组是二维的可能就比较容易写错。下面是一个例子。
- """
- 本程序演示了如何保存numpy array为TFRecords文件,并将其读取出来。
- """
- import random
- import numpy as np
- import tensorflow as tf
- __author__ = 'Darran Zhang @ codelast.com'
- def save_tfrecords(state_data, action_data, reward_data, dest_file):
- """
- 保存numpy array到TFRecord文件中。
- 这里输入了三个不同的numpy array来做演示,它们含有不同类型的元素。
- Args:
- state_data: 要保存到TFRecord文件的第1个numpy array,每一个 state_data[i] 是一个 numpy.ndarray(数组里的每个元素又是一个浮点
- 数),因此不能用 Int64List 或 FloatList 来存储,只能用 BytesList。
- action_data: 要保存到TFRecord文件的第2个numpy array,每一个 action_data[i] 是一个整数,使用 Int64List 来存储。
- reward_data: 要保存到TFRecord文件的第3个numpy array,每一个 reward_data[i] 是一个整数,使用 Int64List 来存储。
- dest_file: 输出文件的路径。
- Returns:
- 不返回任何值
- """
- with tf.io.TFRecordWriter(dest_file) as writer:
- for i in range(len(state_data)):
- features = tf.train.Features(
- feature={
- "state": tf.train.Feature(
- bytes_list=tf.train.BytesList(value=[state_data[i].astype(np.float32).tostring()])),
- "action": tf.train.Feature(
- int64_list=tf.train.Int64List(value=[action_data[i]])),
- "reward": tf.train.Feature(
- int64_list=tf.train.Int64List(value=[reward_data[i]]))
- }
- )
- tf_example = tf.train.Example(features=features)
- serialized = tf_example.SerializeToString()
- writer.write(serialized)
- def parse_fn(example_proto):
- features = {"state": tf.FixedLenFeature((), tf.string),
- "action": tf.FixedLenFeature((), tf.int64),
- "reward": tf.FixedLenFeature((), tf.int64)}
- parsed_features = tf.parse_single_example(example_proto, features)
- return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']
- if __name__ == '__main__':
- buffer_s, buffer_a, buffer_r = [], [], []
- # 随机生成一些数据
- for i in range(3):
- state = [round(random.random() * 100, 2) for _ in range(0, 10)] # 一个数组,里面有10个数,每个都是一个浮点数
- action = random.randrange(0, 2) # 一个数,值为 0 或 1
- reward = random.randrange(0, 100) # 一个数,值域 [0, 100)
- # 把生成的数分别添加到3个list中
- buffer_s.append(state)
- buffer_a.append(action)
- buffer_r.append(reward)
- # 查看生成的数据
- print(buffer_s)
- print(buffer_a)
- print(buffer_r)
- # 在水平方向把各个list堆叠起来,堆叠的结果:得到3个矩阵
- s_stacked = np.vstack(buffer_s)
- a_stacked = np.vstack(buffer_a)
- r_stacked = np.vstack(buffer_r)
- print(s_stacked.shape) # (3, 10)
- print(a_stacked.shape) # (3, 1)
- print(r_stacked.shape) # (3, 1)
- # 写入TFRecord文件
- output_file = './data.tfrecord' # 输出文件的路径
- save_tfrecords(s_stacked, a_stacked, r_stacked, output_file)
- # 读取TFRecord文件并打印出其内容
- for example in tf.io.tf_record_iterator(output_file):
- print(tf.train.Example.FromString(example))
- # 或者用下面的方法
- # from google.protobuf.json_format import MessageToJson
- # jsonMessage = MessageToJson(tf.train.Example.FromString(example))
- # print(jsonMessage)
- # 读取TFRecord文件并还原成numpy array,再打印出来
- with tf.Session() as sess:
- dataset = tf.data.TFRecordDataset(output_file) # 加载TFRecord文件
- dataset = dataset.map(parse_fn) # 解析data到Tensor
- dataset = dataset.repeat(1) # 重复N epochs
- dataset = dataset.batch(3) # batch size
- iterator = dataset.make_one_shot_iterator()
- next_data = iterator.get_next()
- while True:
- try:
- state, action, reward = sess.run(next_data)
- print(state)
- print(action)
- print(reward)
- except tf.errors.OutOfRangeError:
- break
注意:对二维数组,需要用 tf.train.BytesList 来保存,还原成numpy array的时候,要用 tf.decode_raw() 来解析。
由于生成的数据是随机数,因此你看到的输出会和我不一样。
文章来源:https://www.codelast.com/
生成的数据:
[[56.31, 8.72, 78.21, 44.52, 98.18, 95.23, 85.89, 95.76, 63.96, 41.56], [21.78, 66.52, 17.58, 35.36, 29.25, 63.54, 49.12, 82.71, 77.38, 20.04], [65.86, 78.81, 17.64, 3.21, 60.88, 92.98, 80.63, 92.86, 80.7, 4.12]][1, 0, 1][55, 97, 89]
numpy数组写成TFRecord后再重新读取出来,并重新转成numpy数组后,数据是:
[[56.31 8.72 78.21 44.52 98.18 95.23 85.89 95.76 63.96 41.56][21.78 66.52 17.58 35.36 29.25 63.54 49.12 82.71 77.38 20.04][65.86 78.81 17.64 3.21 60.88 92.98 80.63 92.86 80.7 4.12]][1 0 1][55 97 89]
可见数据和生成的一样,这说明上面的程序互相转没有问题。
文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤
转载需注明出处:codelast.com
感谢关注我的微信公众号(微信扫一扫):