Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.
So verwenden Sie den Algorithmus Textklassifizierung – TensorFlow von SageMaker AI
Sie können Textklassifizierung – TensorFlow als integrierten Algorithmus von Amazon SageMaker AI verwenden. Im folgenden Abschnitt wird beschrieben, wie Sie Textklassifizierung – TensorFlow mit dem SageMaker AI Python SDK verwenden. Informationen zur Verwendung von Textklassifizierung – TensorFlow über die Benutzeroberfläche von Amazon SageMaker Studio Classic finden Sie unter SageMaker JumpStart vortrainierte Modelle.
Der Textklassifizierung – TensorFlow-Algorithmus unterstützt Transfer Learning unter Verwendung eines der kompatiblen vortrainierten TensorFlow-Modelle. Eine Liste aller verfügbaren vortrainierten Modelle finden Sie unter TensorFlow-Hub-Modelle. Jedes vortrainierte Modell hat ein Unikat model_id. Im folgenden Beispiel wird BERT Base Uncased (model_id:tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2) zur Feinabstimmung eines benutzerdefinierten Datensatzes verwendet. Die vortrainierten Modelle werden alle vorab vom TensorFlow Hub heruntergeladen und in Amazon S3-Buckets gespeichert, sodass Trainingsauftrages netzwerkisoliert ausgeführt werden können. Verwenden Sie diese vorgenerierten Modelltrainingsartefakte, um einen SageMaker AI Estimator zu erstellen.
Rufen Sie zunächst den Docker-Image-URI, den Trainingsskript-URI und den vortrainierten Modell-URI ab. Ändern Sie dann die Hyperparameter nach Bedarf. Sie können ein Python-Wörterbuch mit allen verfügbaren Hyperparametern und ihren Standardwerten mit hyperparameters.retrieve_default sehen. Weitere Informationen finden Sie unter Textklassifizierungs- TensorFlow Hyperparameter. Verwenden Sie diese Werte, um einen SageMaker AI Estimator zu erstellen.
Anmerkung
Die Standard-Hyperparameterwerte sind für verschiedene Modelle unterschiedlich. Bei größeren Modellen ist die Standardstapelgröße beispielsweise kleiner.
In diesem Beispiel wird der SST2.fit indem Sie den Amazon S3-Speicherort Ihres Trainingsdatensatzes verwenden. Jeder S3-Bucket, der in einem Notebook verwendet wird, muss sich in derselben AWS-Region befinden wie die Notebook-Instance, die darauf zugreift.
from sagemaker import image_uris, model_uris, script_uris, hyperparameters from sagemaker.estimator import Estimator model_id, model_version = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2", "*" training_instance_type = "ml.p3.2xlarge" # Retrieve the Docker image train_image_uri = image_uris.retrieve(model_id=model_id,model_version=model_version,image_scope="training",instance_type=training_instance_type,region=None,framework=None) # Retrieve the training script train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training") # Retrieve the pretrained model tarball for transfer learning train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training") # Retrieve the default hyperparameters for fine-tuning the model hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) # [Optional] Override default hyperparameters with custom values hyperparameters["epochs"] = "5" # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/SST2/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tc-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" # Create an Estimator instance tf_tc_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location, ) # Launch a training job tf_tc_estimator.fit({"training": training_dataset_s3_path}, logs=True)
Weitere Informationen zur Verwendung des SageMaker Textklassifizierung – TensorFlow-Algorithmus für Transfer-Leraning in einem benutzerdefinierten Datensatz finden Sie im Notebook Introduction to JumpStart – Textklassifizierung