Points de contrôle d'activation - Amazon SageMaker AI

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Points de contrôle d'activation

Les points de contrôle d’activation sont une technique permettant de réduire l’utilisation de la mémoire en effaçant les activations de certaines couches et en les recalculant lors d’une transmission vers l’arrière. Concrètement, cela échange un temps de calcul supplémentaire contre une réduction de l’utilisation de la mémoire. Si un module est contrôlé, à la fin d'une passe directe, seules les entrées initiales du module et les sorties finales du module restent en mémoire. PyTorch libère tous les tenseurs intermédiaires qui font partie du calcul à l'intérieur de ce module lors de la passe directe. Lors du passage en arrière des modules pointés de contrôle, PyTorch recalcule ces tenseurs. À ce stade, les couches situées au-delà de ce module avec points de contrôle ont terminé leur transmission vers l’arrière. Avec les points de contrôle, le pic d’utilisation de la mémoire est réduit.

SMP v2 prend en charge le module de point de contrôle PyTorch d'activation,. apply_activation_checkpointing Voici des exemples de points de contrôle d’activation du modèle GPT-NeoX Hugging Face.

Points de contrôles de couches de transformeurs du modèle GPT-NeoX Hugging Face

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Points de contrôle de toute autre couche de transformeur du modèle GPT-NeoX Hugging Face

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

Il possède PyTorch également le torch.utils.checkpoint module de point de contrôle, qui est utilisé par un sous-ensemble de modèles Hugging Face Transformers. Ce module fonctionne également avec SMP v2. Toutefois, vous devez avoir accès à la définition du modèle pour ajouter le wrapper de points de contrôle. Par conséquent, nous vous recommandons d’utiliser la méthode apply_activation_checkpointing.