FlashAttention - Amazon SageMaker KI

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

FlashAttention

SMP v2 unterstützt FlashAttentionKernel und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Beachten Sie, dass SMP FlashAttention v2 verwendet, wenn Sie FlashAttention Paket v2.0 oder höher verwenden. Triton Flash Attention verwendet jedoch standardmäßig den Flash Attention-Kernel in FlashAttention v1.x, sodass er ausschließlich in Version 1 unterstützt wird. FlashAttention

Das Modul (nn.Module) ist eine Low-Level-API, die die Aufmerksamkeitsebenen eines Modells definiert. Es sollte direkt nach der Modellerstellung angewendet werden, beispielsweise über die AutoModelForCausalLM.from_config()-API, bevor das Modell transformiert oder mit FSDP umschlossen wird.

Verwenden Sie Kernel zur Selbstwahrnehmung FlashAttention

Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashSelfAttention-API verwendet wird.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Verwenden Sie FlashAttention Kernel für die Aufmerksamkeit bei gruppierten Abfragen

SMP v2 unterstützt auch FlashAttentionKernel für Grouped-Query Attention (GQA) und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Im Unterschied zur ursprünglichen Aufmerksamkeitsarchitektur unterteilt GQA Abfrageköpfe gleichmäßig in Gruppen und die Abfrageköpfe in derselben Gruppe verwenden dieselben Schlüssel- und Wertköpfe. Daher werden q- und kv-Köpfe getrennt an den Vorwärtsaufruf übergeben. Hinweis: Die Anzahl der q-Köpfe muss durch die Anzahl der kv-Köpfe teilbar sein.

Beispiel für die Verwendung FlashGroupedQueryAttention

Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashGroupedQueryAttention-API verwendet wird.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

Die SMP-Bibliothek bietet auch die Funktion torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, die die torch.sagemaker.nn.attn.FlashGroupedQueryAttention-API auf niedriger Ebene verwendet. Hugging Face Transformers hat eine ähnliche Implementierung, die ab Version 4.36.0 LlamaFlashAttention2 genannt wird. Der folgende Codeausschnitt zeigt, wie die APIs SMP LlamaFlashAttention v2 oder Transformers LlamaFlashAttention2 verwendet werden, um die Aufmerksamkeitsebenen eines vorhandenen Llama-Modells zu ersetzen.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))