amaltese commited on
Commit
5e2f43e
·
verified ·
1 Parent(s): 485668f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import json
5
+ import pandas as pd
6
+ from datasets import Dataset
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TrainingArguments,
11
+ Trainer,
12
+ DataCollatorForLanguageModeling
13
+ )
14
+ from peft import (
15
+ LoraConfig,
16
+ get_peft_model,
17
+ prepare_model_for_kbit_training,
18
+ PeftModel
19
+ )
20
+
21
+ # Set environment variable for cache directory
22
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_cache'
23
+ os.makedirs('/tmp/hf_cache', exist_ok=True)
24
+
25
+ def sample_from_csv(csv_file, sample_size=100):
26
+ """Sample from CSV file and format for training"""
27
+ df = pd.read_csv(csv_file)
28
+
29
+ # Display CSV info
30
+ print(f"CSV columns: {df.columns.tolist()}")
31
+ print(f"Total rows in CSV: {len(df)}")
32
+
33
+ # Try to identify teacher and student columns
34
+ teacher_col = None
35
+ student_col = None
36
+
37
+ for col in df.columns:
38
+ col_lower = col.lower()
39
+ if 'teacher' in col_lower or 'instructor' in col_lower or 'prompt' in col_lower:
40
+ teacher_col = col
41
+ elif 'student' in col_lower or 'response' in col_lower or 'answer' in col_lower:
42
+ student_col = col
43
+
44
+ # If we couldn't identify columns, use the first two
45
+ if teacher_col is None or student_col is None:
46
+ teacher_col = df.columns[0]
47
+ student_col = df.columns[1]
48
+
49
+ # Sample rows
50
+ if sample_size >= len(df):
51
+ sampled_df = df
52
+ else:
53
+ sampled_df = df.sample(n=sample_size, random_state=42)
54
+
55
+ # Format data
56
+ texts = []
57
+ for _, row in sampled_df.iterrows():
58
+ teacher_text = str(row[teacher_col]).strip()
59
+ student_text = str(row[student_col]).strip()
60
+
61
+ # Skip rows with empty values
62
+ if not teacher_text or not student_text or teacher_text == 'nan' or student_text == 'nan':
63
+ continue
64
+
65
+ # Format according to the document format:
66
+ # <s> [INST] Teacher ** <Dialogue> [/INST] Student** <Dialogue> </s>
67
+ formatted_text = f"<s> [INST] Teacher ** {teacher_text} [/INST] Student** {student_text} </s>"
68
+ texts.append(formatted_text)
69
+
70
+ return Dataset.from_dict({"text": texts})
71
+
72
+ def finetune_model(csv_file, sample_size=100, num_epochs=3, progress=gr.Progress()):
73
+ """Fine-tune the model and return results"""
74
+ # Check GPU
75
+ if torch.cuda.is_available():
76
+ print(f"GPU available: {torch.cuda.get_device_name(0)}")
77
+ device = torch.device("cuda")
78
+ else:
79
+ print("No GPU available, fine-tuning will be extremely slow!")
80
+ device = torch.device("cpu")
81
+
82
+ # Sample data
83
+ progress(0.1, "Sampling data from CSV...")
84
+ dataset = sample_from_csv(csv_file, sample_size)
85
+
86
+ # Split dataset
87
+ dataset_split = dataset.train_test_split(test_size=0.1)
88
+
89
+ # Load tokenizer
90
+ progress(0.2, "Loading tokenizer...")
91
+ model_name = "mistralai/Mistral-7B-v0.1"
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+
95
+ # Tokenize dataset
96
+ def tokenize_function(examples):
97
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
98
+
99
+ progress(0.3, "Tokenizing dataset...")
100
+ tokenized_datasets = dataset_split.map(tokenize_function, batched=True)
101
+
102
+ # Load model with LoRA configuration
103
+ progress(0.4, "Loading model...")
104
+ lora_config = LoraConfig(
105
+ r=8,
106
+ lora_alpha=16,
107
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
108
+ lora_dropout=0.05,
109
+ bias="none",
110
+ task_type="CAUSAL_LM"
111
+ )
112
+
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ model_name,
115
+ torch_dtype=torch.float16,
116
+ device_map="auto",
117
+ )
118
+
119
+ # Prepare model for LoRA training
120
+ model = prepare_model_for_kbit_training(model)
121
+ model = get_peft_model(model, lora_config)
122
+
123
+ # Training arguments
124
+ output_dir = "mistral7b_finetuned"
125
+ training_args = TrainingArguments(
126
+ output_dir=output_dir,
127
+ num_train_epochs=num_epochs,
128
+ per_device_train_batch_size=1,
129
+ gradient_accumulation_steps=4,
130
+ save_steps=50,
131
+ logging_steps=10,
132
+ learning_rate=2e-4,
133
+ weight_decay=0.001,
134
+ fp16=True,
135
+ warmup_steps=50,
136
+ lr_scheduler_type="cosine",
137
+ report_to="none", # Disable wandb
138
+ )
139
+
140
+ # Initialize trainer
141
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
142
+ trainer = Trainer(
143
+ model=model,
144
+ args=training_args,
145
+ train_dataset=tokenized_datasets["train"],
146
+ eval_dataset=tokenized_datasets["test"],
147
+ data_collator=data_collator,
148
+ )
149
+
150
+ # Train model
151
+ progress(0.5, "Training model...")
152
+ trainer.train()
153
+
154
+ # Save model
155
+ progress(0.9, "Saving model...")
156
+ trainer.model.save_pretrained(output_dir)
157
+ tokenizer.save_pretrained(output_dir)
158
+
159
+ # Test with sample prompts
160
+ progress(0.95, "Testing model...")
161
+ test_prompts = [
162
+ "How was the Math exam?",
163
+ "Good morning students! How are you all?",
164
+ "What should you do if you get into a fight with a friend?",
165
+ "Did you complete your science project?",
166
+ "What did you learn in class today?"
167
+ ]
168
+
169
+ # Load the fine-tuned model for inference
170
+ fine_tuned_model = PeftModel.from_pretrained(
171
+ model,
172
+ output_dir,
173
+ device_map="auto",
174
+ )
175
+
176
+ # Generate responses
177
+ results = []
178
+ for prompt in test_prompts:
179
+ formatted_prompt = f"<s> [INST] Teacher ** {prompt} [/INST] Student**"
180
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
181
+
182
+ with torch.no_grad():
183
+ outputs = fine_tuned_model.generate(
184
+ **inputs,
185
+ max_length=200,
186
+ temperature=0.7,
187
+ top_p=0.95,
188
+ do_sample=True,
189
+ )
190
+
191
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
192
+ student_part = response.split("Student**")[1].strip() if "Student**" in response else response
193
+
194
+ results.append({
195
+ "prompt": prompt,
196
+ "response": student_part
197
+ })
198
+
199
+ # Save results
200
+ with open("test_results.json", "w") as f:
201
+ json.dump(results, f, indent=2)
202
+
203
+ progress(1.0, "Completed!")
204
+ return results
205
+
206
+ # Define Gradio interface
207
+ with gr.Blocks() as demo:
208
+ gr.Markdown("# Mistral 7B Fine-Tuning for Student Bot")
209
+
210
+ with gr.Tab("Fine-tune Model"):
211
+ with gr.Row():
212
+ csv_input = gr.File(label="Upload Teacher-Student CSV")
213
+
214
+ with gr.Row():
215
+ sample_size = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="Sample Size")
216
+ epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
217
+
218
+ with gr.Row():
219
+ start_btn = gr.Button("Start Fine-Tuning")
220
+
221
+ with gr.Row():
222
+ output = gr.JSON(label="Results")
223
+
224
+ start_btn.click(finetune_model, inputs=[csv_input, sample_size, epochs], outputs=[output])
225
+
226
+ with gr.Tab("About"):
227
+ gr.Markdown("""
228
+ ## Fine-Tuning Mistral 7B for Student Bot
229
+
230
+ This app fine-tunes the Mistral 7B model to respond like a student to teacher prompts.
231
+
232
+ ### Requirements
233
+ - CSV file with teacher-student conversation pairs
234
+ - GPU acceleration (provided by this Space)
235
+
236
+ ### Process
237
+ 1. Upload your CSV file
238
+ 2. Set sample size and number of epochs
239
+ 3. Click "Start Fine-Tuning"
240
+ 4. View test results with sample prompts
241
+ """)
242
+
243
+ # Launch app
244
+ demo.launch()