Esegui un calcolo su una VM Cloud TPU utilizzando JAX
Questo documento fornisce una breve introduzione al lavoro con JAX e Cloud TPU.
Prima di seguire questa guida rapida, devi creare un account per Google Cloud Platform, installare Google Cloud CLI e configurare il comando gcloud
.
Per saperne di più, vedi Configurare un account e un progetto Cloud TPU.
Installa Google Cloud CLI
Google Cloud CLI contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per maggiori informazioni, consulta Installare Google Cloud CLI.
Configura il comando gcloud
Esegui i comandi seguenti per configurare gcloud
in modo da utilizzare il progetto Google Cloud e installare i componenti necessari per l'anteprima della VM TPU.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Abilita l'API Cloud TPU
Abilita l'API Cloud TPU utilizzando il seguente comando
gcloud
in Cloud Shell. (puoi anche abilitarlo dalla console Google Cloud).$ gcloud services enable tpu.googleapis.com
Esegui questo comando per creare un'identità di servizio.
$ gcloud beta services identity create --service tpu.googleapis.com
Crea una VM Cloud TPU con gcloud
Con le VM Cloud TPU, il modello e il codice vengono eseguiti direttamente sulla macchina host della TPU. Accedi direttamente all'host TPU tramite SSH. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare i log e il codice di debug direttamente sull'host TPU.
Crea la VM TPU eseguendo il comando seguente da Cloud Shell o dal terminale del computer in cui è installato Google Cloud CLI.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-8 \ --version=tpu-ubuntu2204-base
Campi obbligatori
zone
- La zona in cui prevedi di creare la Cloud TPU.
accelerator-type
- Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per maggiori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU.
version
- La versione software di Cloud TPU. Per tutti i tipi di TPU, utilizza
tpu-ubuntu2204-base
.
Connettiti alla VM Cloud TPU
Accedi alla VM TPU utilizzando il seguente comando:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b
Campi obbligatori
tpu_name
- Il nome della VM TPU a cui ti stai connettendo.
zone
- La zona in cui hai creato la Cloud TPU.
Installa JAX sulla VM Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Controllo del sistema
Verifica che JAX possa accedere alla TPU e possa eseguire operazioni di base:
Avvia l'interprete Python 3:
(vm)$ python3
>>> import jax
Visualizza il numero di core TPU disponibili:
>>> jax.device_count()
Il numero di core TPU viene visualizzato. Se utilizzi una TPU v4, dovrebbe essere
4
. Se utilizzi una TPU v2 o v3, dovrebbe essere 8
.
Esegui un semplice calcolo:
>>> jax.numpy.add(1, 1)
Viene visualizzato il risultato dell'aggiunta di numpy:
Output dal comando:
Array(2, dtype=int32, weak_type=true)
Esci dall'interprete Python:
>>> exit()
Esecuzione di codice JAX su una VM TPU
Ora puoi eseguire qualsiasi codice JAX che vuoi. Gli esempi di lino sono un ottimo punto di partenza per eseguire modelli ML standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:
Installa le dipendenze degli esempi Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installa FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Esegui lo script di addestramento FLAX MNIST
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Lo script scarica il set di dati e avvia l'addestramento. L'output dello script dovrebbe essere simile a questo:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Esegui la pulizia
Al termine delle operazioni della VM TPU, segui questi passaggi per la pulizia delle risorse.
Disconnettiti dall'istanza di Compute Engine, se non lo hai già fatto:
(vm)$ exit
Elimina la tua Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central2-b
Verifica che le risorse siano state eliminate eseguendo questo comando. Assicurati che la tua TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.
$ gcloud compute tpus tpu-vm list \ --zone=us-central2-b
Note sulle prestazioni
Di seguito sono riportati alcuni dettagli importanti particolarmente importanti per l'utilizzo delle TPU in JAX.
Spaziatura interna
Una delle cause più comuni di lentezza delle prestazioni sulle TPU è l'introduzione di spaziatura interna involontaria:
- Gli array in Cloud TPU sono affiancati. Ciò comporta il riempimento di una delle dimensioni fino a un multiplo di 8 e una dimensione diversa a un multiplo di 128.
- L'unità di moltiplicazione matriciale funziona meglio con coppie di matrici di grandi dimensioni che riducono al minimo la necessità di spaziatura.
bfloat16 dtype
Per impostazione predefinita, la moltiplicazione della matrice in JAX sulle TPU utilizza bfloat16 con l'accumulo con float32. Può essere controllato con l'argomento precisione sulle chiamate di funzione jax.numpy pertinenti (matmul, punto, einsum e così via). In particolare:
precision=jax.lax.Precision.DEFAULT
: utilizza la precisione bfloat16 mista (più veloce)precision=jax.lax.Precision.HIGH
: utilizza più pass MXU per ottenere una maggiore precisioneprecision=jax.lax.Precision.HIGHEST
: utilizza ancora più passaggi MXU per raggiungere la precisione float32
JAX aggiunge anche il comando dtype bfloat16, che puoi usare per trasmettere esplicitamente array a bfloat16
, ad esempio jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Esecuzione di JAX in un Colab
Quando esegui codice JAX in un blocco note di Colab, Colab crea automaticamente un nodo TPU legacy. I nodi TPU hanno un'architettura diversa. Per ulteriori informazioni, consulta Architettura di sistema.
Passaggi successivi
Per ulteriori informazioni su Cloud TPU, vedi: