Puntos de control de activación
Los puntos de comprobación de activación son una técnica para reducir el uso de memoria al borrar las activaciones de determinadas capas y volver a calcularlas durante una pasada hacia atrás. De hecho, esto cambia el tiempo de cálculo adicional por un menor uso de memoria. Si se comprueba un módulo, al final de una pasada hacia adelante, las entradas iniciales al módulo y salidas finales del módulo permanecen en la memoria. PyTorch libera todos los tensores intermedios que formen parte de la computación dentro de ese módulo durante la pasada hacia adelante. Durante la pasada hacia atrás de los módulos con puntos de comprobación, PyTorch recalcula estos tensores. En este punto, las capas situadas más allá de este módulo de puntos de comprobación han terminado su pasada hacia atrás, por lo que el uso máximo de memoria con los puntos de comprobación es menor.
SMP v2 es compatible con el módulo de puntos de comprobación de activación de PyTorch, apply_activation_checkpointing
Capas del transformador de puntos de comprobación del modelo de Hugging Face GPT-NeoX
from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)
Aplicación de puntos de comprobación a una capa del transformador de cada dos del modelo de Hugging Face GPT-NeoX
# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
Alternativamente, PyTorch también tiene el módulo torch.utils.checkpoint para puntos de comprobación, que lo utiliza un subconjunto de modelos de Hugging Face Transformer. Este módulo también funciona con SMP v2. Sin embargo, requiere que tenga acceso a la definición del modelo para añadir el encapsulador de puntos de comprobación. Por tanto, le recomendamos que utilice el método apply_activation_checkpointing.