ML-Arbeitslasten mit Ray skalieren

Einführung

Das Ray-Tool von Cloud TPU kombiniert die Cloud TPU API und Ray-Jobs mit dem Ziel, die Entwicklungsumgebung auf Cloud TPU für Nutzer zu verbessern. Dieses Nutzerhandbuch enthält ein minimales Beispiel dafür, wie Sie Ray mit Cloud TPUs verwenden können. Diese Beispiele sind nicht zur Verwendung in Produktionsdiensten gedacht und dienen nur zur Veranschaulichung.

Inhalt dieses Tools

Das Tool bietet folgende Funktionen:

  • Generische Abstraktionen, die Boilerplate für gängige TPU-Aktionen verbergen
  • Spielzeugbeispiele, die Sie für Ihre eigenen grundlegenden Workflows verzweigen können

Zum Beispiel:

  • tpu_api.py: Python-Wrapper für einfache TPU-Vorgänge mit der Cloud TPU API.
  • tpu_controller.py: Klassendarstellung einer TPU. Dies ist im Wesentlichen ein Wrapper für tpu_api.py.
  • ray_tpu_controller.py: TPU-Controller mit Ray-Funktionalität. Dadurch wird der Boilerplate-Code für Ray-Cluster- und Ray-Jobs entfernt.
  • run_basic_jax.py: Einfaches Beispiel, das zeigt, wie RayTpuController für print(jax.device_count()) verwendet wird.
  • run_hp_search.py: Einfaches Beispiel, das zeigt, wie Ray Tune mit JAX/Flax auf MNIST verwendet werden kann.
  • run_pax_autoresume.py: Beispiel, das zeigt, wie Sie RayTpuController für fehlertolerantes Training mit PAX als Beispielarbeitslast verwenden können.

Hauptknoten des Ray-Clusters einrichten

Eine der grundlegenden Möglichkeiten, Ray mit einem TPU-Pod zu verwenden, besteht darin, den TPU-Pod als Ray-Cluster einzurichten. Dazu erstellen Sie normalerweise eine separate CPU-VM als Koordinator-VM. Die folgende Grafik zeigt ein Beispiel für eine Ray-Clusterkonfiguration:

Beispiel für eine Ray-Clusterkonfiguration

Die folgenden Befehle zeigen, wie Sie einen Ray-Cluster mit der Google Cloud CLI einrichten:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Um Ihnen die Arbeit zu erleichtern, stellen wir auch grundlegende Skripts zum Erstellen einer Koordinator-VM und zum Bereitstellen des Inhalts dieses Ordners auf Ihrer Koordinator-VM bereit. Den Quellcode finden Sie unter create_cpu.sh und deploy.sh.

Diese Skripts legen einige Standardwerte fest:

  • create_cpu.sh erstellt eine VM mit dem Namen $USER-admin und verwendet das Projekt und die Zone, die auf die Standardeinstellungen von gcloud config festgelegt sind. Führen Sie gcloud config list aus, um diese Standardeinstellungen zu sehen.
  • create_cpu.sh weist standardmäßig ein Bootlaufwerk von 200 GB zu.
  • deploy.sh geht davon aus, dass Ihr VM-Name $USER-admin ist. Wenn Sie diesen Wert in create_cpu.sh ändern, müssen Sie ihn auch in deploy.sh ändern.

So verwenden Sie die praktischen Skripts:

  1. Klonen Sie das GitHub-Repository auf Ihren lokalen Computer und geben Sie den Ordner ray_tpu ein:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Wenn Sie kein dediziertes Dienstkonto für die TPU-Verwaltung haben (dringend empfohlen), richten Sie eines ein:

    $ ./create_tpu_service_account.sh
    
  3. Erstellen Sie eine Koordinator-VM:

    $ ./create_cpu.sh
    

    Dieses Skript installiert Abhängigkeiten auf der VM mit einem Startup-Skript. Es wird automatisch blockiert, bis das Startup-Skript abgeschlossen ist.

  4. Stellen Sie lokalen Code auf der Koordinator-VM bereit:

    $ ./deploy.sh
    
  5. Stellen Sie eine SSH-Verbindung zur VM her:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Die Portweiterleitung ist hier aktiviert, da Ray automatisch ein Dashboard an Port 8265 startet. Sie können von der Maschine, die Sie über die SSH-Verbindung zur Koordinator-VM verbinden, auf dieses Dashboard unter http://127.0.0.1:8265/ zugreifen.

  6. Wenn Sie Schritt 0 übersprungen haben, richten Sie Ihre gcloud-Anmeldedaten auf der CPU-VM ein:

    $ (vm) gcloud auth login --update-adc
    

    Mit diesem Schritt werden Projekt-ID-Informationen festgelegt und die Ausführung der Cloud TPU API auf der Koordinator-VM ermöglicht.

  7. Installationsvoraussetzungen:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Starten Sie Ray auf der Koordinator-VM. Die Koordinator-VM wird dann zum Hauptknoten des Ray-Clusters:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Beispiele für die Verwendung

Einfaches JAX-Beispiel

run_basic_jax.py ist ein minimales Beispiel, das zeigt, wie Sie die Ray-Jobs und Ray-Laufzeitumgebung in einem Ray-Cluster mit TPU-VMs verwenden können, um eine JAX-Arbeitslast auszuführen.

Für ML-Frameworks, die mit Cloud TPUs kompatibel sind, die ein Multi-Controller-Programmiermodell wie JAX und PyTorch/XLA PJRT verwenden, müssen Sie mindestens einen Prozess pro Host ausführen. Weitere Informationen finden Sie unter Multi-Prozess-Programmiermodell. In der Praxis könnte das so aussehen:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Wenn Sie mehr als ~16 Hosts haben, z. B. v4-128, werden Probleme mit der SSH-Skalierbarkeit auftreten und der Befehl muss möglicherweise so geändert werden:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Dies kann die Geschwindigkeit der Entwickler beeinträchtigen, wenn my_bug_free_python_code Programmfehler enthält. Dieses Problem lässt sich beispielsweise mithilfe eines Orchestrators wie Kubernetes oder Ray lösen. Ray beinhaltet das Konzept einer Laufzeitumgebung, die bei Anwendung Code und Abhängigkeiten bereitstellt, wenn die Ray-Anwendung ausgeführt wird.

Durch die Kombination der Ray-Laufzeitumgebung mit dem Ray-Cluster und Ray-Jobs können Sie den SCP/SSH-Zyklus umgehen. Wenn Sie die obigen Beispiele befolgt haben, können Sie dies mit folgendem Befehl ausführen:

$ python3 legacy/run_basic_jax.py

Die Ausgabe sieht in etwa so aus:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Fehlertolerantes Training

In diesem Beispiel wird gezeigt, wie Sie mit RayTpuController fehlertolerantes Training implementieren können. In diesem Beispiel wird ein einfaches LLM auf PAX auf einer v4-16 vorab trainiert. Beachten Sie jedoch, dass Sie diese PAX-Arbeitslast durch jede andere lang andauernde Arbeitslast ersetzen können. Den Quellcode finden Sie unter run_pax_autoresume.py.

So führen Sie das Beispiel aus:

  1. Klonen Sie paxml in Ihre Koordinator-VM:

    $ git clone https://github.com/google/paxml.git
    

    Um die Nutzerfreundlichkeit der Ray-Laufzeitumgebung für das Vornehmen und Bereitstellen von JAX-Änderungen zu demonstrieren, müssen Sie in diesem Beispiel PAX ändern.

  2. Fügen Sie eine neue Testkonfiguration hinzu:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Führen Sie run_pax_autoresume.py aus.

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. Experimentieren Sie während der Ausführung der Arbeitslast, was passiert, wenn Sie Ihre standardmäßige TPU mit dem Namen $USER-tpu-ray löschen:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray erkennt, dass die TPU ausgefallen ist, mit der folgenden Meldung:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    Der Job erstellt automatisch die TPU-VM neu und startet den Trainingsjob neu, damit er das Training ab dem letzten Prüfpunkt fortsetzen kann (in diesem Beispiel 200 Schritte):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

In diesem Beispiel wird die Verwendung von Ray-Tune von Ray AIR zur Hyperparameter-Abstimmung MNIST von JAX/FLAX veranschaulicht. Den Quellcode finden Sie unter run_hp_search.py.

So führen Sie das Beispiel aus:

  1. Installieren Sie die Anforderungen:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Führen Sie run_hp_search.py aus.

    $ python3 src/tune/run_hp_search.py
    

    Die Ausgabe sieht in etwa so aus:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Fehlerbehebung

Ray-Hauptknoten kann nicht verbunden werden

Wenn Sie eine Arbeitslast ausführen, die den TPU-Lebenszyklus erstellt/löscht, werden die TPU-Hosts dadurch manchmal nicht vom Ray-Cluster getrennt. Dies kann als gRPC-Fehler angezeigt werden, die darauf hinweisen, dass der Ray-Hauptknoten keine Verbindung zu einer Gruppe von IP-Adressen herstellen kann.

In diesem Fall müssen Sie Ihre Ray-Sitzung möglicherweise beenden (ray stop) und neu starten (ray start --head --port=6379 --num-cpus=0).

Ray-Job schlägt direkt ohne Logausgabe fehl

PAX ist experimentell und dieses Beispiel kann aufgrund von pip-Abhängigkeiten fehlerhaft sein. In diesem Fall könnte eine Meldung wie diese angezeigt werden:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Die Ursache des Fehlers finden Sie im Dashboard für laufende/fehlgeschlagene Jobs unter http://127.0.0.1:8265/. Dort finden Sie weitere Informationen. runtime_env_agent.log zeigt alle Fehlerinformationen im Zusammenhang mit der Einrichtung von „runtime_env“ an. Beispiel:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts