Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.
FlashAttention
SMP v2 unterstützt FlashAttention
Das Modul (nn.Module) ist eine Low-Level-API, die die Aufmerksamkeitsebenen eines Modells definiert. Es sollte direkt nach der Modellerstellung angewendet werden, beispielsweise über die AutoModelForCausalLM.from_config()-API, bevor das Modell transformiert oder mit FSDP umschlossen wird.
Verwenden Sie Kernel zur Selbstwahrnehmung FlashAttention
Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashSelfAttention-API verwendet wird.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Verwenden Sie FlashAttention Kernel für die Aufmerksamkeit bei gruppierten Abfragen
SMP v2 unterstützt auch FlashAttention
Beispiel für die Verwendung FlashGroupedQueryAttention
Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashGroupedQueryAttention-API verwendet wird.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
Die SMP-Bibliothek bietet auch die Funktion torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, die die torch.sagemaker.nn.attn.FlashGroupedQueryAttention-API auf niedriger Ebene verwendet. Hugging Face Transformers hat eine ähnliche Implementierung, die ab Version 4.36.0 LlamaFlashAttention2LlamaFlashAttention v2 oder Transformers LlamaFlashAttention2 verwendet werden, um die Aufmerksamkeitsebenen eines vorhandenen Llama-Modells zu ersetzen.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))