TensorFlow 模型中的 XLA 集成¶
加速线性代数(Accelerated Linear Algebra,简称 XLA)是用于加速 TensorFlow 模型运行时的编译器。根据官方文档:
XLA 是一种特定领域的线性代数编译器,可以在不改变源代码的情况下加速 TensorFlow 模型。
在 TensorFlow 中使用 XLA 非常简单——它已经包含在 tensorflow 库中,只需通过 jit_compile 参数触发即可。例如,在使用像 tf.function 这样的图创建函数时,或在使用 Keras 的 fit() 和 predict() 方法时,可以通过传递 jit_compile 参数给 model.compile() 来启用 XLA。XLA 并不限于这些方法,还可以加速任意的 tf.function。
一些 🤗 Transformers 库中的 TensorFlow 方法已经被重写为 XLA 兼容,包括用于 GPT2、T5 和 OPT 模型的文本生成,以及用于 Whisper 模型的语音处理。
虽然具体的加速效果取决于模型本身,但对于 🤗 Transformers 库中的 TensorFlow 文本生成模型,我们观察到的速度提升大约为 100 倍。本文将介绍如何使用 XLA 提升这些模型的性能,并提供一些额外资源,帮助你了解更多关于基准测试和 XLA 集成的设计理念。
使用 XLA 运行 TensorFlow 函数¶
让我们以一个简单的 TensorFlow 模型为例:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"),
tf.keras.layers.Dense(5, activation="softmax")
])
上面的模型接收维度为 (10, ) 的输入。我们可以像这样运行前向传播:
# 生成随机输入
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))
# 运行前向传播
_ = model(random_inputs)
要使用 XLA 编译的函数运行前向传播,可以这样做:
xla_fn = tf.function(model, jit_compile=True)
_ = xla_fn(random_inputs) # 使用 XLA 编译的函数运行前向传播
默认情况下,model 的 call() 函数用于编译 XLA 图。如果你要编译其他模型函数,也可以这样做:
my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)
使用 🤗 Transformers 库中的 XLA 运行文本生成¶
要启用 🤗 Transformers 库中的 XLA 加速生成,你需要安装最新版本的 transformers。可以通过以下命令安装:
pip install transformers --upgrade
然后可以运行以下代码:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
# 确保安装了最新的 Transformers 版本
from transformers.utils import check_min_version
check_min_version("4.21.0")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="<</s>>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]
# 创建一个 XLA 生成函数
xla_generate = tf.function(model.generate, jit_compile=True)
tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
你会发现,启用 generate() 的 XLA 支持只需一行代码。不过,有一些需要注意的地方才能真正实现 XLA 带来的加速效果,我们将在下一节中讨论。
注意事项¶
当你第一次执行一个 XLA 启用的函数(如上面的 xla_generate()),它会尝试推断计算图,这个过程需要时间,被称为“追踪”。
你可能会注意到首次生成时速度并不快。然而,后续调用 xla_generate()(或任何其他 XLA 启用的函数)时,如果输入形状与初始构建计算图时一致,则无需重新推断计算图,从而提高生成速度。
对于形状固定的模态(如图像),这不是问题。但对于输入形状变化的模态(如文本),你需要注意这一点。为了确保 xla_generate() 始终使用相同的输入形状,你可以在调用分词器时指定 padding 参数。
例如:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="<</s>>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True)
# 调用分词器时指定填充选项
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
这样,你可以确保 xla_generate() 始终使用相同的输入形状,从而提高生成速度。你可以通过以下代码验证这一点:
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="<</s>>")
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
xla_generate = tf.function(model.generate, jit_compile=True)
for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
start = time.time_ns()
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
end = time.time_ns()
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")
在 Tesla T4 GPU 上,你可能会看到类似以下的输出:
Execution time -- 30819.6 ms
Execution time -- 79.0 ms
Execution time -- 78.9 ms
首次调用 xla_generate() 时由于需要追踪计算图,所以耗时较长,但后续调用则快得多。请注意,任何生成选项的变化都会触发重新追踪,从而导致生成时间变慢。
我们没有涵盖 🤗 Transformers 提供的所有文本生成选项。如果你想了解更多高级用法,请参考官方文档。
进一步资源¶
如果你想深入了解 🤗 Transformers 和 TensorFlow 中的 XLA,以下是一些额外资源:
- 这个 Colab Notebook 提供了一个交互式演示,展示了 XLA 兼容的编码器-解码器(如 T5)和只解码器(如 GPT2)文本生成模型。
- 这篇博客文章 提供了 XLA 兼容模型的基准测试比较,以及关于 TensorFlow 中 XLA 的友好介绍。
- 这篇博客文章 讨论了我们在 🤗 Transformers 中添加 XLA 支持的设计理念。
- 推荐的学习资源: