TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
1、TFRecords
在训练过程中,基本都是使用GPU来计算,但是取一个一个batch取数据还是必须要用cpu,这个过程耗费时间也会影响训练时间,制作TFRecords可以有效解决这个问题,此外制作TFRecords数据可以更好的管理存储数据
为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法,tf.Example类是一种将数据表示为{“string”: value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据
通常情况下,tf.Example中可以使用以下几种格式:
- tf.train.BytesList: 可以使用的类型包括 string和byte
- tf.train.FloatList: 可以使用的类型包括 float和double
- tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
TFRecords是TensorFlow官方推荐的
2、转化示例
def _bytes_feature(value):
"""Returns a bytes_list from a string/byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Return a float_list form a float/double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Return a int64_list from a bool/enum/int/uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
定义3个函数分别对3种类型的数据进行转换成对应的TensorFlow的数据格式
# tf.train.BytesList
print(_bytes_feature(b'test_string'))
print(_bytes_feature('test_string'.encode('utf8')))
# tf.train.FloatList
print(_float_feature(np.exp(1)))
# tf.train.Int64List
print(_int64_feature(True))
print(_int64_feature(1))
传进几个numpy格式的数据,再调用上面的函数进行转换,再打印:
bytes_list { value: “test_string” }
bytes_list { value: “test_string” }
float_list { value: 2.7182817459106445 }
int64_list { value: 1 }
int64_list { value: 1 }
3、TFRecords制作方法
def serialize_example(feature0, feature1, feature2, feature3):
"""
创建tf.Example
"""
# 转换成相应类型
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
#使用tf.train.Example来创建
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
#SerializeToString方法转换为二进制字符串
return example_proto.SerializeToString()
- 定义一个函数,传入4个参数
- 使用前面定义的函数对4个参数分别转换成相应的格式
- 构建Example将转换完的数据创建一条数据
- 序列化 tf.Example:返回一个二进制的字符串
n_observations = int(1e4)
feature0 = np.random.choice([False, True], n_observations)
feature1 = np.random.randint(0, 5, n_observations)
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
feature3 = np.random.randn(n_observations)
- 定义一个一万备用
- 随机选择一万个布尔数据
- 随机选择一万个0、1、2、3、4这5个整数
- 随机构造字符串
- 随机构造浮点数
filename = 'tfrecord-1'
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
- 定义文件名
- 定义一个写的模块,传进文件名,写入数据
- 迭代一万次
- 按照零到一万的索引,分别传入上面构造的4个特征
- 写入数据
这段代码执行后,会得到一个名为tfrecord-1的文件:
4、加载tfrecord文件
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
打印结果: