必备Keras提示:
1. 在急切执行模式下进行调试/原型制作。您可以使用PyTorch后端进行原型设计,或者如果您仅考虑一个层或前向传递,甚至可以使用NumPy后端(NumPy不支持梯度/训练)。使用JAX或TF时,请确保使用run_eagerly=True,以便于调试的急切执行。
2. 当真正开始训练时,通常切换到JAX后端是个好主意,它提供了最先进的性能(至少比PT快10%,尤其是对于基于Transformer的模型,通常快达3倍)。请注意,XLA编译的TF性能也非常有竞争力。TF通常也比JAX或PT更节省内存。
3. 如果您正在寻找模型的最佳可实现性能(无论是训练还是推断),请尝试所有可能的后端,并坚持使用最快的一个。最快的后端可能并不总是JAX!
4. 对于TPU或GPU集群上的大规模数据/模型并行性,请始终使用JAX后端,配合keras.distribution API(例如,使用keras.distribution.ModelParallel来指定分片配置)。与其他所有选项(尤其是非JAX选项)相比,这将为您节省很多麻烦。