Escalonar cargas de trabalho de ML usando o Ray

Introdução

A ferramenta Ray do Cloud TPU combina a API Cloud TPU e o Ray Jobs com o objetivo de melhorar a experiência de desenvolvimento dos usuários no Cloud TPU. Neste guia do usuário, fornecemos um exemplo mínimo de como usar o Ray com Cloud TPUs. Esses exemplos não devem ser usados em serviços de produção e servem apenas para fins ilustrativos.

O que está incluído nessa ferramenta?

Para sua conveniência, a ferramenta oferece:

  • Abstrações genéricas que ocultam o código boilerplate para ações comuns da TPU
  • Exemplos de brinquedos que você pode bifurcar para fluxos de trabalho básicos

Especificamente:

  • tpu_api.py: wrapper do Python para operações básicas de TPU usando a API Cloud TPU.
  • tpu_controller.py: representação de classe de uma TPU. Esse é essencialmente um wrapper para tpu_api.py.
  • ray_tpu_controller.py: controlador de TPU com funcionalidade Ray. Isso remove o código boilerplate do cluster do Ray e dos jobs do Ray.
  • run_basic_jax.py: exemplo básico que mostra como usar RayTpuController para print(jax.device_count()).
  • run_hp_search.py: exemplo básico que mostra como o Ray Tune pode ser usado com o JAX/Flax no MNIST.
  • run_pax_autoresume.py: exemplo que mostra como é possível usar RayTpuController para treinamento tolerante a falhas com o PAX como uma carga de trabalho de exemplo.

Como configurar o nó principal do cluster do Ray

Uma das maneiras básicas de usar o Ray com um pod de TPU é configurá-lo como um cluster do Ray. Criar uma VM de CPU separada como VM de coordenador é a maneira natural de fazer isso. O gráfico a seguir mostra um exemplo de configuração de cluster Ray:

Um exemplo de configuração de cluster Ray

Os comandos a seguir mostram como configurar um cluster Ray usando a Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Para sua conveniência, também fornecemos scripts básicos para criar uma VM de coordenador e implantar o conteúdo dessa pasta na sua VM de coordenador. Para ver o código-fonte, consulte create_cpu.sh e deploy.sh.

Esses scripts definem alguns valores padrão:

  • create_cpu.sh criará uma VM chamada $USER-admin e usará o projeto e a zona definidos para seus padrões de gcloud config. Execute gcloud config list para conferir esses padrões.
  • Por padrão, create_cpu.sh aloca um disco de inicialização de 200 GB.
  • deploy.sh pressupõe que o nome da VM é $USER-admin. Se você mudar esse valor em create_cpu.sh, mude-o em deploy.sh.

Para usar os scripts de conveniência:

  1. Clone o repositório do GitHub na sua máquina local e insira a pasta ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Se você não tiver uma conta de serviço dedicada para administração da TPU (altamente recomendado), configure uma:

    $ ./create_tpu_service_account.sh
    
  3. Crie uma VM do coordenador:

    $ ./create_cpu.sh
    

    Este script instala dependências na VM usando um script de inicialização e bloqueia automaticamente até que o script de inicialização seja concluído.

  4. Implante o código local na VM do coordenador:

    $ ./deploy.sh
    
  5. SSH para a VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    O encaminhamento de portas é ativado aqui, já que o Ray iniciará automaticamente um painel na porta 8265. Na máquina que você usa SSH para acessar a VM do coordenador, acesse o painel em http://127.0.0.1:8265/.

  6. Se você pulou a etapa 0, configure suas credenciais gcloud na VM da CPU:

    $ (vm) gcloud auth login --update-adc
    

    Essa etapa define as informações de ID do projeto e permite que a API Cloud TPU seja executada na VM do coordenador.

  7. Requisitos de instalação:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Inicie o Ray na VM do coordenador, e a VM dele se torna o nó principal do cluster do Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Exemplos de uso

Exemplo básico de JAX

run_basic_jax.py é um exemplo mínimo que demonstra como é possível usar o ambiente de execução de jobs de Ray e Ray em um cluster do Ray com VMs de TPU para executar uma carga de trabalho do JAX.

Para frameworks de ML compat��veis com Cloud TPUs que usam um modelo de programação com vários controladores, como JAX e PyTorch/XLA PJRT, é necess��rio executar pelo menos um processo por host. Para mais informações, consulte Modelo de programação de vários processos. Na prática, isso seria assim:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Se você tiver mais de 16 hosts, como uma v4-128, você terá problemas de escalonabilidade de SSH e seu comando talvez precise mudar para:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Isso poderá se tornar um obstáculo na velocidade do desenvolvedor se my_bug_free_python_code tiver bugs. Uma das maneiras de resolver esse problema é usando um orquestrador como o Kubernetes ou o Ray. O Ray inclui o conceito de um ambiente de execução que, quando aplicado, implanta o código e as dependências quando o aplicativo Ray é executado.

Combinar o ambiente de execução do Ray com o cluster e os jobs do Ray permite ignorar o ciclo de SCP/SSH. Considerando que você seguiu os exemplos acima, isso pode ser executado com:

$ python3 legacy/run_basic_jax.py

O resultado será assim:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Treinamento tolerante a falhas

Este exemplo mostra como usar RayTpuController para implementar um treinamento tolerante a falhas. Para este exemplo, pré-treinamos um LLM simples no PAX em uma v4-16, mas é possível substituir essa carga de trabalho do PAX por qualquer outra carga de trabalho de longa duração. Para ver o código-fonte, consulte run_pax_autoresume.py.

Para executar esse exemplo:

  1. Clone paxml na VM do coordenador:

    $ git clone https://github.com/google/paxml.git
    

    Para demonstrar a facilidade de uso que o Ray Runtime Environment oferece para fazer e implantar alterações no JAX, este exemplo requer a modificação do PAX.

  2. Adicione uma nova configuração de experimento:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Execute run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. Durante a execução da carga de trabalho, teste o que acontece quando você exclui sua TPU, por padrão, com o nome $USER-tpu-ray:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    O Ray detectará que a TPU está inativa e mostrará a seguinte mensagem:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    O job recria automaticamente a VM da TPU e reinicia o job de treinamento para que possa retomar o treinamento do checkpoint mais recente (200 etapas neste exemplo):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

Este exemplo mostra o uso do Ray Tune do Ray AIR para ajustar hiperparâmetro a partir do JAX/FLAX. Para ver o código-fonte, consulte run_hp_search.py.

Para executar esse exemplo:

  1. Instale os requisitos:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Execute run_hp_search.py:

    $ python3 src/tune/run_hp_search.py
    

    O resultado será assim:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Solução de problemas

O nó principal do Ray não pode se conectar

Se você executar uma carga de trabalho que crie/exclua o ciclo de vida da TPU, às vezes isso não desconectará os hosts da TPU do cluster do Ray. Isso pode aparecer como erros gRPC que sinalizam que o nó principal do Ray não consegue se conectar a um conjunto de endereços IP.

Como resultado, talvez seja necessário encerrar a sessão de raio (ray stop) e reiniciá-la (ray start --head --port=6379 --num-cpus=0).

O Ray Job falha diretamente sem nenhuma saída de registro

O PAX é experimental, e este exemplo pode falhar devido a dependências de pip. Se isso acontecer, você verá algo assim:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Para ver a causa raiz do erro, acesse http://127.0.0.1:8265/ e veja o painel dos jobs em execução/com falha, que fornecerão mais informações. runtime_env_agent.log mostra todas as informações de erro relacionadas à configuração deruntime_env, por exemplo:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts