Scaling des charges de travail de ML à l'aide de Ray

Introduction

L'outil Cloud TPU Ray combine l'API Cloud TPU et les tâches Ray dans le but d'améliorer l'expérience de développement des utilisateurs sur Cloud TPU. Ce guide de l'utilisateur fournit un exemple minimal d'utilisation de Ray avec des Cloud TPU. Ces exemples ne sont pas destinés à être utilisés dans les services de production et sont fournis à titre d'illustration uniquement.

Que comprend cet outil ?

Pour plus de commodité, l'outil inclut les éléments suivants:

  • Abstractions génériques qui masquent le code récurrent pour les actions TPU courantes
  • Exemples de jouets que vous pouvez dupliquer pour vos propres workflows de base

à savoir :

  • tpu_api.py : wrapper Python pour les opérations TPU de base utilisant l'API Cloud TPU.
  • tpu_controller.py : représentation de classe d'un TPU. Il s'agit essentiellement d'un wrapper pour tpu_api.py.
  • ray_tpu_controller.py : contrôleur TPU avec la fonctionnalité Ray. Cela élimine le code récurrent pour les tâches Ray Cluster et Ray.
  • run_basic_jax.py : exemple de base montrant comment utiliser RayTpuController pour print(jax.device_count()).
  • run_hp_search.py : exemple de base illustrant l'utilisation de Ray Tun avec JAX/Flax sur MNIST.
  • run_pax_autoresume.py : exemple qui montre comment utiliser RayTpuController pour un entraînement tolérant aux pannes en utilisant PAX comme exemple de charge de travail.

Configurer le nœud principal du cluster Ray

L'un des moyens les plus simples d'utiliser Ray avec un pod TPU consiste à configurer le pod TPU en tant que cluster Ray. La création d'une VM processeur distincte en tant que VM coordinateur est la méthode la plus naturelle. Le graphique suivant présente un exemple de configuration de cluster Ray:

Exemple de configuration de cluster Ray

Les commandes suivantes montrent comment configurer un cluster Ray à l'aide de la Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Pour plus de commodité, nous fournissons également des scripts de base pour créer une VM de coordinateur et déployer le contenu de ce dossier sur la VM de votre coordinateur. Pour le code source, consultez create_cpu.sh et deploy.sh.

Ces scripts définissent des valeurs par défaut:

  • create_cpu.sh créera une VM nommée $USER-admin et utilisera le projet et la zone définis sur vos valeurs par défaut pour gcloud config. Exécutez gcloud config list pour afficher ces valeurs par défaut.
  • Par défaut, create_cpu.sh alloue une taille de disque de démarrage de 200 Go.
  • deploy.sh suppose que le nom de votre VM est $USER-admin. Si vous modifiez cette valeur dans create_cpu.sh, veillez à le faire dans deploy.sh.

Pour utiliser les scripts de commodité:

  1. Clonez le dépôt GitHub sur votre ordinateur local et saisissez le chemin d'accès ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Si vous ne disposez pas de compte de service dédié à l'administration du TPU (fortement recommandé), configurez-en un:

    $ ./create_tpu_service_account.sh
    
  3. Créez une VM de coordinateur:

    $ ./create_cpu.sh
    

    Ce script installe les dépendances sur la VM à l'aide d'un script de démarrage et se bloque automatiquement jusqu'à la fin de ce script.

  4. Déployez du code local sur la VM du coordinateur:

    $ ./deploy.sh
    
  5. Connectez-vous en SSH à la VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Le transfert de port est activé ici, car Ray lance automatiquement un tableau de bord sur le port 8265. Depuis la machine avec laquelle vous vous connectez en SSH à la VM de votre coordinateur, vous pourrez accéder à ce tableau de bord à l'adresse http://127.0.0.1:8265/.

  6. Si vous avez ignoré l'étape 0, configurez vos identifiants gcloud dans la VM processeur:

    $ (vm) gcloud auth login --update-adc
    

    Cette étape permet de définir les informations d'ID du projet et de permettre à l'API Cloud TPU de s'exécuter sur la VM du coordinateur.

  7. Conditions d'installation:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Démarrez Ray sur la VM du coordinateur. Celle-ci devient le nœud principal du cluster Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Exemples d'utilisation

Exemple de code JAX de base

run_basic_jax.py est un exemple minimal qui montre comment utiliser l'environnement d'exécution Ray Jobs et Ray sur un cluster Ray avec des VM TPU pour exécuter une charge de travail JAX.

Pour les frameworks de ML compatibles avec les Cloud TPU qui utilisent un modèle de programmation multicontrôleur, comme JAX et PyTorch/XLA PJRT, vous devez exécuter au moins un processus par hôte. Pour en savoir plus, consultez la section Modèle de programmation multiprocessus. En pratique, cela peut se présenter comme suit:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Si vous avez plus de 16 hôtes, par exemple la version v4-128, vous rencontrerez des problèmes d'évolutivité SSH et votre commande devra peut-être être remplacée par:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Cela peut nuire à la rapidité des développeurs si my_bug_free_python_code contient des bugs. L'une des méthodes permettant de résoudre ce problème consiste à faire appel à un orchestration tel que Kubernetes ou Ray. Ray inclut le concept d'environnement d'exécution qui, lorsqu'il est appliqué, déploie du code et des dépendances lors de l'exécution de l'application Ray.

La combinaison de l'environnement d'exécution Ray avec un cluster Ray et des tâches Ray vous permet de contourner le cycle SCP/SSH. En supposant que vous ayez suivi les exemples ci-dessus, vous pouvez exécuter cette commande avec la commande suivante:

$ python3 legacy/run_basic_jax.py

Le résultat ressemble à ce qui suit :

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Entraînement à tolérance aux pannes

Cet exemple montre comment utiliser RayTpuController pour implémenter un entraînement tolérant aux pannes. Dans cet exemple, nous pré-entraînons un LLM simple sur PAX sur une version 4-16, mais notez que vous pouvez remplacer cette charge de travail PAX par toute autre charge de travail de longue durée. Pour en savoir plus sur le code source, consultez run_pax_autoresume.py.

Pour exécuter cet exemple:

  1. Clonez paxml sur la VM de votre coordinateur:

    $ git clone https://github.com/google/paxml.git
    

    Pour démontrer la facilité d'utilisation de l'environnement d'exécution Ray pour effectuer et déployer des modifications JAX, vous devez modifier PAX dans cet exemple.

  2. Ajoutez une configuration de test:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Exécutez run_pax_autoresume.py :

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. Pendant l'exécution de la charge de travail, testez ce qui se passe lorsque vous supprimez votre TPU (nommé $USER-tpu-ray par défaut) :

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray détecte que le TPU est arrêté et le message suivant s'affiche:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    La tâche recrée automatiquement la VM TPU et redémarre la tâche d'entraînement pour pouvoir reprendre l'entraînement à partir du dernier point de contrôle (200 étapes dans cet exemple):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

Cet exemple présente l'utilisation de Ray Tun du Ray AIR pour le réglage des hyperparamètres MNIST depuis JAX/FLAX. Pour en savoir plus sur le code source, consultez run_hp_search.py.

Pour exécuter cet exemple:

  1. Installez les éléments requis:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Exécutez run_hp_search.py :

    $ python3 src/tune/run_hp_search.py
    

    Le résultat ressemble à ce qui suit :

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Dépannage

Le nœud principal Ray ne parvient pas à se connecter

Si vous exécutez une charge de travail qui crée/supprime le cycle de vie du TPU, cela ne déconnecte parfois pas les hôtes TPU du cluster Ray. Cela peut se produire en tant qu'erreurs gRPC qui signalent que le nœud principal Ray ne peut pas se connecter à un ensemble d'adresses IP.

Par conséquent, vous devrez peut-être arrêter votre session Ray (ray stop) et la redémarrer (ray start --head --port=6379 --num-cpus=0).

La tâche Ray échoue directement sans sortie de journal

PAX est expérimental et cet exemple peut ne pas fonctionner en raison de dépendances de pip. Dans ce cas, un écran semblable à celui-ci peut s'afficher:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Pour connaître la cause de l'erreur, vous pouvez accéder à http://127.0.0.1:8265/ et consulter le tableau de bord des tâches en cours d'exécution ou ayant échoué, qui fournit plus d'informations. runtime_env_agent.log affiche toutes les informations d'erreur liées à la configuration de "runtime_env", par exemple:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts