khulnasoft commited on
Commit
22e0e62
·
verified ·
1 Parent(s): e849fd6

Create aifixcode_trainer.py

Browse files
Files changed (1) hide show
  1. aifixcode_trainer.py +86 -0
aifixcode_trainer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### aifixcode_trainer.py
2
+
3
+ """
4
+ This script sets up a simple HuggingFace-based training + inference pipeline
5
+ for bug-fixing AI using a CodeT5 model and supports continual training.
6
+ You can upload this script to HuggingFace Space or Hub repo.
7
+ """
8
+
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
10
+ from datasets import load_dataset, DatasetDict
11
+ import torch
12
+ import os
13
+
14
+ # ========== CONFIG ==========
15
+ MODEL_NAME = "Salesforce/codet5p-220m"
16
+ MODEL_OUT_DIR = "./aifixcode-model"
17
+ TRAIN_DATASET_PATH = "./data/train.json"
18
+ VAL_DATASET_PATH = "./data/val.json"
19
+
20
+ # ========== LOAD MODEL + TOKENIZER ==========
21
+ print("Loading model and tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
24
+
25
+ # ========== LOAD DATASET ==========
26
+ print("Loading dataset...")
27
+ def load_json_dataset(train_path, val_path):
28
+ dataset = DatasetDict({
29
+ "train": load_dataset("json", data_files=train_path)["train"],
30
+ "validation": load_dataset("json", data_files=val_path)["train"]
31
+ })
32
+ return dataset
33
+
34
+ dataset = load_json_dataset(TRAIN_DATASET_PATH, VAL_DATASET_PATH)
35
+
36
+ # ========== PREPROCESS ==========
37
+ print("Tokenizing dataset...")
38
+ def preprocess(example):
39
+ input_code = example["input"]
40
+ target_code = example["output"]
41
+ model_inputs = tokenizer(input_code, truncation=True, padding="max_length", max_length=512)
42
+ labels = tokenizer(target_code, truncation=True, padding="max_length", max_length=512)
43
+ model_inputs["labels"] = labels["input_ids"]
44
+ return model_inputs
45
+
46
+ encoded_dataset = dataset.map(preprocess, batched=True)
47
+
48
+ # ========== TRAINING SETUP ==========
49
+ print("Setting up trainer...")
50
+ training_args = TrainingArguments(
51
+ output_dir=MODEL_OUT_DIR,
52
+ evaluation_strategy="epoch",
53
+ save_strategy="epoch",
54
+ learning_rate=5e-5,
55
+ per_device_train_batch_size=4,
56
+ per_device_eval_batch_size=4,
57
+ num_train_epochs=3,
58
+ weight_decay=0.01,
59
+ logging_dir="./logs",
60
+ logging_strategy="epoch",
61
+ push_to_hub=True,
62
+ hub_model_id="khulnasoft/aifixcode-model",
63
+ hub_strategy="every_save"
64
+ )
65
+
66
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
67
+
68
+ trainer = Trainer(
69
+ model=model,
70
+ args=training_args,
71
+ train_dataset=encoded_dataset["train"],
72
+ eval_dataset=encoded_dataset["validation"],
73
+ tokenizer=tokenizer,
74
+ data_collator=data_collator
75
+ )
76
+
77
+ # ========== TRAIN ==========
78
+ print("Starting training...")
79
+ trainer.train()
80
+
81
+ # ========== SAVE FINAL MODEL ==========
82
+ print("Saving model...")
83
+ trainer.save_model(MODEL_OUT_DIR)
84
+ tokenizer.save_pretrained(MODEL_OUT_DIR)
85
+
86
+ print("Training complete and model saved!")