Checkpointing mit SMP - Amazon SageMaker KI

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

Checkpointing mit SMP

Die SageMaker Modellparallelismus-Bibliothek (SMP) unterstützt PyTorch APIs für Checkpoints und stellt APIs diese Hilfesuchpoints bei der Verwendung der SMP-Bibliothek korrekt bereit.

PyTorch FSDP (Fully Sharded Data Parallelism) unterstützt drei Arten von Checkpoints: vollständige, geteilte und lokale Checkpoints, die jeweils unterschiedlichen Zwecken dienen. Vollständige Prüfpunkte werden verwendet, wenn das Modell nach Abschluss des Trainings exportiert wird, da die Generierung eines vollständigen Prüfpunkts ein rechenintensiver Prozess ist. Mit Hilfe von fragmentierten Prüfpunkten kann der Zustand eines Modells, das für jeden einzelnen Rang fragmentiert wurde, gespeichert und geladen werden. Mit Sharded Checkpoints können Sie das Training mit unterschiedlichen Hardwarekonfigurationen fortsetzen, z. B. mit einer anderen Anzahl von. GPUs Das Laden von fragmentierten Prüfpunkten kann jedoch aufgrund der Kommunikation zwischen mehreren Geräten langsam sein. Die SMP-Bibliothek bietet lokale Prüfpunktfunktionen, die ein schnelleres Abrufen des Modellzustands ohne zusätzlichen Kommunikationsaufwand ermöglichen. Beachten Sie, dass von FSDP erstellte Checkpoints in ein gemeinsam genutztes Netzwerkdateisystem wie Amazon geschrieben werden müssen. FSx

Asynchrone lokale Prüfpunkte

Beim Training von Machine-Learning-Modellen sind keine nachfolgenden Iterationen erforderlich, um darauf zu warten, dass die Prüfpunktdateien auf der Festplatte gespeichert werden. Mit der Veröffentlichung von SMP v2.5 unterstützt die Bibliothek das asynchrone Speichern von Prüfpunktdateien. Das bedeutet, dass die nachfolgende Trainingsiteration gleichzeitig mit den Eingabe- und Ausgabeoperationen (Operationen) ausgeführt werden kann. I/O) operations for creating checkpoints, without being slowed down or held back by those I/O Außerdem PyTorch kann das Abrufen von Shard-Modell- und Optimizer-Parametern zeitaufwändig sein, da zusätzliche kollektive Kommunikation erforderlich ist, um verteilte Tensor-Metadaten zwischen Rängen auszutauschen. Selbst wenn es verwendet wird, StateDictType.LOCAL_STATE_DICT um lokale Checkpoints für jeden Rang zu speichern, ruft es PyTorch immer noch Hooks auf, die kollektive Kommunikation durchführen. Um dieses Problem zu beheben und den Zeitaufwand für das Abrufen von Prüfpunkten zu reduzieren, führt SMP SMStateDictType.SM_LOCAL_STATE_DICT für einen schnelleren Abruf von Modell- und Optimierer-Prüfpunkten ein, indem der kollektive Kommunikationsaufwand umgangen wird.

Anmerkung

Die Wahrung der Konsistenz im FSDP SHARD_DEGREE ist eine Voraussetzung für die Nutzung von SMStateDictType.SM_LOCAL_STATE_DICT. Stellen Sie sicher, dass der SHARD_DEGREE unverändert bleibt. Die Anzahl der Modellreplikationen kann zwar variieren, der Grad der Modellfragmentierung muss jedoch mit dem vorherigen Trainingsaufbau identisch sein, wenn der Vorgang von einem Prüfpunkt aus fortgesetzt wird.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

Der folgende Codeausschnitt zeigt, wie Sie mit SMStateDictType.SM_LOCAL_STATE_DICT einen Prüfpunkt laden können.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Das Speichern von Checkpoints für umfangreiche Sprachmodelle (LLMs) kann teuer sein, da dafür oft ein großes Dateisystemvolumen erstellt werden muss. Um die Kosten zu senken, haben Sie die Möglichkeit, Checkpoints direkt in Amazon S3 zu speichern, ohne dass zusätzliche Dateisystemdienste wie Amazon erforderlich sind. FSx Sie können das vorherige Beispiel mit dem folgenden Codeausschnitt anwenden, um Checkpoints in S3 zu speichern, indem Sie eine S3-URL als Ziel angeben.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Asynchrone fragmentierten Prüfpunkte

Es kann Situationen geben, in denen Sie das Training mit unterschiedlichen Hardwarekonfigurationen fortsetzen müssen, z. B. wenn Sie die Anzahl der Geräte ändern. GPUs In diesen Fällen müssen Ihre Trainingsprozesse Prüfpunkte während des Reshardings laden, was bedeutet, dass Sie das nachfolgende Training mit einer anderen Anzahl von SHARD_DEGREE wieder aufnehmen. In diesem Szenario, in dem Sie das Training mit einer anderen Anzahl von SHARD_DEGREE fortsetzen müssen, müssen Sie Ihre Modellprüfpunkte mithilfe des Fragmentierungs-Zustandswörterbuchs speichern, das durch StateDictType.SHARDED_STATE_DICT dargestellt wird. Wenn Sie Prüfpunkte in diesem Format speichern, können Sie den Resharding-Prozess ordnungsgemäß durchführen, wenn Sie das Training mit einer geänderten Hardwarekonfiguration fortsetzen. Der bereitgestellte Codeausschnitt veranschaulicht, wie die tsm-API verwendet werden kann, um fragmentierte Prüfpunkte asynchron zu speichern, was einen effizienteren und optimierten Trainingsprozess ermöglicht.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Der Vorgang zum Laden gemeinsamer Prüfpunkte ähnelt dem vorherigen Abschnitt, beinhaltet jedoch die Verwendung von torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader und der load-Methode. Diese load-Methode ermöglicht es Ihnen, die gemeinsamen Prüfpunktdaten nach einem Verfahren zu laden, das dem zuvor beschriebenen Vorgang entspricht.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Vollständige Modellprüfpunkte

Am Ende des Trainings können Sie einen vollständigen Prüfpunkt speichern, der alle Shards eines Modells in einer einzigen Modellprüfpunktdatei zusammenfasst. Die SMP-Bibliothek unterstützt die PyTorch vollständige Modell-Checkpoint-API vollständig, sodass Sie keine Änderungen vornehmen müssen.

Beachten Sie, dass die SMP-Bibliothek das Modell transformiert, wenn Sie SMP Tensor-Parallelität verwenden. In diesem Fall übersetzt die SMP-Bibliothek beim Checkpointing des vollständigen Modells das Modell standardmäßig zurück in das Prüfpunktformat von Hugging Face Transformers.

In Fällen, in denen Sie mit der SMP-Tensorparallelität trainieren und den SMP-Übersetzungsprozess ausschalten, können Sie das translate_on_save Argument der PyTorch FullStateDictConfig API verwenden, um die automatische SMP-Übersetzung nach Bedarf ein- oder auszuschalten. Wenn Sie sich beispielsweise darauf konzentrieren, ein Modell zu trainieren, müssen Sie den Übersetzungsprozess nicht hinzufügen, was zusätzlichen Aufwand hinzufügt. In diesem Fall empfehlen wir die Einstellung von translate_on_save=False. Wenn Sie planen, die SMP-Übersetzung des Modells auch in Zukunft für das Training zu verwenden, können Sie diese Funktion auch ausschalten, um die SMP-Übersetzung des Modells für eine spätere Verwendung zu speichern. Die Rückübersetzung des Modells in das Modellprüfpunktformat von Hugging Face Transformers ist erforderlich, wenn Sie das Training Ihres Modells abschließen und es für Inferenzen verwenden.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Beachten Sie, dass die Option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) dazu dient, das Modell auf der CPU des Geräts des 0ten Ranges zu sammeln, um beim Training großer Modelle Speicherplatz zu sparen.

Um das Modell für die Inferenz wieder zu laden, gehen Sie wie im folgenden Codebeispiel gezeigt vor. Beachten Sie, dass die Klasse AutoModelForCausalLM in Hugging Face Transformers je nach Modell möglicherweise zu anderen Faktor-Builder-Klassen wechselt, wie z. B. AutoModelForSeq2SeqLM. Weitere Informationen finden Sie in der Dokumentation zu Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)