# train_llama4.py # Script to fine-tune Llama 4 Maverick for healthcare fraud detection from transformers import AutoProcessor, Llama4ForConditionalGeneration, Trainer, TrainingArguments from transformers import BitsAndBytesConfig import datasets import torch from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from accelerate import Accelerator import huggingface_hub import os # Version and CUDA check print(f"PyTorch version: {torch.__version__}") print(f"CUDA version: {torch.version.cuda}") print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # Authenticate with Hugging Face LLama = os.getenv("LLama") if not LLama: raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.") huggingface_hub.login(token=LLama) # Load Llama 4 model and processor MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) # Quantization config for A100 80 GB VRAM quantization_config = BitsAndBytesConfig(load_in_8bit=True) model = Llama4ForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config, attn_implementation="flex_attention" ) # Prepare for LoRA model = prepare_model_for_kbit_training(model) peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] ) model = get_peft_model(model, peft_config) model.print_trainable_parameters() # Load dataset dataset = datasets.load_dataset("json", data_files="Bingaman_training_data.json", field="training_pairs") print("First example from dataset:", dataset["train"][0]) # Tokenization def tokenize_data(example): messages = [ { "role": "user", "content": [{"type": "text", "text": example['input']}] }, { "role": "assistant", "content": [{"type": "text", "text": example['output']}] } ] formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False) inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt") input_ids = inputs["input_ids"].squeeze(0).tolist() attention_mask = inputs["attention_mask"].squeeze(0).tolist() labels = input_ids.copy() return { "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask } tokenized_dataset = dataset["train"].map(tokenize_data, batched=False, remove_columns=dataset["train"].column_names) print("First tokenized example:", {k: (type(v), len(v)) for k, v in tokenized_dataset[0].items()}) # Data collator def custom_data_collator(features): input_ids = [torch.tensor(f["input_ids"]) for f in features] attention_mask = [torch.tensor(f["attention_mask"]) for f in features] labels = [torch.tensor(f["labels"]) for f in features] return { "input_ids": torch.stack(input_ids), "attention_mask": torch.stack(attention_mask), "labels": torch.stack(labels) } # Training setup accelerator = Accelerator() training_args = TrainingArguments( output_dir="./fine_tuned_llama4_healthcare", per_device_train_batch_size=2, gradient_accumulation_steps=8, eval_strategy="steps", eval_steps=10, save_strategy="steps", save_steps=20, save_total_limit=3, num_train_epochs=5, learning_rate=2e-5, weight_decay=0.01, logging_dir="./logs", logging_steps=5, bf16=True, gradient_checkpointing=True, optim="adamw_torch", warmup_steps=50 ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, eval_dataset=tokenized_dataset.select(range(min(5, len(tokenized_dataset)))), data_collator=custom_data_collator ) # Start training trainer.train() model.save_pretrained("./fine_tuned_llama4_healthcare") processor.save_pretrained("./fine_tuned_llama4_healthcare") print("Training complete. Model and processor saved to ./fine_tuned_llama4_healthcare")