新智元报道
编辑:编辑部
JAX 在最近的基准测试中的性能已经不声不响地超过了 Pytorch 和 TensorFlow,也许未来会有更多的大模型诞生在这个平台上。谷歌在背后的默默付出终于得到了回报。
谷歌力推的 JAX 在最近的基准测试中性能已经超过 Pytorch 和 TensorFlow,7 项指标排名第一。
而且测试并不是在 JAX 性能表现最好的 TPU 上完成的。
虽然现在在开发者中,Pytorch 依然比 Tensorflow 更受欢迎。
但未来,也许有更多的大模型会基于 JAX 平台进行训练和运行。
模型
最近,Keras 团队为三个后端(TensorFlow、JAX、PyTorch)与原生 PyTorch 实现以及搭配 TensorFlow 的 Keras 2 进行了基准测试。
首先,他们为生成式和非生成式人工智能任务选择了一组主流的计算机视觉和自然语言处理模型:
对于模型的 Keras 版本,其采用了 KerasCV 和 KerasNLP 中已有的实现进行构建。而对于原生的 PyTorch 版本,则选择了网络上最流行的几个选项:
- 来自 HuggingFace Transformers 的 BERT、Gemma、Mistral
- 来自 HuggingFace Diffusers 的 StableDiffusion
- 来自 Meta 的 SegmentAnything
他们将这组模型称作「Native PyTorch」,以便与使用 PyTorch 后端的 Keras 3 版本进行区分。
他们对所有基准测试都使用了合成数据,并在所有 LLM 训练和推理中使用了 bfloat16 精度,同时在所有 LLM 训练中使用了 LoRA(微调)。
根据 PyTorch 团队的建议,他们在原生 PyTorch 实现中使用了 torch.compile (model, mode="reduce-overhead")(由于不兼容,Gemma 和 Mistral 训练除外)。
为了衡量开箱即用的性能,他们使用高级 API(例如 HuggingFace 的 Trainer ()、标准 PyTorch 训练循环和 Keras model.fit ()),并尽可能减少配置。
硬件配置
所有基准测试均使用 Google Cloud Compute Engine 进行,配置为:一块拥有 40GB 显存的 NVIDIA A100 GPU、12 个虚拟 CPU 和 85GB 的主机内存。
基准测试结果
表 2 显示了基准测试结果(以步/毫秒为单位)。每步都涉及对单个数据批次进行训练或预测。
结果是 100 步的平均值,但排除了第一个步,因为第一步包括了模型创建和编译,这会额外花费时间。
为了确保比较的公平性,对于相同的模型和任务(不论是训练还是推理)都使用相同的批大小。
然而,对于不同的模型和任务,由于它们的规模和架构有所不同,可根据需要调整数据批大小,从而避免因过大而导致内存溢出,或是批过小而导致 GPU 使用不足。
过小的批大小也会使 PyTorch 看起来较慢,因为会增加 Python 的开销。
对于大型语言模型(Gemma 和 Mistral),测试时也使用了相同的批处理大小,因为它们是相同类型的模型,具有类似数量的参数(7B)。
考虑到用户对单批文本生成的需求,也对批大小为 1 的文本生成情况进行了基准测试。
关键发现
发现1
不存在「最优」后端。
Keras 的三种后端各展所长,重要的是,就性能而言,并没有哪一个后端能够始终胜出。
选择哪个后端最快,往往取决于模型的架构。
这一点突出了选择不同框架以追求最佳性能的重要性。Keras 3 可以帮助轻松切换后端,以便为模型找到最合适的选择。
发现2
Keras 3 的性能普遍超过 PyTorch 的标准实现。
相对于原生 PyTorch,Keras 3 在吞吐量(步/毫秒)上有明显的提升。
特别是,在 10 个测试任务中,有 5 个的速度提升超过了 50%。其中,最高更是达到了 290%。
如果是 100%,意味着 Keras 3 的速度是 PyTorch 的 2 倍;如果是0%,则表示两者性能相当
发现3
Keras 3 提供一流的「开箱即用」性能。
也就是,所有参与测试的 Keras 模型都未进行过任何优化。相比之下,使用原生 PyTorch 实现时,通常需要用户自行进行更多性能优化。
除了上面分享的数据,测试中还注意到在 HuggingFace Diffusers 的 StableDiffusion 推理功能上,从版本 0.25.0 升级到 0.3.0 时,性能提升超过了 100%。
同样,在 HuggingFace Transformers 中,Gemma 从 4.38.1 版本升级至 4.38.2 版本也显著提高了性能。
这些性能的提升凸显了 HuggingFace 在性能优化方面的专注和努力。
对于一些手动优化较少的模型,如 SegmentAnything,则使用了研究作者提供的实现。在这种情况下,与 Keras 相比,性能差距比大多数其他模型更大。
这表明,Keras 能够提供卓越的开箱即用性能,用户无需深入了解所有优化技巧即可享受到快速的模型运行速度。
发现4
Keras 3 的表现始终优于 Keras 2。
例如,SegmentAnything 的推理速度提升了惊人的 380%,StableDiffusion 的训练处理速度提升了 150% 以上,BERT 的训练处理速度也提升了 100% 以上。
这主要是因为 Keras 2 在某些情况下直接使用了更多的 TensorFlow 融合操作,而这可能对于 XLA 的编译并不是最佳选择。
值得注意的是,即使仅升级到 Keras 3 并继续使用 TensorFlow 后端,也能显著提升性能。
结论
框架的性能在很大程度上取决于具体使用的模型。
Keras 3 能够帮助为任务选择最快的框架,这种选择几乎总能超越 Keras 2 和 PyTorch 实现。
更为重要的是,Keras 3 模型无需进行复杂的底层优化,即可提供卓越的开箱即用性能。
参考资料: