[原创] 二维numpy数组保存到TFRecord并读取还原回来

TensorFlow版本:1.14.0
Python版本:3.6.8

TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据。
TFRecord 内部有一系列的 Example ,Example 是 protocolbuf 协议下的消息体。

把一个一维的numpy数组保存为TFRecord文件很容易,但如果numpy数组是二维的可能就比较容易写错。下面是一个例子。

  1. """ 
  2. 本程序演示了如何保存numpy array为TFRecords文件,并将其读取出来。 
  3. """  
  4. import random  
  5.   
  6. import numpy as np  
  7. import tensorflow as tf  
  8.   
  9. __author__ = 'Darran Zhang @ codelast.com'  
  10.   
  11.   
  12. def save_tfrecords(state_data, action_data, reward_data, dest_file):  
  13.     """ 
  14.     保存numpy array到TFRecord文件中。 
  15.     这里输入了三个不同的numpy array来做演示,它们含有不同类型的元素。 
  16.     Args: 
  17.         state_data: 要保存到TFRecord文件的第1个numpy array,每一个 state_data[i] 是一个 numpy.ndarray(数组里的每个元素又是一个浮点 
  18.                     数),因此不能用 Int64List 或 FloatList 来存储,只能用 BytesList。 
  19.         action_data: 要保存到TFRecord文件的第2个numpy array,每一个 action_data[i] 是一个整数,使用 Int64List 来存储。 
  20.         reward_data: 要保存到TFRecord文件的第3个numpy array,每一个 reward_data[i] 是一个整数,使用 Int64List 来存储。 
  21.         dest_file: 输出文件的路径。 
  22.     Returns: 
  23.         不返回任何值 
  24.     """  
  25.     with tf.io.TFRecordWriter(dest_file) as writer:  
  26.         for i in range(len(state_data)):  
  27.             features = tf.train.Features(  
  28.                 feature={  
  29.                     "state": tf.train.Feature(  
  30.                         bytes_list=tf.train.BytesList(value=[state_data[i].astype(np.float32).tostring()])),  
  31.                     "action": tf.train.Feature(  
  32.                         int64_list=tf.train.Int64List(value=[action_data[i]])),  
  33.                     "reward": tf.train.Feature(  
  34.                         int64_list=tf.train.Int64List(value=[reward_data[i]]))  
  35.                 }  
  36.             )  
  37.             tf_example = tf.train.Example(features=features)  
  38.             serialized = tf_example.SerializeToString()  
  39.             writer.write(serialized)  
  40.   
  41.   
  42. def parse_fn(example_proto):  
  43.     features = {"state": tf.FixedLenFeature((), tf.string),  
  44.                 "action": tf.FixedLenFeature((), tf.int64),  
  45.                 "reward": tf.FixedLenFeature((), tf.int64)}  
  46.     parsed_features = tf.parse_single_example(example_proto, features)  
  47.     return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']  
  48.   
  49.   
  50. if __name__ == '__main__':  
  51.     buffer_s, buffer_a, buffer_r = [], [], []  
  52.   
  53.     # 随机生成一些数据  
  54.     for i in range(3):  
  55.         state = [round(random.random() * 100, 2) for _ in range(0, 10)]  # 一个数组,里面有10个数,每个都是一个浮点数  
  56.         action = random.randrange(0, 2)  # 一个数,值为 0 或 1  
  57.         reward = random.randrange(0, 100)  # 一个数,值域 [0, 100)  
  58.         # 把生成的数分别添加到3个list中  
  59.         buffer_s.append(state)  
  60.         buffer_a.append(action)  
  61.         buffer_r.append(reward)  
  62.   
  63.     # 查看生成的数据  
  64.     print(buffer_s)  
  65.     print(buffer_a)  
  66.     print(buffer_r)  
  67.   
  68.     # 在水平方向把各个list堆叠起来,堆叠的结果:得到3个矩阵  
  69.     s_stacked = np.vstack(buffer_s)  
  70.     a_stacked = np.vstack(buffer_a)  
  71.     r_stacked = np.vstack(buffer_r)  
  72.   
  73.     print(s_stacked.shape)  # (3, 10)  
  74.     print(a_stacked.shape)  # (3, 1)  
  75.     print(r_stacked.shape)  # (3, 1)  
  76.   
  77.     # 写入TFRecord文件  
  78.     output_file = './data.tfrecord'  # 输出文件的路径  
  79.     save_tfrecords(s_stacked, a_stacked, r_stacked, output_file)  
  80.   
  81.     # 读取TFRecord文件并打印出其内容  
  82.     for example in tf.io.tf_record_iterator(output_file):  
  83.         print(tf.train.Example.FromString(example))  
  84.         # 或者用下面的方法  
  85.         # from google.protobuf.json_format import MessageToJson  
  86.         # jsonMessage = MessageToJson(tf.train.Example.FromString(example))  
  87.         # print(jsonMessage)  
  88.   
  89.     # 读取TFRecord文件并还原成numpy array,再打印出来  
  90.     with tf.Session() as sess:  
  91.         dataset = tf.data.TFRecordDataset(output_file)  # 加载TFRecord文件  
  92.         dataset = dataset.map(parse_fn)  # 解析data到Tensor  
  93.         dataset = dataset.repeat(1)  # 重复N epochs  
  94.         dataset = dataset.batch(3)  # batch size  
  95.   
  96.         iterator = dataset.make_one_shot_iterator()  
  97.         next_data = iterator.get_next()  
  98.   
  99.         while True:  
  100.             try:  
  101.                 state, action, reward = sess.run(next_data)  
  102.                 print(state)  
  103.                 print(action)  
  104.                 print(reward)  
  105.             except tf.errors.OutOfRangeError:  
  106.                 break  

注意:对二维数组,需要用 tf.train.BytesList 来保存,还原成numpy array的时候,要用 tf.decode_raw() 来解析。
由于生成的数据是随机数,因此你看到的输出会和我不一样。
文章来源:https://www.codelast.com/
生成的数据:

[[56.31, 8.72, 78.21, 44.52, 98.18, 95.23, 85.89, 95.76, 63.96, 41.56], [21.78, 66.52, 17.58, 35.36, 29.25, 63.54, 49.12, 82.71, 77.38, 20.04], [65.86, 78.81, 17.64, 3.21, 60.88, 92.98, 80.63, 92.86, 80.7, 4.12]]
[1, 0, 1]
[55, 97, 89]

numpy数组写成TFRecord后再重新读取出来,并重新转成numpy数组后,数据是:

[[56.31  8.72 78.21 44.52 98.18 95.23 85.89 95.76 63.96 41.56]
 [21.78 66.52 17.58 35.36 29.25 63.54 49.12 82.71 77.38 20.04]
 [65.86 78.81 17.64  3.21 60.88 92.98 80.63 92.86 80.7   4.12]]
[1 0 1]
[55 97 89]

可见数据和生成的一样,这说明上面的程序互相转没有问题。

文章来源:https://www.codelast.com/
➤➤ 版权声明 ➤➤ 
转载需注明出处:codelast.com 
感谢关注我的微信公众号(微信扫一扫):

wechat qrcode of codelast

发表评论