Imposta un checkpoint gestito su più livelli - Amazon SageMaker AI

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

Imposta un checkpoint gestito su più livelli

Questa sezione contiene il processo di configurazione per il checkpoint gestito a più livelli per Amazon. SageMaker HyperPod Scoprirai come abilitare la funzionalità nel cluster e implementare i checkpoint nel codice di addestramento.

Prerequisiti

Prima di configurare il checkpoint gestito a più livelli, assicurati di avere:

  • Un HyperPod cluster Amazon EKS con sufficiente memoria CPU disponibile per l'allocazione dei checkpoint

  • PyTorch carichi di lavoro di formazione e lavori DCP (entrambi supportati)

  • Autorizzazioni IAM appropriate per la gestione dei cluster, tra cui:

    • Autorizzazioni di scrittura di Amazon CloudWatch e Amazon S3 per il pod di formazione per leggere/scrivere checkpoint e inviare metriche

    • Queste autorizzazioni possono essere configurate dalla configurazione EKS OIDC

Passaggio 1: abilita il checkpoint gestito su più livelli per il tuo cluster

Importante

È necessario attivare l'utilizzo del checkpoint gestito su più livelli.

Abilita il checkpoint gestito su più livelli tramite HyperPod APIs durante la creazione o l'aggiornamento del cluster. Il servizio installa automaticamente il sistema di gestione della memoria quando specifichi il parametro TieredStorageConfig.

Per i nuovi cluster, puoi usare. create-cluster AWS CLI

aws sagemaker create-cluster \ --cluster-name cluster-name \ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount": instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore": threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'

Il parametro InstanceMemoryAllocationPercentage specifica il valore percentage (int) della memoria del cluster da allocare per i checkpoint. L’intervallo è compreso tra 20 e 100.

Fase 2. Installa la libreria Python nell’immagine di addestramento

Installa la libreria Amazon SageMaker checkpointing e le sue dipendenze nella tua immagine di formazione aggiungendola al tuo Dockerfile:

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector

Fase 3: salva i checkpoint nel ciclo di allenamento

Nel ciclo di allenamento, puoi salvare i checkpoint in modo asincrono utilizzando DCP. PyTorch Di seguito è riportato un esempio su come eseguire questa operazione.

import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background

Fase 4: Caricare i checkpoint per il ripristino

Di seguito è riportato un esempio sul caricamento di un checkpoint.

# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling

Convalida le tue operazioni di checkpoint gestite su più livelli

È possibile convalidare le operazioni di checkpoint gestite su più livelli con i log.

Registrazione di log personalizzata (facoltativo)

Puoi integrare i log dei checkpoint con altri log passando un logger personalizzato alla libreria. Ad esempio, puoi aggiungere un logger personalizzato al codice di addestramento in modo che tutti i log della libreria vengano raccolti anche nel logger di addestramento.

Registrazione avanzata di log dei servizi (facoltativo)

Per migliorare il debug e la visibilità del servizio, puoi montare il percorso del log dei checkpoint /var/log/sagemaker_checkpointing dall’interno del pod a un percorso /var/logs/sagemaker_checkpointing sull’host. Questa operazione garantisce che solo i log specifici della libreria vengano raccolti separatamente. Il team di assistenza può quindi avere una maggiore visibilità per il debug e il supporto.