AWS Trainium & Inferentia documentation
Fine-tune Qwen 3 on AWS Trainium
Fine-tune Qwen 3 on AWS Trainium
This tutorial will teach how to fine-tune open LLMs like Qwen 3 on AWS Trainium. In our example, we are going to leverage Hugging Face Optimum Neuron, Transformers and datasets.
Quick intro: AWS Trainium
AWS Trainium (Trn1) is a purpose-built EC2 for deep learning (DL) training workloads. Trainium is the successor of AWS Inferentia focused on high-performance training workloads. Trainium has been optimized for training natural language processing, computer vision, and recommender models used.
The biggest Trainium instance, the trn1.32xlarge
comes with over 500GB of memory, making it easy to fine-tune ~10B parameter models on a single instance. Below you will find an overview of the available instance types. More details here:
instance size | accelerators | accelerator memory | vCPU | CPU Memory | price per hour |
---|---|---|---|---|---|
trn1.2xlarge | 1 | 32 | 8 | 32 | \$1.34 |
trn1.32xlarge | 16 | 512 | 128 | 512 | \$21.50 |
trn1n.32xlarge (2x bandwidth) | 16 | 512 | 128 | 512 | \$24.78 |
Note: This tutorial was created on a trn1.32xlarge AWS EC2 Instance.
1. Setup AWS environment
In this example, we will use the trn1.32xlarge
instance on AWS with 16 Accelerator, including 32 Neuron Cores and the Hugging Face Neuron Deep Learning AMI. The Hugging Face AMI comes with all important libraries, like Transformers, Datasets, Optimum and Neuron packages pre-installed this makes it super easy to get started, since there is no need for environment management.
If you want to know more about distributed training you can take a look at the documentation.
2. Load and prepare the dataset
We will use a simple recipes, to make LLM get better at suggesting delicious ideas.
{
'recipes': "- Preheat oven to 350 degrees\n- Butter two 9x5' loaf pans\n- Cream the sugar and the butter until light and whipped\n- Add the bananas, eggs, lemon juice, orange rind\n- Beat until blended uniformly\n- Be patient, and beat until the banana lumps are gone\n- Sift the dry ingredients together\n- Fold lightly and thoroughly into the banana mixture\n- Pour the batter into prepared loaf pans\n- Bake for 45 to 55 minutes, until the loaves are firm in the middle and the edges begin to pull away from the pans\n- Cool the loaves on racks for 30 minutes before removing from the pans\n- Freezes well",
'names': 'Beat this banana bread'
}
To load the simple_recipes
dataset, we use the load_dataset()
method from the 🤗 Datasets library.
from random import randrange
from datasets import load_dataset
# Load dataset from the hub
dataset_id = "tengomucho/simple_recipes"
recipes = load_dataset(dataset_id, split="train")
dataset_size = len(recipes)
print(f"dataset size: {dataset_size}")
print(recipes[randrange(dataset_size)])
# dataset size: 20000
To tune our model we need to convert our structured examples into a collection of quotes with a given context, so we define our tokenization function that we will be able to map on the dataset.
The dataset should be structured with input-output pairs, where each input is a prompt and the output is the expected response from the model. We will make use of the model’s tokenizer chat template and preprocess the dataset to be fed to the trainer.
from transformers import AutoTokenizer
model_id = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
def preprocess_function(examples):
recipes = examples["recipes"]
names = examples["names"]
chats = []
for recipe, name in zip(recipes, names):
# Append the EOS token to the response
recipe += tokenizer.eos_token
chat = [
{"role": "user", "content": f"How can I make {name}?"},
{"role": "assistant", "content": recipe},
]
tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
chats.append(chat)
return {"messages": chats}
dataset = recipes.map(preprocess_function, batched=True, remove_columns=recipes.column_names)
Let’s test our formatting function on a random example:
print(dataset[randrange(dataset_size)])
# {
# 'messages': [
# {'content': 'How can I make Aunt liz s almond broccoli casserole', 'role': 'user'},
# {
# 'content': '- Pre-stream broccoli for about 5 minutes\n- Saute onions and garlic in butter\n- Add soup, cheese whiz and mushrooms to sauteed onion mixture\n- Put broccoli into a greased casserole dish and pour sauce over it\n- Sprinkle the almonds over this and then sprinkle the croutons on top\n- Bake at 350 degf for 30
# minutes<|im_end|>',
# 'role': 'assistant'
# }
# ]
# }
3. Fine-tune Qwen 3 on AWS Trainium using the NeuronSFTTrainer and PEFT
Usually, to fine-tune PyTorch-based transformer models you would use PEFT to use LoRA adapters to save memory and use theSFTTrainer
the perform supervised fine-tuning.
On AWS Trainium, optimum-neuron
offers a 1-to-1 replacement with the NeuronSFTTrainer
, optimized to take advantage of the multiple cores available on this setup.
When it comes to distributed training on AWS Trainium there are few things we need to take care of. Since Qwen3 is a big model it does not fit on a single accelerator. The NeuronSFTTrainer
supports different distributed training techniques (DDP, Tensor Parallelism, etc) to solve this.
Loading the model an preparing the LoRA adapter is very similar to what you would do with other accelerators.
import torch
from peft import LoraConfig
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments
from optimum.neuron.models.training import Qwen3ForCausalLM
from optimum.neuron.models.training.config import TrainingNeuronConfig
# This is necessary to pass the training configuration
trn_config = TrainingNeuronConfig(tensor_parallel_size=8, pipeline_parallel_size=1)
# Define your own training arguments
training_args = NeuronTrainingArguments()
dtype = torch.bfloat16 # This will allow to use mixed-precision
model = Qwen3ForCausalLM.from_pretrained(model_id, trn_config, torch_dtype=dtype)
config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.05,
target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"],
bias="none",
task_type="CAUSAL_LM",
)
args = training_args.to_dict()
packing = True
sft_config = NeuronSFTConfig(
max_seq_length=8192,
packing=packing,
**args,
)
def formatting_function(examples):
return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)
trainer = NeuronSFTTrainer(
args=sft_config,
model=model,
peft_config=config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=formatting_function,
)
# Start training
train_result = trainer.train()
We prepared a script, sft_finetuning_qwen3.py to fine-tune Qwen3 that contains everything mentioned in this tutorial. You can launch it with the torchrun command:
PROCESSES_PER_NODE=32
NUM_EPOCHS=3
TP_DEGREE=8
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=2
MODEL_NAME="Qwen/Qwen3-3-8B"
OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned"
MAX_STEPS=-1
DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE"
torchrun $DISTRIBUTED_ARGS notebooks/text-generationscripts/sft_finetuning_qwen3.py \
--model_id $MODEL_NAME \
--num_train_epochs $NUM_EPOCHS \
--do_train \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BS \
--per_device_eval_batch_size $BS \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--gradient_checkpointing true \
--bf16 true \
--save_steps 20 \
--tensor_parallel_size $TP_DEGREE \
--logging_steps $LOGGING_STEPS \
--save_total_limit -1 \
--output_dir $OUTPUT_DIR \
--lr_scheduler_type "cosine" \
--overwrite_output_dir
For convenience, we provide this command in a shell script called sft_finetuning_qwen3.sh.
4. Consolidate and Test finetuned model
Optimum Neuron trains and serializes model shard files separately, meaning that they need to be consolidated (i.e.: re-merged) to be used.
To do this, you can use the Optimum CLI as suggested here:
optimum-cli neuron consolidate Qwen3-8B-finetuned Qwen3-8B-finetuned/adapter_default
This will create an adapter_model.safetensors
file, the LoRA adapter weights that we trained in the previous step. We can now reload the model and merge it, so it can be loaded for evaluation:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
MODEL_NAME = "Qwen/Qwen3-8B"
ADAPTER_PATH = "Qwen3-8B-finetuned/adapter_default"
MERGED_MODEL_PATH = "Qwen3-8B-recipes"
# Load base model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load adapter configuration and model
adapter_config = PeftConfig.from_pretrained(ADAPTER_PATH)
finetuned_model = PeftModel.from_pretrained(model, ADAPTER_PATH, config=adapter_config)
print("Saving tokenizer")
tokenizer.save_pretrained(MERGED_MODEL_PATH)
print("Saving model")
finetuned_model = finetuned_model.merge_and_unload()
finetuned_model.save_pretrained(MERGED_MODEL_PATH)
Once this step is done, it is possible to test the model with a new prompt.
You have successfully created a fine-tuned model from Qwen3!