File size: 9,423 Bytes
f3b5025
 
a4467aa
 
f3b5025
a4467aa
 
1872e0d
5e2f43e
a4467aa
 
 
 
f3b5025
a4467aa
 
 
 
 
 
 
 
 
 
5e2f43e
a4467aa
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
 
f3b5025
a4467aa
f3b5025
a4467aa
 
f3b5025
 
a4467aa
f3b5025
a4467aa
 
 
 
f3b5025
a4467aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
 
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
 
 
 
f3b5025
 
a4467aa
 
 
 
f3b5025
 
a4467aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b5025
a4467aa
f3b5025
 
a4467aa
 
 
 
 
 
 
1872e0d
a4467aa
 
 
f3b5025
a4467aa
 
f3b5025
a4467aa
 
 
 
 
 
f3b5025
 
a4467aa
f3b5025
 
 
a4467aa
 
 
f3b5025
 
a4467aa
 
 
 
 
 
 
f3b5025
 
a4467aa
 
f3b5025
a4467aa
f3b5025
a4467aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
f3b5025
a4467aa
 
 
 
f3b5025
a4467aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
 
f3b5025
a4467aa
f3b5025
a4467aa
f3b5025
a4467aa
 
f3b5025
a4467aa
 
 
 
 
 
 
 
 
f3b5025
a4467aa
 
f3b5025
a4467aa
 
f3b5025
a4467aa
 
 
 
 
5e2f43e
1872e0d
a4467aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import gradio as gr
import pandas as pd
import torch
import os
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import spaces  # Import the spaces library

# Initialize logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Function to load and process data
def load_data(csv_file):
    try:
        df = pd.read_csv(csv_file)
        logger.info(f"CSV columns: {df.columns.tolist()}")
        logger.info(f"Total rows in CSV: {len(df)}")
        return df
    except Exception as e:
        logger.error(f"Error loading CSV: {e}")
        return None

# Function to prepare dataset
def prepare_dataset(df, teacher_col, student_col, num_samples=100):
    # Extract and format data
    logger.info(f"Using columns: {teacher_col} (teacher) and {student_col} (student)")
    
    formatted_data = []
    for i in range(min(num_samples, len(df))):
        teacher_text = str(df.iloc[i][teacher_col])
        student_text = str(df.iloc[i][student_col])
        
        # Create prompt
        formatted_text = f"### Teacher: {teacher_text}\n### Student: {student_text}"
        formatted_data.append({"text": formatted_text})
    
    logger.info(f"Created {len(formatted_data)} formatted examples")
    
    # Create dataset
    dataset = Dataset.from_list(formatted_data)
    
    # Split dataset
    train_val_split = dataset.train_test_split(test_size=0.1, seed=42)
    
    return train_val_split

# Function to tokenize data
def tokenize_data(dataset, tokenizer, max_length=512):
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length"
        )
    
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    return tokenized_dataset

# Main fine-tuning function with memory optimizations
def finetune_model(model_id, train_data, output_dir, epochs, batch_size=None):
    """
    Fine-tune a model with optimized memory settings to prevent CUDA OOM errors.
    """
    logger.info(f"Using model: {model_id}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # ============ MEMORY OPTIMIZATION 1: REDUCED BATCH SIZE ============
    # A smaller batch size dramatically reduces memory usage during training
    actual_batch_size = 8 if batch_size is None else min(batch_size, 8)
    logger.info(f"Using batch size: {actual_batch_size} (reduced from original to save memory)")
    
    # ============ MEMORY OPTIMIZATION 2: 8-bit QUANTIZATION ============
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        load_in_8bit=True,  # Use 8-bit quantization to reduce memory usage
        device_map="auto",  # Automatically handle model distribution
        use_cache=False,    # Disable KV cache which uses extra memory
        torch_dtype=torch.float16,  # Use lower precision
    )
    
    # Count model parameters
    logger.info(f"Model parameters: {model.num_parameters():,}")
    
    # Prepare model for training with quantization
    model = prepare_model_for_kbit_training(model)
    
    # ============ MEMORY OPTIMIZATION 3: GRADIENT CHECKPOINTING ============
    model.gradient_checkpointing_enable()
    logger.info("Gradient checkpointing enabled: trading computation for memory savings")
    
    # ============ MEMORY OPTIMIZATION 4: OPTIMIZED LORA CONFIG ============
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=4,              # REDUCED from default 8/16 to save memory
        lora_alpha=16,    # Scaling factor
        lora_dropout=0.1, # Dropout probability for regularization
        target_modules=["q_proj", "v_proj"],  # Only attention query and value projections
    )
    logger.info("Using optimized LoRA parameters with reduced rank (r=4) and targeted modules")
    
    # Apply LoRA adapters to the model
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()  # Print trainable parameters info
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        # ============ MEMORY OPTIMIZATION 5: REDUCED BATCH SIZE IN ARGS ============
        per_device_train_batch_size=actual_batch_size,
        per_device_eval_batch_size=actual_batch_size,
        # ============ MEMORY OPTIMIZATION 6: MIXED PRECISION TRAINING ============
        fp16=True,  # Use FP16 for mixed precision training
        # ============ MEMORY OPTIMIZATION 7: GRADIENT ACCUMULATION ============
        gradient_accumulation_steps=4,  # Accumulate gradients over 4 steps
        # ============ MEMORY OPTIMIZATION 8: GRADIENT CHECKPOINTING IN ARGS ============
        gradient_checkpointing=True,
        # Other parameters
        logging_steps=10,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        learning_rate=2e-4,
        weight_decay=0.01,
        warmup_ratio=0.03,
        # ============ MEMORY OPTIMIZATION 9: REDUCED OPTIMIZER OVERHEAD ============
        optim="adamw_torch_fused",  # More memory-efficient optimizer
        # ============ MEMORY OPTIMIZATION 10: REDUCED LOGGING MEMORY ============
        report_to="none",  # Disable extra logging to save memory
    )
    
    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data["train"],
        eval_dataset=train_data["validation"],
        tokenizer=tokenizer,
    )
    
    # ============ MEMORY OPTIMIZATION 11: MANAGE CUDA CACHE ============
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        logger.info("CUDA cache cleared before training")
    
    # Start training
    logger.info("Starting training...")
    trainer.train()
    
    # Save the model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    logger.info(f"Model saved to {output_dir}")
    
    return model, tokenizer

# Gradio interface functions
def process_csv(file, teacher_col, student_col, num_samples):
    df = load_data(file.name)
    if df is None:
        return "Error loading CSV file"
    return f"CSV loaded successfully with {len(df)} rows"

def start_fine_tuning(file, teacher_col, student_col, model_id, epochs, batch_size, num_samples):
    try:
        # Load and process data
        df = load_data(file.name)
        if df is None:
            return "Error loading CSV file"
        
        # Prepare dataset
        dataset = prepare_dataset(df, teacher_col, student_col, num_samples=int(num_samples))
        
        # Load tokenizer for preprocessing
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Tokenize dataset
        tokenized_dataset = {
            "train": tokenize_data(dataset["train"], tokenizer),
            "validation": tokenize_data(dataset["test"], tokenizer),
        }
        
        # Create output directory
        output_dir = "./fine_tuned_model"
        os.makedirs(output_dir, exist_ok=True)
        
        # Finetune model with memory optimizations
        finetune_model(
            model_id=model_id,
            train_data=tokenized_dataset,
            output_dir=output_dir,
            epochs=int(epochs),
            batch_size=int(batch_size),
        )
        
        return "Fine-tuning completed successfully!"
    
    except Exception as e:
        logger.error(f"Error during fine-tuning: {e}")
        return f"Error during fine-tuning: {str(e)}"

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Teacher-Student Bot Fine-Tuning")
    
    with gr.Tab("Upload Data"):
        file_input = gr.File(label="Upload CSV File")
        with gr.Row():
            teacher_col = gr.Textbox(label="Teacher Column", value="Unnamed: 0")
            student_col = gr.Textbox(label="Student Column", value="idx")
        num_samples = gr.Slider(label="Number of Samples", minimum=10, maximum=1000, value=100, step=10)
        upload_btn = gr.Button("Process CSV")
        csv_output = gr.Textbox(label="CSV Processing Result")
        upload_btn.click(process_csv, inputs=[file_input, teacher_col, student_col, num_samples], outputs=csv_output)
    
    with gr.Tab("Fine-Tune"):
        model_id = gr.Textbox(label="Model ID", value="mistralai/Mistral-7B-v0.1")
        with gr.Row():
            batch_size = gr.Number(label="Batch Size", value=8, info="Recommended: 8 or lower for 7B models")
            epochs = gr.Number(label="Number of Epochs", value=2)
        
        training_btn = gr.Button("Start Fine-Tuning")
        training_output = gr.Textbox(label="Training Progress")
        
        training_btn.click(
            start_fine_tuning,
            inputs=[file_input, teacher_col, student_col, model_id, epochs, batch_size, num_samples],
            outputs=training_output
        )

# Launch the app - REMOVED the spaces.zero.mount() call that was causing the error
demo.queue().launch(debug=True)