在 TPU Pod 切片上运行 JAX 代码
在单个 TPU 板上运行 JAX 代码后,您可以通过在 TPU Pod 切片上运行代码来扩容代码。 TPU Pod 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU Pod 切片上运行 JAX 代码;如需了解更深入的信息,请参阅在多主机和多进程环境中使用 JAX。
如果您想使用装载的 NFS 进行数据存储,则必须为所有服务设置 OS Login Pod 切片中的 TPU 虚拟机。如需了解详情,请参阅 使用 NFS 进行数据存储。创建 TPU Pod 切片
在运行本文档中的命令之前,请确保您已按照 设置账号和 Cloud TPU 项目中的说明。 在本地机器上运行以下命令。
使用 gcloud
命令可以创建 TPU Pod 切片。例如,要创建
v4-32 Pod 切片使用如下命令:
$ gcloud compute tpus tpu-vm create tpu-name \
--zone=us-central2-b \
--accelerator-type=v4-32 \
--version=tpu-ubuntu2204-base
在 Pod 切片上安装 JAX
创建 TPU Pod 切片之后,您必须在 TPU Pod 切片中的所有主机上安装 JAX。您可以使用 --worker=all
选项通过一个命令在所有主机上安装 JAX:
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
在 Pod 切片上运行 JAX 代码
要在 TPU Pod 切片上运行 JAX 代码,您必须在 TPU Pod 切片中的每个主机上运行代码。jax.device_count()
调用停止响应,直到
Pod 切片中每个主机上调用的方法。以下示例说明了���何
在 TPU Pod 切片上运行简单的 JAX 计算。
准备代码
您的 gcloud
版本不低于 344.0.0(对于
scp
命令)。
使用 gcloud --version
检查您的 gcloud
版本,
运行 gcloud components upgrade
(如果需要)。
使用以下代码创建一个名为 example.py
的文件:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
将 example.py
复制到 Pod 切片中的所有 TPU 工作器虚拟机
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
--worker=all \
--zone=us-central2-b
如果您以前未使用过 scp
命令,则可能会看到
错误,类似于以下内容:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
如需解决此错误,请运行 ssh-add
命令,如
错误消息并重新运行该命令。
在 Pod 切片上运行代码
在每个虚拟机上启动 example.py
程序:
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="python3 example.py"
输出(使用 v4-32 Pod 切片生成):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
清理
完成后,您可以使用 gcloud
命令释放 TPU 虚拟机资源:
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=us-central2-b