TensorFlow 2.0高效开发指南
Effective TensorFlow 2.0
为使 TensorFLow 用户更高效,TensorFlow 2.0 中进行了多出更改。TensorFlow 2.0 删除了篇冗余 API,使 API 更加一致(统一 RNNs, 统一优化器),并通过Eager execution更好地与 Python 集成。
许多 RFCs 已经解释了 TensorFlow 2.0 带来的变化。本指南介绍了 TensorFlow 2.0 应该怎么进行开发。这假设您已对 TensorFlow 1.x 有一定了解。
A brief summary of major changes
API Cleanup
许多 API 在 TF 2.0 中进行了移动或删除。一些主要的变化包括删除tf.app
,tf.flags
,使tf.logging
支持现在开源的 absl-py,重新生成项目的tf.contribe
,通过清理tf.*
中那些较少使用的命名空间,例如tf.math
。一些 API 已替换为自己的 2.0 版本-tf.summary
,tf.keras.metrics
, 和tf.keras.optimizers
。最快升级应用这些重命名带来的变化可使用v2 升级脚本。
Eager execution
TensorFlow 1.x 要求用户通过tf.*
API 手动的将抽象语法树(图)拼接在一起。然后它要求用户通过一组输入、输出张量传递给session.run()
从而手动编译调用这个图。TensorFlow 2.0 Eager execution 可以像 Python 那样执行,在 2.0 中,graph 和 session 会像实现细节一样。
值得注意的是tf.control_dependencies()
不再需要了,因为所有代码都是行顺序执行的(用tf.function
声明)。
No more globals
TensorFlow 1.x 严重依赖隐式全局命名空间。当你调用tf.Variable()
,它会被放入默认图中,即使你忘了指向它的 Python 变量,它也会被保留在那里。然后你可以恢复它,但前提是你得知道它创建时的名称。如果你无法控制变量的创建,这很难做到。其结果是,各种各样的机制,试图帮助用户再次找到他们的变量,以及为框架找到用户创建的变量:Variable scopes, global collections。例如tf.get_global_step()
,tf.global_variables_initializer()
,还有优化器隐式计算所有可训练变量的梯度等等。
TensorFlow 2.0 消除了这些机制(Variable 2.0 RFC)默认支持的机制:跟踪你的变量!如果你忘记了一个tf.Variable
,它就会当作垃圾被回收。
Functions, not sessions
session.run()
几乎可以像函数一样调用:指定输入和被调用的函数,你可以得到一组输出。在 TensorFlow 2.0 中,您可以使用 Python 函数tf.function()
来标记它以进行 JIT 编译,以便 TensorFlow 将其作为单个图运行(Function 2.0 RFC)。这种机制允许 TensorFlow 2.0 获得图模型所有的好处:
- 性能:函数可以被优化(node pruning, kernel fusion, etc.)
- 可移植性:该功能可以被导出/重新导入(SavedModel 2.0 RFC),允许用户重用和共享模块化 TensorFlow 功能。
1 | # TensorFlow 1.X |
凭借穿插 Python 和 TensorFlow 代码的能力,我们希望用户能够充分利用 Python 的表现力。除了在没有 Python 解释器的情况下执行 TensorFlow,如 mobile, C++, 和 JS。为了帮助用户避免在添加时重写代码@tf.function
, AutoGraph会将 Python 构造的一个子集转换为他们的 TensorFlow 等价物:
for
/while
->tf.while_loop
(支持 break 和 continue)if
->tf.cond
for _ in dataset
->dataset.reduce
AutoGraph 支持控制流的任意嵌套,这使得可以有较好性能并且简洁地实现许多复杂的 ML 程序,如序列模型,强化学习,自定义训练循环等。
Recommendations for idiomatic TensorFlow 2.0
Refactor your code into smaller functions
TensorFlow 1.x 中常见使用模式是“kitchen sink”策略,其中所有可能的计算的联合被预先布置,然后选择被评估的张量,通过session.run()
运行。在 TensorFlow 2.0 中,用户应该将代码重构为较小的函数,这些函数根据需要被调用。通常,没有必要用tf.function
去装饰那些比较小的函数;仅用tf.function
去装饰高等级的计算,例如,训练的一个步骤,或模型的前向传递。
Use Keras layers and models to manage variables
Keras 模型和图层提供了方便 variables 和 trainable_variables 属性,它以递归方式收集所有因变量。这使得在本地管理变量非常容易。
对比:
1 | def dense(x, W, b): |
Keras 版本:
1 | # 可以调用每个图层,其签名等效于 linear(x) |
Keras layers/models 继承自tf.train.Checkpointable
并集成了@tf.function
,这使得直接从 Keras 对象导出 SavedModels 或 checkpoint 成为可能。您不一定要使用 Keras 的.fit
API 来利用这些集成。
这是一个迁移学习的例子,演示了 Keras 如何轻松收集相关变量的子集。假设你正在训练一个带有共享主干的多头模型:
1 | trunk = tf.keras.Sequential([...]) |
Combine tf.data.Datasets and @tf.function
在内存中迭代拟合训练数据时,可以随意使用常规的 Python 迭代。或者,tf.data.Dataset
是从硬盘读取训练数据流的最好方法。Datasets 是可迭代的(不是迭代器),它可以像在 Eager 模式下的其他 Python 迭代一样工作。您可以通过用tf.function()
包装代码来充分利用数据集异步预取/流功能,这将使用 AutoGraph 等效的图操作替换 Python 的迭代。
1 |
|
如果您使用 Keras.fit()
API,则无需担心数据集迭代。
1 | model.compile(optimizer=optimizer, loss=loss_fn) |
Take advantage of AutoGraph with Python control flow
AutoGraph 提供了一种将依赖于数据的控制流转换为等效图形模式的方法,如tf.cond
和tf.while_loop
。
数据相关控制流出现的一个常见位置是序列模型。tf.keras.layers.RNN
包装了一个 RNN cell,允许您既可以静态也可以动态的循环展开。为了演示,您可以重新实现动态展开,如下所示:
1 | class DynamicRNN(tf.keras.Model): |
有关 AutoGraph 功能的更详细概述,请参阅指南
Use tf.metrics to aggregate data and tf.summary to log it
要记录摘要,请使用tf.summary.(scalar|histogram|...)
上下文管理器将其重定向到编写器。(如果省略上下文管理器,则不会发生任何事情。)与 TF 1.x 不同,摘要直接发送给编写器; 没有单独的“合并”操作,也没有单独的 add_summary()调用,这意味着 step 必须在调用点提供该值。
1 | summary_writer = tf.summary.create_file_writer('/tmp/summaries') |
要在将数据记录为摘要之前聚合数据,请使用tf.metrics
。Metrics 是有状态的;它们积累值并在您调用.result()
时返回结果。清除积累值,请使用.reset_states()
。
1 | def train(model, optimizer, dataset, log_freq=10): |
通过将 TensorBoard 指向摘要日志目录来可视化生成的摘要:tensorboard --logdir /tmp/summaries
。
- 欢迎关注我的公众号,一起学习!
本文作者 : HeoLis
原文链接 : https://ishero.net/TensorFlow%202.0%E9%AB%98%E6%95%88%E5%BC%80%E5%8F%91%E6%8C%87%E5%8D%97.html
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!
学习、记录、分享、获得