Peaufinage - Amazon SageMaker AI

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Peaufinage

Le peaufinage est un processus d’entraînement continu de modèles pré-entraînés afin d’améliorer les performances pour des cas d’utilisation spécifiques.

CPUs Il est très simple de peaufiner les petits modèles qui s'adaptent entièrement à un seul GPU ou ceux qui s'adaptent entièrement à 8 copies du modèle. Elle ne nécessite aucune modification particulière de l’entraînement FSDP régulier. Dans le domaine des modèles plus grands, vous devez envisager d’utiliser la fonctionnalité d’initialisation différée des paramètres, qui peut s’avérer délicate.

Pour résoudre ce problème, la bibliothèque SMP charge le modèle complet sur l’un des rangs tandis que les autres rangs créent des modèles avec des poids vides sur un méta-dispositif. Ensuite, PyTorch FSDP initialise les poids sur les rangs non nuls à l'aide de la init_weights fonction, et synchronise les poids sur tous les rangs avec les poids sur le 0e rang avec défini sur. sync_module_states True L’extrait de code suivant illustre comment le configurer dans votre script d’entraînement.

import torch.distributed as dist from transformers import AutoModelForCasalLM from accelerate import init_empty_weights from torch.sagemaker.delayed_param import DelayedParamIniter if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) delayed_initer = DelayedParamIniter(model) model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None )

Peaufinage d’un modèle de transformeur Hugging Face pré-entraîné avec le parallélisme de tenseur SMP

Cette section décrit le chargement des modèles de transformeurs pour deux cas d’utilisation : le peaufinage des petits modèles de transformeurs et le peaufinage de grands modèles de transformeurs. Pour les modèles plus petits sans initialisation différée des paramètres, encapsulez le modèle avec l'torch.sagemaker.transformAPI avant de l'encapsuler avec PyTorch FSDP.

import functools from transformers import AutoModelForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.sagemaker import transform model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True) # Transform model while loading state dictionary from rank 0. tp_model = transform(model, load_state_dict_from_rank0=True) # Wrap with FSDP. model = FSDP( tp_model, ... sync_module_states=True, )

Pour les modèles plus grands, l’approche précédente entraîne un épuisement de la mémoire du CPU. Nous vous recommandons d’utiliser l’initialisation différée des paramètres pour éviter de tels problèmes de mémoire du CPU. Dans ce cas, vous pouvez appliquer l’API torch.sagemaker.transform et l’API torch.sagemaker.delayed_param.DelayedParamIniter comme illustré dans l’exemple de code suivant.

from transformers import AutoModelForCausalLM from torch.sagemaker import transform from torch.sagemaker.delayed_param import DelayedParamIniter # Create one instance of model without delayed param # on CPU, on one rank. if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) # Transform model while loading state dictionary from rank 0 model = transform(model, load_state_dict_from_rank0=True) if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks delayed_initer = DelayedParamIniter(model) else: delayed_initer = None with ( delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext() ): # Wrap the model with FSDP model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None )