Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.
Richten Sie verwaltetes mehrstufiges Checkpointing ein
Dieser Abschnitt enthält den Einrichtungsprozess für verwaltetes mehrstufiges Checkpointing für Amazon. SageMaker HyperPod Sie erfahren, wie Sie die Funktion in Ihrem Cluster aktivieren und Checkpointing in Ihrem Trainingscode implementieren.
Themen
Schritt 1: Aktivieren Sie verwaltetes mehrstufiges Checkpointing für Ihren Cluster
Schritt 2: Python-Bibliothek in Ihrem Trainings-Image installieren
Schritt 3: Speichern Sie Checkpoints in Ihrer Trainingsschleife
Schritt 4: Laden Sie die Checkpoints für die Wiederherstellung
Überprüfen Sie Ihre verwalteten mehrstufigen Checkpoint-Operationen
Voraussetzungen
Bevor Sie verwaltetes mehrstufiges Checkpointing einrichten, stellen Sie sicher, dass Sie über Folgendes verfügen:
-
Ein Amazon HyperPod EKS-Cluster mit ausreichend verfügbarem CPU-Speicher für die Checkpoint-Zuweisung
-
PyTorch Trainingsworkloads und DCP-Jobs (beide werden unterstützt)
-
Geeignete IAM-Berechtigungen für die Clusterverwaltung, einschließlich:
-
Amazon CloudWatch - und Amazon S3 S3-Schreibberechtigungen für den Trainings-Pod zum Lesen/Schreiben von Checkpoints und Push-Metriken
-
Diese Berechtigungen können über die EKS-OIDC-Einrichtung konfiguriert werden.
-
Schritt 1: Aktivieren Sie verwaltetes mehrstufiges Checkpointing für Ihren Cluster
Wichtig
Sie müssen sich für die Verwendung von verwaltetem mehrstufigem Checkpointing anmelden.
Aktivieren Sie verwaltetes mehrstufiges Checkpointing über, HyperPod APIs wenn Sie Ihren Cluster erstellen oder aktualisieren. Der Service installiert das Speicherverwaltungssystem automatisch, wenn Sie den TieredStorageConfig-Parameter angeben.
Für neue Cluster können Sie verwenden. create-clusterAWS CLI
aws sagemaker create-cluster \ --cluster-namecluster-name\ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount":instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore":threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB":volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'
Der InstanceMemoryAllocationPercentage-Parameter gibt die (int) des Cluster-Speichers an, der für Checkpointing zugewiesen werden soll. Der Bereich liegt zwischen 20 und 100.percentage
Schritt 2: Python-Bibliothek in Ihrem Trainings-Image installieren
Installieren Sie die Amazon SageMaker Checkpointing-Bibliothek
# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector
Schritt 3: Speichern Sie Checkpoints in Ihrer Trainingsschleife
In deiner Trainingsschleife kannst du Checkpoints mithilfe von DCP asynchron speichern. PyTorch Im Folgenden finden Sie ein Beispiel dafür.
import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background
Schritt 4: Laden Sie die Checkpoints für die Wiederherstellung
Im Folgenden finden Sie ein Beispiel für das Laden eines Checkpoints.
# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling
Überprüfen Sie Ihre verwalteten mehrstufigen Checkpoint-Operationen
Sie können Ihre verwalteten mehrstufigen Checkpoint-Operationen anhand von Protokollen validieren.
Benutzerdefinierte Protokollierung (optional)
Sie können Checkpointing-Protokolle in andere Protokolle integrieren, indem Sie einen benutzerdefinierten Logger an die Bibliothek übergeben. Beispielsweise können Sie Ihrem Trainingscode einen benutzerdefinierten Logger hinzufügen, sodass alle Protokolle aus der Bibliothek auch im Trainings-Logger gesammelt werden.
Verbesserte Serviceprotokollierung (optional)
Um das Debugging und die Transparenz der Services zu verbessern, können Sie den Checkpointing-Protokollpfad /var/log/sagemaker_checkpointing von Ihrem Pod aus in einen /var/logs/sagemaker_checkpointing-Pfad auf Ihrem Host mounten. Dadurch wird sichergestellt, dass nur bibliotheksspezifische Protokolle separat gesammelt werden. So erhält das Serviceteam verbesserte Transparenz hinsichtlich Debugging und Support.