Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import torch | |
import os | |
from datasets import Dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training | |
import spaces # Import the spaces library | |
# Initialize logging | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Function to load and process data | |
def load_data(csv_file): | |
try: | |
df = pd.read_csv(csv_file) | |
logger.info(f"CSV columns: {df.columns.tolist()}") | |
logger.info(f"Total rows in CSV: {len(df)}") | |
return df | |
except Exception as e: | |
logger.error(f"Error loading CSV: {e}") | |
return None | |
# Function to prepare dataset | |
def prepare_dataset(df, teacher_col, student_col, num_samples=100): | |
# Extract and format data | |
logger.info(f"Using columns: {teacher_col} (teacher) and {student_col} (student)") | |
formatted_data = [] | |
for i in range(min(num_samples, len(df))): | |
teacher_text = str(df.iloc[i][teacher_col]) | |
student_text = str(df.iloc[i][student_col]) | |
# Create prompt | |
formatted_text = f"### Teacher: {teacher_text}\n### Student: {student_text}" | |
formatted_data.append({"text": formatted_text}) | |
logger.info(f"Created {len(formatted_data)} formatted examples") | |
# Create dataset | |
dataset = Dataset.from_list(formatted_data) | |
# Split dataset | |
train_val_split = dataset.train_test_split(test_size=0.1, seed=42) | |
return train_val_split | |
# Function to tokenize data | |
def tokenize_data(dataset, tokenizer, max_length=512): | |
def tokenize_function(examples): | |
return tokenizer( | |
examples["text"], | |
truncation=True, | |
max_length=max_length, | |
padding="max_length" | |
) | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
return tokenized_dataset | |
# Main fine-tuning function with memory optimizations | |
def finetune_model(model_id, train_data, output_dir, epochs, batch_size=None): | |
""" | |
Fine-tune a model with optimized memory settings to prevent CUDA OOM errors. | |
""" | |
logger.info(f"Using model: {model_id}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# ============ MEMORY OPTIMIZATION 1: REDUCED BATCH SIZE ============ | |
# A smaller batch size dramatically reduces memory usage during training | |
actual_batch_size = 8 if batch_size is None else min(batch_size, 8) | |
logger.info(f"Using batch size: {actual_batch_size} (reduced from original to save memory)") | |
# ============ MEMORY OPTIMIZATION 2: 8-bit QUANTIZATION ============ | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
load_in_8bit=True, # Use 8-bit quantization to reduce memory usage | |
device_map="auto", # Automatically handle model distribution | |
use_cache=False, # Disable KV cache which uses extra memory | |
torch_dtype=torch.float16, # Use lower precision | |
) | |
# Count model parameters | |
logger.info(f"Model parameters: {model.num_parameters():,}") | |
# Prepare model for training with quantization | |
model = prepare_model_for_kbit_training(model) | |
# ============ MEMORY OPTIMIZATION 3: GRADIENT CHECKPOINTING ============ | |
model.gradient_checkpointing_enable() | |
logger.info("Gradient checkpointing enabled: trading computation for memory savings") | |
# ============ MEMORY OPTIMIZATION 4: OPTIMIZED LORA CONFIG ============ | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=4, # REDUCED from default 8/16 to save memory | |
lora_alpha=16, # Scaling factor | |
lora_dropout=0.1, # Dropout probability for regularization | |
target_modules=["q_proj", "v_proj"], # Only attention query and value projections | |
) | |
logger.info("Using optimized LoRA parameters with reduced rank (r=4) and targeted modules") | |
# Apply LoRA adapters to the model | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() # Print trainable parameters info | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
num_train_epochs=epochs, | |
# ============ MEMORY OPTIMIZATION 5: REDUCED BATCH SIZE IN ARGS ============ | |
per_device_train_batch_size=actual_batch_size, | |
per_device_eval_batch_size=actual_batch_size, | |
# ============ MEMORY OPTIMIZATION 6: MIXED PRECISION TRAINING ============ | |
fp16=True, # Use FP16 for mixed precision training | |
# ============ MEMORY OPTIMIZATION 7: GRADIENT ACCUMULATION ============ | |
gradient_accumulation_steps=4, # Accumulate gradients over 4 steps | |
# ============ MEMORY OPTIMIZATION 8: GRADIENT CHECKPOINTING IN ARGS ============ | |
gradient_checkpointing=True, | |
# Other parameters | |
logging_steps=10, | |
save_strategy="epoch", | |
evaluation_strategy="epoch", | |
learning_rate=2e-4, | |
weight_decay=0.01, | |
warmup_ratio=0.03, | |
# ============ MEMORY OPTIMIZATION 9: REDUCED OPTIMIZER OVERHEAD ============ | |
optim="adamw_torch_fused", # More memory-efficient optimizer | |
# ============ MEMORY OPTIMIZATION 10: REDUCED LOGGING MEMORY ============ | |
report_to="none", # Disable extra logging to save memory | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_data["train"], | |
eval_dataset=train_data["validation"], | |
tokenizer=tokenizer, | |
) | |
# ============ MEMORY OPTIMIZATION 11: MANAGE CUDA CACHE ============ | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info("CUDA cache cleared before training") | |
# Start training | |
logger.info("Starting training...") | |
trainer.train() | |
# Save the model | |
model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
logger.info(f"Model saved to {output_dir}") | |
return model, tokenizer | |
# Gradio interface functions | |
def process_csv(file, teacher_col, student_col, num_samples): | |
df = load_data(file.name) | |
if df is None: | |
return "Error loading CSV file" | |
return f"CSV loaded successfully with {len(df)} rows" | |
def start_fine_tuning(file, teacher_col, student_col, model_id, epochs, batch_size, num_samples): | |
try: | |
# Load and process data | |
df = load_data(file.name) | |
if df is None: | |
return "Error loading CSV file" | |
# Prepare dataset | |
dataset = prepare_dataset(df, teacher_col, student_col, num_samples=int(num_samples)) | |
# Load tokenizer for preprocessing | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Tokenize dataset | |
tokenized_dataset = { | |
"train": tokenize_data(dataset["train"], tokenizer), | |
"validation": tokenize_data(dataset["test"], tokenizer), | |
} | |
# Create output directory | |
output_dir = "./fine_tuned_model" | |
os.makedirs(output_dir, exist_ok=True) | |
# Finetune model with memory optimizations | |
finetune_model( | |
model_id=model_id, | |
train_data=tokenized_dataset, | |
output_dir=output_dir, | |
epochs=int(epochs), | |
batch_size=int(batch_size), | |
) | |
return "Fine-tuning completed successfully!" | |
except Exception as e: | |
logger.error(f"Error during fine-tuning: {e}") | |
return f"Error during fine-tuning: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Teacher-Student Bot Fine-Tuning") | |
with gr.Tab("Upload Data"): | |
file_input = gr.File(label="Upload CSV File") | |
with gr.Row(): | |
teacher_col = gr.Textbox(label="Teacher Column", value="Unnamed: 0") | |
student_col = gr.Textbox(label="Student Column", value="idx") | |
num_samples = gr.Slider(label="Number of Samples", minimum=10, maximum=1000, value=100, step=10) | |
upload_btn = gr.Button("Process CSV") | |
csv_output = gr.Textbox(label="CSV Processing Result") | |
upload_btn.click(process_csv, inputs=[file_input, teacher_col, student_col, num_samples], outputs=csv_output) | |
with gr.Tab("Fine-Tune"): | |
model_id = gr.Textbox(label="Model ID", value="mistralai/Mistral-7B-v0.1") | |
with gr.Row(): | |
batch_size = gr.Number(label="Batch Size", value=8, info="Recommended: 8 or lower for 7B models") | |
epochs = gr.Number(label="Number of Epochs", value=2) | |
training_btn = gr.Button("Start Fine-Tuning") | |
training_output = gr.Textbox(label="Training Progress") | |
training_btn.click( | |
start_fine_tuning, | |
inputs=[file_input, teacher_col, student_col, model_id, epochs, batch_size, num_samples], | |
outputs=training_output | |
) | |
# Launch the app - REMOVED the spaces.zero.mount() call that was causing the error | |
demo.queue().launch(debug=True) |