Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.
Fine-tuning
Il fine-tuning è un processo di addestramento continuo di modelli preaddestrati per migliorare le prestazioni in casi d’uso specifici.
La messa a punto di piccoli modelli che si adattano completamente a una singola GPU o quelli che contengono 8 copie del modello è semplice. CPUs Non richiede infatti particolari modifiche al normale addestramento FSDP. Per quanto riguarda i modelli più grandi di questo, è necessario prendere in considerazione l’utilizzo della funzionalità di inizializzazione ritardata dei parametri, che può risultare complicata.
Per risolvere questo problema, la libreria SMP carica il modello completo su una delle classificazioni, mentre le altre creano modelli con pesi vuoti su un metadispositivo. Quindi, PyTorch FSDP inizializza i pesi sui ranghi diversi da zero utilizzando la init_weights funzione e sincronizza i pesi su tutti i ranghi con i pesi di grado 0 con impostato su. sync_module_states True Il frammento di codice riportato di seguito mostra come eseguire la configurazione nello script di addestramento.
import torch.distributed as dist from transformers import AutoModelForCasalLM from accelerate import init_empty_weights from torch.sagemaker.delayed_param import DelayedParamIniter if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) delayed_initer = DelayedParamIniter(model) model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None )
Fine-tuning di un modello di Hugging Face Transformer preaddestrato con parallelizzazione tensoriale SMP
Questa sezione illustra il caricamento dei modelli di trasformatore per due casi d’uso, ovvero il fine-tuning di modelli Transformer di piccole e di grandi dimensioni. Per i modelli più piccoli senza ritardi nell'inizializzazione dei parametri, avvolgete il modello con l'API prima di avvolgerlo con FSDP. torch.sagemaker.transform PyTorch
import functools from transformers import AutoModelForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.sagemaker import transform model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True) # Transform model while loading state dictionary from rank 0. tp_model = transform(model, load_state_dict_from_rank0=True) # Wrap with FSDP. model = FSDP( tp_model, ... sync_module_states=True, )
Per i modelli più grandi, l’approccio precedente provoca l’esaurimento della memoria della CPU. Per evitare problemi con la memoria della CPU, è consigliabile utilizzare l’inizializzazione ritardata dei parametri. In questo caso, puoi applicare l’API torch.sagemaker.transform e l’API torch.sagemaker.delayed_param.DelayedParamIniter come mostrato nel seguente esempio di codice.
from transformers import AutoModelForCausalLM from torch.sagemaker import transform from torch.sagemaker.delayed_param import DelayedParamIniter # Create one instance of model without delayed param # on CPU, on one rank. if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) # Transform model while loading state dictionary from rank 0 model = transform(model, load_state_dict_from_rank0=True) if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks delayed_initer = DelayedParamIniter(model) else: delayed_initer = None with ( delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext() ): # Wrap the model with FSDP model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None )