Prise en charge de FlashAttention - Amazon SageMaker AI

Prise en charge de FlashAttention

La prise en charge de FlashAttention est une fonctionnalité de la bibliothèque applicable uniquement au modèle de transformateur distribué, qui est un modèle Transformer encapsulé par smp.DistributedModel() pour l'entraînement parallèle entre modèles. Cette fonctionnalité est également compatible avec Parallélisme de tenseur.

La bibliothèque FlashAttention ne prend en charge les modèles que lorsque la attention_head_size est définie sur une valeur multiple de 8 et inférieure à 128. Par conséquent, lorsque vous entraînez un transformateur distribué et que vous veillez à ce que FlashAttention fonctionne correctement, vous devez ajuster les paramètres pour que la taille de la tête d'attention soit conforme aux exigences. Pour plus d'informations, consultez également Installation et fonctionnalités dans le référentiel GitHub de FlashAttention (langue française non garantie).

Supposons, par exemple, que vous configurez un modèle Transformer avec hidden_width=864 et num_heads=48. La taille de la tête de FlashAttention est calculée comme attention_head_size = hidden_width / num_heads = 864 / 48 = 18. Pour activer FlashAttention, vous devez régler le paramètre num_heads sur 54, de sorte que attention_head_size = hidden_width / num_heads = 864 / 54 = 16, soit un multiple de 8.