FlashAttention - Amazon SageMaker AI

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

FlashAttention

SMP v2 supporta i FlashAttentionkernel e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. Nota che se utilizzi il FlashAttention pacchetto v2.0 o successivo, SMP utilizza la FlashAttention v2; tuttavia, Triton flash attention utilizza per impostazione predefinita il kernel flash attention nella v1.x, rendendolo supportato esclusivamente nella v1. FlashAttention FlashAttention

Il modulo (nn.Module) è un’API di basso livello che definisce i livelli di attenzione di un modello. Dovrebbe essere applicato subito dopo la creazione del modello, ad esempio dall’API AutoModelForCausalLM.from_config(), e prima che si esegua la trasformazione o il wrapping del modello con FSDP.

Usa i kernel per l'attenzione personale FlashAttention

Il frammento di codice riportato di seguito mostra come utilizzare l’API torch.sagemaker.nn.attn.FlashSelfAttention fornita da SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Usa i FlashAttention kernel per attirare l'attenzione sulle query raggruppate

SMP v2 supporta anche i FlashAttentionkernel for grouped-query attention (GQA) e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. A differenza dell’architettura di attenzione originale, GQA suddivide equamente le teste di query in gruppi e le teste di query dello stesso gruppo condividono le stesse teste di chiavi e valori. Pertanto, le teste q e kv vengono passate alla chiamata in avanti separatamente. Nota: il numero di teste q deve essere divisibile per il numero di teste kv.

Esempio di utilizzo FlashGroupedQueryAttention

Il frammento di codice riportato di seguito mostra come utilizzare l’API torch.sagemaker.nn.attn.FlashGroupedQueryAttention fornita da SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

La libreria SMP fornisce anche torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, che utilizza l’API torch.sagemaker.nn.attn.FlashGroupedQueryAttention a basso livello. A partire dalla v4.36.0., Hugging Face Transformers dispone di un’implementazione simile chiamata LlamaFlashAttention2. Il frammento di codice riportato di seguito mostra come utilizzare l’API LlamaFlashAttention SMP v2 o l’API LlamaFlashAttention2 Transformers per sostituire i livelli di attenzione di un modello Llama esistente.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))