Points de contrôle à l’aide de la SMP - Amazon SageMaker AI

Points de contrôle à l’aide de la SMP

La bibliothèque SageMaker de parallélisme des modèles (SMP) prend en charge les API PyTorch pour les points de contrôle et fournit des API qui aident à contrôler correctement les points lors de l’utilisation de la bibliothèque SMP.

PyTorch FSDP (Fully Sharded Data Parallelism) prend en charge trois types de points de contrôle : complets, partitionnés et locaux, chacun pour des objectifs différents. Des points de contrôle complets sont utilisés lors de l’exportation du modèle une fois l’entraînement terminé, car la génération d’un point de contrôle complet est un processus coûteux en termes de calcul. Les points de contrôle partitionnés permettent d’enregistrer et de charger l’état d’un modèle partitionné pour chaque rang individuel. Grâce aux points de contrôle partitionnés, vous pouvez reprendre l’entraînement avec différentes configurations matérielles, par exemple dans le cas d’un nombre différent de GPU. Cependant, le chargement des points de contrôle partitionnés peut être lent en raison de la communication requise entre plusieurs dispositifs. La bibliothèque SMP fournit des fonctionnalités de points de contrôle locaux, qui permettent d’extraire plus rapidement l’état du modèle sans surcharger davantage les communications. Notez que les points de contrôle créés par FSDP nécessitent d’écrire dans un système de fichiers réseau partagé tel qu’Amazon FSx.

Points de contrôle locaux asynchrones

Lors de l’entraînement des modèles de machine learning, les itérations suivantes n’ont pas à attendre que les fichiers de points de contrôle soient enregistrés sur disque. Avec SMP v2.5, la bibliothèque prend en charge l’enregistrement des fichiers de point de contrôle de manière asynchrone. Cela signifie que l’itération d’entraînement suivante peut être exécutée simultanément avec les opérations d’entrée et de sortie (E/S) pour créer des points de contrôle, sans être ralentie ou freinée par ces opérations d’E/S. De plus, le processus d’extraction des paramètres du modèle partitionné et de l’optimiseur dans PyTorch peut prendre du temps en raison de la communication collective supplémentaire requise pour échanger les métadonnées de tenseur distribuées entre les rangs. Même quand StateDictType.LOCAL_STATE_DICT est utilisé pour enregistrer des points de contrôle locaux pour chaque rang, PyTorch invoque toujours des hooks qui effectuent une communication collective. Pour atténuer ce problème et réduire le temps nécessaire à l’extraction des points de contrôle, SMP introduit SMStateDictType.SM_LOCAL_STATE_DICT, qui permet d’extraire plus rapidement les points de contrôle du modèle et de l’optimiseur en évitant la surcharge de communication collective.

Note

Le maintien de la cohérence du SHARD_DEGREE FSDP est une condition préalable à l’utilisation de SMStateDictType.SM_LOCAL_STATE_DICT. Assurez-vous que le SHARD_DEGREE reste inchangé. Bien que le nombre de réplications du modèle puisse varier, le degré de partitionnement du modèle doit être identique à celui de la configuration d’entraînement précédente lorsque vous reprenez à partir d’un point de contrôle.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

L’extrait de code suivant illustre comment charger un point de contrôle à l’aide de SMStateDictType.SM_LOCAL_STATE_DICT.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Le stockage de points de contrôle pour les grands modèles de langage (LLM) peut s’avérer coûteux, car il nécessite souvent la création d’un grand volume de système de fichiers. Pour réduire les coûts, vous avez la possibilité d’enregistrer les points de contrôle directement dans Amazon S3 sans avoir besoin de services de système de fichiers supplémentaires tels qu’Amazon FSx. Vous pouvez utiliser l’exemple précédent avec l’extrait de code suivant pour enregistrer des points de contrôle dans S3 en spécifiant une URL S3 comme destination.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Points de contrôle partitionnés asynchrones

Il peut arriver que vous deviez continuer d’entraîner avec différentes configurations matérielles, par exemple en modifiant le nombre de GPU. Dans ces cas, vos processus d’entraînement doivent charger des points de contrôle lors du repartitionnement, ce qui implique de reprendre l’entraînement suivant avec un nombre différent de SHARD_DEGREE. Afin de répondre au scénario dans lequel vous devez reprendre l’entraînement avec un nombre différent de SHARD_DEGREE, vous devez enregistrer les points de contrôle de votre modèle à l’aide du type de dictionnaire d’état partitionné, représenté par StateDictType.SHARDED_STATE_DICT. L’enregistrement des points de contrôle dans ce format vous permet de gérer correctement le processus de repartitionnement lorsque vous poursuivez l’entraînement avec une configuration matérielle modifiée. L’extrait de code fourni illustre comment utiliser l’API tsm pour enregistrer des points de contrôle partitionnés de manière asynchrone, permettant ainsi un processus d’entraînement plus efficace et rationalisé.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Le processus de chargement des points de contrôle partagés est similaire à celui décrit dans la section précédente, mais il implique d’utiliser torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader et sa méthode load. La méthode load de cette classe permet de charger les données des points de contrôle partagés, en suivant un processus analogue à celui décrit précédemment.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Points de contrôle de modèle complets

À la fin de l’entraînement, vous pouvez enregistrer un point de contrôle complet qui combine toutes les partitions d’un modèle dans un seul fichier de points de contrôle du modèle. La bibliothèque SMP prend entièrement en charge l’API de points de contrôle complets PyTorch, vous n’avez donc pas besoin d’effectuer de modifications.

Notez que si vous utilisez le Parallélisme de tenseur SMP, la bibliothèque SMP transforme le modèle. Dans ce cas, lorsque vous contrôlez les points du modèle complet, la bibliothèque SMP retraduit le modèle au format de point de contrôle des transformeurs Hugging Face par défaut.

Dans les cas où vous entraînez avec le parallélisme de tenseur SMP et que vous désactivez le processus de traduction SMP, vous pouvez utiliser l’argument translate_on_save de l’API FullStateDictConfig PyTorch pour activer ou désactiver la traduction automatique SMP selon vos besoins. Par exemple, si vous vous concentrez sur l’entraînement d’un modèle, vous n’avez pas besoin d’ajouter le processus de traduction qui entraîne une surcharge. Dans ce cas, nous vous recommandons de définir le paramètres sur la valeur translate_on_save=False. De plus, si vous prévoyez de continuer à utiliser la traduction SMP du modèle pour un futur entraînement supplémentaire, vous pouvez la désactiver afin de l’enregistrer pour utilisation ultérieure. Il est nécessaire de retraduire le modèle au format de point de contrôle de modèle de transformeur Hugging Face lorsque vous terminez l’entraînement de votre modèle et que vous l’utilisez à des fins d’inférence.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Notez que l’option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) consiste à rassembler le modèle sur le CPU du dispositif de rang 0 pour économiser de la mémoire lors de l’entraînement de grands modèles.

Pour recharger le modèle à des fins d’inférence, vous devez procéder comme illustré dans l’exemple de code suivant. Notez que la classe AutoModelForCausalLM peut être remplacée par d’autres classes de création de facteurs dans les transformeurs Hugging Face, par exemple AutoModelForSeq2SeqLM, en fonction de votre modèle. Pour plus d’informations, consultez la documentation des transformeurs Hugging Face.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)