codewithdark commited on
Commit
c4aca3b
·
verified ·
1 Parent(s): 6ed954d

Upload 3 files

Browse files
Files changed (3) hide show
  1. utils/check_dataset.py +272 -0
  2. utils/model.py +552 -0
  3. utils/sample_dataset.py +105 -0
utils/check_dataset.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ def validate_dataset(self, file_path, format_type):
5
+ """
6
+ Validate and analyze the dataset format, providing detailed feedback
7
+
8
+ Parameters:
9
+ file_path (str): Path to the dataset file
10
+ format_type (str): File format (csv, jsonl, text)
11
+
12
+ Returns:
13
+ dict: Validation results including format, structure, and statistics
14
+ """
15
+ import pandas as pd
16
+ import json
17
+ import os
18
+ import re
19
+
20
+ validation_results = {
21
+ "is_valid": False,
22
+ "format": format_type,
23
+ "detected_structure": None,
24
+ "statistics": {},
25
+ "issues": [],
26
+ "recommendations": []
27
+ }
28
+
29
+ try:
30
+ # Check if file exists
31
+ if not os.path.exists(file_path):
32
+ validation_results["issues"].append(f"File not found: {file_path}")
33
+ return validation_results
34
+
35
+ # Check file size
36
+ file_size = os.path.getsize(file_path)
37
+ validation_results["statistics"]["file_size_bytes"] = file_size
38
+ validation_results["statistics"]["file_size_mb"] = round(file_size / (1024 * 1024), 2)
39
+
40
+ if file_size == 0:
41
+ validation_results["issues"].append("File is empty")
42
+ return validation_results
43
+
44
+ if format_type == "csv":
45
+ # Load CSV file
46
+ try:
47
+ df = pd.read_csv(file_path)
48
+ validation_results["statistics"]["total_rows"] = len(df)
49
+ validation_results["statistics"]["total_columns"] = len(df.columns)
50
+ validation_results["statistics"]["column_names"] = list(df.columns)
51
+
52
+ # Check for null values
53
+ null_counts = df.isnull().sum().to_dict()
54
+ validation_results["statistics"]["null_counts"] = null_counts
55
+
56
+ if validation_results["statistics"]["total_rows"] == 0:
57
+ validation_results["issues"].append("CSV file has no rows")
58
+ return validation_results
59
+
60
+ # Detect structure
61
+ if "instruction" in df.columns and "response" in df.columns:
62
+ validation_results["detected_structure"] = "instruction-response"
63
+ validation_results["is_valid"] = True
64
+ elif "input" in df.columns and "output" in df.columns:
65
+ validation_results["detected_structure"] = "input-output"
66
+ validation_results["is_valid"] = True
67
+ elif "prompt" in df.columns and "completion" in df.columns:
68
+ validation_results["detected_structure"] = "prompt-completion"
69
+ validation_results["is_valid"] = True
70
+ elif "text" in df.columns:
71
+ validation_results["detected_structure"] = "text-only"
72
+ validation_results["is_valid"] = True
73
+ else:
74
+ # Look for text columns
75
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
76
+ if text_columns:
77
+ validation_results["detected_structure"] = "custom"
78
+ validation_results["statistics"]["potential_text_columns"] = text_columns
79
+ validation_results["is_valid"] = True
80
+ validation_results["recommendations"].append(
81
+ f"Consider renaming columns to match standard formats: instruction/response, input/output, prompt/completion, or text"
82
+ )
83
+ else:
84
+ validation_results["issues"].append("No suitable text columns found in CSV")
85
+
86
+ # Check for short text
87
+ if validation_results["detected_structure"] == "instruction-response":
88
+ short_instructions = (df["instruction"].str.len() < 10).sum()
89
+ short_responses = (df["response"].str.len() < 10).sum()
90
+ validation_results["statistics"]["short_instructions"] = short_instructions
91
+ validation_results["statistics"]["short_responses"] = short_responses
92
+
93
+ if short_instructions > 0:
94
+ validation_results["issues"].append(f"Found {short_instructions} instructions shorter than 10 characters")
95
+ if short_responses > 0:
96
+ validation_results["issues"].append(f"Found {short_responses} responses shorter than 10 characters")
97
+
98
+ except Exception as e:
99
+ validation_results["issues"].append(f"Error parsing CSV: {str(e)}")
100
+ return validation_results
101
+
102
+ elif format_type == "jsonl":
103
+ try:
104
+ # Load JSONL file
105
+ data = []
106
+ with open(file_path, 'r', encoding='utf-8') as f:
107
+ for line_num, line in enumerate(f, 1):
108
+ line = line.strip()
109
+ if not line:
110
+ continue
111
+ try:
112
+ json_obj = json.loads(line)
113
+ data.append(json_obj)
114
+ except json.JSONDecodeError:
115
+ validation_results["issues"].append(f"Invalid JSON at line {line_num}")
116
+
117
+ validation_results["statistics"]["total_examples"] = len(data)
118
+
119
+ if len(data) == 0:
120
+ validation_results["issues"].append("No valid JSON objects found in file")
121
+ return validation_results
122
+
123
+ # Get sample of keys from first object
124
+ if data:
125
+ validation_results["statistics"]["sample_keys"] = list(data[0].keys())
126
+
127
+ # Detect structure
128
+ structures = []
129
+ for item in data:
130
+ if "instruction" in item and "response" in item:
131
+ structures.append("instruction-response")
132
+ elif "input" in item and "output" in item:
133
+ structures.append("input-output")
134
+ elif "prompt" in item and "completion" in item:
135
+ structures.append("prompt-completion")
136
+ elif "text" in item:
137
+ structures.append("text-only")
138
+ else:
139
+ structures.append("custom")
140
+
141
+ # Count structure types
142
+ from collections import Counter
143
+ structure_counts = Counter(structures)
144
+ validation_results["statistics"]["structure_counts"] = structure_counts
145
+
146
+ # Set detected structure to most common
147
+ if structures:
148
+ most_common = structure_counts.most_common(1)[0][0]
149
+ validation_results["detected_structure"] = most_common
150
+ validation_results["is_valid"] = True
151
+
152
+ # Check if mixed
153
+ if len(structure_counts) > 1:
154
+ validation_results["issues"].append(f"Mixed structures detected: {dict(structure_counts)}")
155
+ validation_results["recommendations"].append("Consider standardizing all records to the same structure")
156
+
157
+ except Exception as e:
158
+ validation_results["issues"].append(f"Error parsing JSONL: {str(e)}")
159
+ return validation_results
160
+
161
+ elif format_type == "text":
162
+ try:
163
+ # Read text file
164
+ with open(file_path, 'r', encoding='utf-8') as f:
165
+ content = f.read()
166
+
167
+ # Get basic stats
168
+ total_chars = len(content)
169
+ total_words = len(content.split())
170
+ total_lines = content.count('\n') + 1
171
+
172
+ validation_results["statistics"]["total_characters"] = total_chars
173
+ validation_results["statistics"]["total_words"] = total_words
174
+ validation_results["statistics"]["total_lines"] = total_lines
175
+
176
+ # Check if it's a single large document or multiple examples
177
+ paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
178
+ validation_results["statistics"]["total_paragraphs"] = len(paragraphs)
179
+
180
+ # Try to detect structure
181
+ # Look for common patterns like "Q: ... A: ...", "Input: ... Output: ..."
182
+ has_qa_pattern = re.search(r"Q:.*?A:", content, re.DOTALL) is not None
183
+ has_input_output = re.search(r"Input:.*?Output:", content, re.DOTALL) is not None
184
+ has_prompt_completion = re.search(r"Prompt:.*?Completion:", content, re.DOTALL) is not None
185
+
186
+ if has_qa_pattern:
187
+ validation_results["detected_structure"] = "Q&A-format"
188
+ elif has_input_output:
189
+ validation_results["detected_structure"] = "input-output-format"
190
+ elif has_prompt_completion:
191
+ validation_results["detected_structure"] = "prompt-completion-format"
192
+ elif len(paragraphs) > 1:
193
+ validation_results["detected_structure"] = "paragraphs"
194
+ else:
195
+ validation_results["detected_structure"] = "continuous-text"
196
+
197
+ validation_results["is_valid"] = True
198
+
199
+ if validation_results["detected_structure"] == "continuous-text" and total_chars < 1000:
200
+ validation_results["issues"].append("Text file is very short for fine-tuning")
201
+ validation_results["recommendations"].append("Consider adding more content or examples")
202
+
203
+ except Exception as e:
204
+ validation_results["issues"].append(f"Error parsing text file: {str(e)}")
205
+ return validation_results
206
+ else:
207
+ validation_results["issues"].append(f"Unsupported file format: {format_type}")
208
+ return validation_results
209
+
210
+ # General recommendations
211
+ if validation_results["is_valid"]:
212
+ if not validation_results["issues"]:
213
+ validation_results["recommendations"].append("Dataset looks good and ready for fine-tuning!")
214
+ else:
215
+ validation_results["recommendations"].append("Address the issues above before proceeding with fine-tuning")
216
+
217
+ return validation_results
218
+
219
+ except Exception as e:
220
+ validation_results["issues"].append(f"Unexpected error: {str(e)}")
221
+ return validation_results
222
+
223
+ def generate_dataset_report(validation_results):
224
+ """
225
+ Generate a user-friendly report from validation results
226
+
227
+ Parameters:
228
+ validation_results (dict): Results from validate_dataset
229
+
230
+ Returns:
231
+ str: Formatted report
232
+ """
233
+ report = []
234
+
235
+ # Add header
236
+ report.append("# Dataset Validation Report")
237
+ report.append("")
238
+
239
+ # Add validation status
240
+ if validation_results["is_valid"]:
241
+ report.append("✅ Dataset is valid and can be used for fine-tuning")
242
+ else:
243
+ report.append("❌ Dataset has issues that need to be addressed")
244
+ report.append("")
245
+
246
+ # Add format info
247
+ report.append(f"**File Format:** {validation_results['format']}")
248
+ report.append(f"**Detected Structure:** {validation_results['detected_structure']}")
249
+ report.append("")
250
+
251
+ # Add statistics
252
+ report.append("## Statistics")
253
+ for key, value in validation_results["statistics"].items():
254
+ # Format the key for better readability
255
+ formatted_key = key.replace("_", " ").title()
256
+ report.append(f"- {formatted_key}: {value}")
257
+ report.append("")
258
+
259
+ # Add issues
260
+ if validation_results["issues"]:
261
+ report.append("## Issues")
262
+ for issue in validation_results["issues"]:
263
+ report.append(f"- ⚠️ {issue}")
264
+ report.append("")
265
+
266
+ # Add recommendations
267
+ if validation_results["recommendations"]:
268
+ report.append("## Recommendations")
269
+ for recommendation in validation_results["recommendations"]:
270
+ report.append(f"- 💡 {recommendation}")
271
+
272
+ return "\n".join(report)
utils/model.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+ from datetime import datetime
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TrainingArguments,
16
+ Trainer,
17
+ DataCollatorForLanguageModeling,
18
+ TrainerCallback
19
+ )
20
+ from peft import (
21
+ LoraConfig,
22
+ get_peft_model,
23
+ prepare_model_for_kbit_training
24
+ )
25
+ from datasets import load_dataset
26
+ from unsloth import FastModel
27
+
28
+
29
+ class GemmaFineTuning:
30
+ def __init__(self):
31
+ self.model = None
32
+ self.tokenizer = None
33
+ self.dataset = None
34
+ self.trainer = None
35
+ self.training_history = {"loss": [], "eval_loss": [], "step": []}
36
+ self.model_save_path = None
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ self.fourbit_models = [
40
+ "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
41
+ "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
42
+ "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
43
+ "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
44
+ ]
45
+ # Default hyperparameters
46
+ self.default_params = {
47
+ "model_name": "google/gemma-2b",
48
+ "learning_rate": 2e-5,
49
+ "batch_size": 8,
50
+ "epochs": 3,
51
+ "max_length": 512,
52
+ "weight_decay": 0.01,
53
+ "warmup_ratio": 0.1,
54
+ "use_lora": True,
55
+ "lora_r": 16,
56
+ "lora_alpha": 32,
57
+ "lora_dropout": 0.05,
58
+ "eval_ratio": 0.1,
59
+ }
60
+
61
+ def load_model_and_tokenizer(self, model_name: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
62
+ """Load the model and tokenizer"""
63
+ try:
64
+ # Map UI model names to actual model IDs
65
+ model_mapping = {
66
+ "google/gemma-2b": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
67
+ "google/gemma-7b": "unsloth/gemma-7b-it-unsloth-bnb-4bit",
68
+ "google/gemma-2b-it": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
69
+ "google/gemma-7b-it": "unsloth/gemma-7b-it-unsloth-bnb-4bit"
70
+ }
71
+
72
+ actual_model_name = model_mapping.get(model_name, model_name)
73
+
74
+ model, tokenizer = FastModel.from_pretrained(
75
+ model_name=actual_model_name,
76
+ max_seq_length=2048,
77
+ load_in_4bit=True,
78
+ load_in_8bit=False,
79
+ full_finetuning=False,
80
+ )
81
+
82
+ # Move model to device
83
+ model = model.to(self.device)
84
+ return model, tokenizer
85
+
86
+ except Exception as e:
87
+ raise ValueError(f"Error loading model {model_name}: {str(e)}")
88
+
89
+ def prepare_dataset(self, file_path, format_type):
90
+ """
91
+ Prepare and normalize dataset from various formats
92
+
93
+ Parameters:
94
+ file_path (str): Path to the dataset file
95
+ format_type (str): File format (csv, jsonl, text)
96
+
97
+ Returns:
98
+ dict: Dataset dictionary with train split
99
+ """
100
+ import pandas as pd
101
+ import json
102
+ import os
103
+ from datasets import Dataset, DatasetDict
104
+
105
+ try:
106
+ if format_type == "csv":
107
+ # Load CSV file
108
+ df = pd.read_csv(file_path)
109
+
110
+ # Check if the CSV has the expected columns (looking for either instruction-response pairs or text)
111
+ if "instruction" in df.columns and "response" in df.columns:
112
+ # Instruction-following dataset format
113
+ dataset_format = "instruction-response"
114
+ # Ensure no nulls
115
+ df = df.dropna(subset=["instruction", "response"])
116
+ # Create formatted text by combining instruction and response
117
+ df["text"] = df.apply(lambda row: f"<instruction>{row['instruction']}</instruction>\n<response>{row['response']}</response>", axis=1)
118
+ elif "input" in df.columns and "output" in df.columns:
119
+ # Another common format
120
+ dataset_format = "input-output"
121
+ df = df.dropna(subset=["input", "output"])
122
+ df["text"] = df.apply(lambda row: f"<input>{row['input']}</input>\n<output>{row['output']}</output>", axis=1)
123
+ elif "prompt" in df.columns and "completion" in df.columns:
124
+ # OpenAI-style format
125
+ dataset_format = "prompt-completion"
126
+ df = df.dropna(subset=["prompt", "completion"])
127
+ df["text"] = df.apply(lambda row: f"<prompt>{row['prompt']}</prompt>\n<completion>{row['completion']}</completion>", axis=1)
128
+ elif "text" in df.columns:
129
+ # Simple text format
130
+ dataset_format = "text-only"
131
+ df = df.dropna(subset=["text"])
132
+ else:
133
+ # Try to infer format from the first text column
134
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
135
+ if len(text_columns) > 0:
136
+ dataset_format = "inferred"
137
+ df["text"] = df[text_columns[0]]
138
+ df = df.dropna(subset=["text"])
139
+ else:
140
+ raise ValueError("CSV file must contain either 'instruction'/'response', 'input'/'output', 'prompt'/'completion', or 'text' columns")
141
+
142
+ # Create dataset
143
+ dataset = Dataset.from_pandas(df)
144
+
145
+ elif format_type == "jsonl":
146
+ # Load JSONL file
147
+ with open(file_path, 'r', encoding='utf-8') as f:
148
+ data = [json.loads(line) for line in f if line.strip()]
149
+
150
+ # Check and normalize the format
151
+ normalized_data = []
152
+ for item in data:
153
+ normalized_item = {}
154
+
155
+ # Try to find either instruction-response pairs or text
156
+ if "instruction" in item and "response" in item:
157
+ normalized_item["text"] = f"<instruction>{item['instruction']}</instruction>\n<response>{item['response']}</response>"
158
+ normalized_item["instruction"] = item["instruction"]
159
+ normalized_item["response"] = item["response"]
160
+ elif "input" in item and "output" in item:
161
+ normalized_item["text"] = f"<input>{item['input']}</input>\n<output>{item['output']}</output>"
162
+ normalized_item["input"] = item["input"]
163
+ normalized_item["output"] = item["output"]
164
+ elif "prompt" in item and "completion" in item:
165
+ normalized_item["text"] = f"<prompt>{item['prompt']}</prompt>\n<completion>{item['completion']}</completion>"
166
+ normalized_item["prompt"] = item["prompt"]
167
+ normalized_item["completion"] = item["completion"]
168
+ elif "text" in item:
169
+ normalized_item["text"] = item["text"]
170
+ else:
171
+ # Try to infer from the first string value
172
+ text_keys = [k for k, v in item.items() if isinstance(v, str) and len(v.strip()) > 0]
173
+ if text_keys:
174
+ normalized_item["text"] = item[text_keys[0]]
175
+ else:
176
+ continue # Skip this item if no usable text found
177
+
178
+ normalized_data.append(normalized_item)
179
+
180
+ if not normalized_data:
181
+ raise ValueError("No valid data items found in the JSONL file")
182
+
183
+ # Create dataset
184
+ dataset = Dataset.from_list(normalized_data)
185
+
186
+ elif format_type == "text":
187
+ # For text files, split by newlines and create entries
188
+ with open(file_path, 'r', encoding='utf-8') as f:
189
+ content = f.read()
190
+
191
+ # Check if it's a single large document or multiple examples
192
+ # If file size > 10KB, try to split into paragraphs
193
+ if os.path.getsize(file_path) > 10240:
194
+ # Split by double newlines (paragraphs)
195
+ paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
196
+ # Filter out very short paragraphs (less than 20 chars)
197
+ paragraphs = [p for p in paragraphs if len(p) >= 20]
198
+ data = [{"text": p} for p in paragraphs]
199
+ else:
200
+ # Treat as a single example
201
+ data = [{"text": content}]
202
+
203
+ # Create dataset
204
+ dataset = Dataset.from_list(data)
205
+ else:
206
+ raise ValueError(f"Unsupported file format: {format_type}")
207
+
208
+ # Return as a DatasetDict with a train split
209
+ return DatasetDict({"train": dataset})
210
+
211
+ except Exception as e:
212
+ import traceback
213
+ error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
214
+ print(error_msg)
215
+ raise ValueError(error_msg)
216
+
217
+ def chunk_text(self, text: str, chunk_size: int) -> List[str]:
218
+ """Split text into chunks of approximately chunk_size characters"""
219
+ words = text.split()
220
+ chunks = []
221
+ current_chunk = []
222
+ current_length = 0
223
+
224
+ for word in words:
225
+ if current_length + len(word) + 1 > chunk_size and current_chunk:
226
+ chunks.append(" ".join(current_chunk))
227
+ current_chunk = [word]
228
+ current_length = len(word)
229
+ else:
230
+ current_chunk.append(word)
231
+ current_length += len(word) + 1 # +1 for the space
232
+
233
+ if current_chunk:
234
+ chunks.append(" ".join(current_chunk))
235
+
236
+ return chunks
237
+
238
+ def preprocess_dataset(self, dataset, tokenizer, max_length):
239
+ """
240
+ Tokenize and format the dataset for training
241
+
242
+ Parameters:
243
+ dataset (DatasetDict): Dataset dictionary with train and validation splits
244
+ tokenizer: HuggingFace tokenizer
245
+ max_length (int): Maximum sequence length
246
+
247
+ Returns:
248
+ DatasetDict: Tokenized dataset ready for training
249
+ """
250
+ def tokenize_function(examples):
251
+ # Check if the dataset has both input and target text columns
252
+ if "text" in examples:
253
+ texts = examples["text"]
254
+ inputs = tokenizer(
255
+ texts,
256
+ padding="max_length",
257
+ truncation=True,
258
+ max_length=max_length,
259
+ return_tensors="pt"
260
+ )
261
+ inputs["labels"] = inputs["input_ids"].clone()
262
+ return inputs
263
+ else:
264
+ # Try to find text columns based on common naming patterns
265
+ potential_text_cols = [col for col in examples.keys() if isinstance(examples[col], list) and
266
+ all(isinstance(item, str) for item in examples[col])]
267
+
268
+ if not potential_text_cols:
269
+ raise ValueError("No suitable text columns found in the dataset")
270
+
271
+ # Use the first text column found
272
+ text_col = potential_text_cols[0]
273
+ texts = examples[text_col]
274
+
275
+ inputs = tokenizer(
276
+ texts,
277
+ padding="max_length",
278
+ truncation=True,
279
+ max_length=max_length,
280
+ return_tensors="pt"
281
+ )
282
+ inputs["labels"] = inputs["input_ids"].clone()
283
+ return inputs
284
+
285
+ # Apply tokenization to each split
286
+ tokenized_dataset = {}
287
+ for split, ds in dataset.items():
288
+ tokenized_dataset[split] = ds.map(
289
+ tokenize_function,
290
+ batched=True,
291
+ remove_columns=ds.column_names
292
+ )
293
+
294
+ return tokenized_dataset
295
+
296
+ def prepare_training_args(self, params: Dict) -> TrainingArguments:
297
+ """Set up training arguments based on hyperparameters"""
298
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
299
+ self.model_save_path = f"gemma-finetuned-{timestamp}"
300
+
301
+ args = TrainingArguments(
302
+ output_dir=self.model_save_path,
303
+ per_device_train_batch_size=params.get("batch_size", self.default_params["batch_size"]),
304
+ gradient_accumulation_steps=4,
305
+ per_device_eval_batch_size=params.get("batch_size", self.default_params["batch_size"]),
306
+ learning_rate=params.get("learning_rate", self.default_params["learning_rate"]),
307
+ num_train_epochs=params.get("epochs", self.default_params["epochs"]),
308
+ warmup_ratio=params.get("warmup_ratio", self.default_params["warmup_ratio"]),
309
+ weight_decay=params.get("weight_decay", self.default_params["weight_decay"]),
310
+ logging_steps=1,
311
+ evaluation_strategy="steps" if params.get("eval_ratio", 0) > 0 else "no",
312
+ eval_steps=100 if params.get("eval_ratio", 0) > 0 else None,
313
+ save_strategy="steps",
314
+ save_steps=100,
315
+ save_total_limit=2,
316
+ load_best_model_at_end=True if params.get("eval_ratio", 0) > 0 else False,
317
+ report_to="none"
318
+ )
319
+ return args
320
+
321
+ def train(self, training_params: Dict) -> str:
322
+ """Main training method that handles the complete training pipeline"""
323
+ try:
324
+ if self.dataset is None:
325
+ return "Error: No dataset loaded. Please preprocess a dataset first."
326
+
327
+ # Reset training history
328
+ self.training_history = {"loss": [], "eval_loss": [], "step": []}
329
+
330
+ # Load model and tokenizer if not already loaded or if model name changed
331
+ current_model_name = training_params.get("model_name", self.default_params["model_name"])
332
+ if (self.model is None or self.tokenizer is None or
333
+ getattr(self, '_current_model_name', None) != current_model_name):
334
+
335
+ self.model, self.tokenizer = self.load_model_and_tokenizer(current_model_name)
336
+ self._current_model_name = current_model_name
337
+
338
+ # Create validation split if needed
339
+ eval_ratio = float(training_params.get("eval_ratio", self.default_params["eval_ratio"]))
340
+ if eval_ratio > 0 and "validation" not in self.dataset:
341
+ split_dataset = self.dataset["train"].train_test_split(test_size=eval_ratio)
342
+ self.dataset = {
343
+ "train": split_dataset["train"],
344
+ "validation": split_dataset["test"]
345
+ }
346
+
347
+ # Apply LoRA if selected
348
+ if training_params.get("use_lora", self.default_params["use_lora"]):
349
+ self.model = self.setup_lora(self.model, {
350
+ "lora_r": int(training_params.get("lora_r", self.default_params["lora_r"])),
351
+ "lora_alpha": int(training_params.get("lora_alpha", self.default_params["lora_alpha"])),
352
+ "lora_dropout": float(training_params.get("lora_dropout", self.default_params["lora_dropout"]))
353
+ })
354
+
355
+ # Preprocess dataset
356
+ max_length = int(training_params.get("max_length", self.default_params["max_length"]))
357
+ tokenized_dataset = self.preprocess_dataset(self.dataset, self.tokenizer, max_length)
358
+
359
+ # Update training arguments with proper type conversion
360
+ training_args = self.prepare_training_args({
361
+ "batch_size": int(training_params.get("batch_size", self.default_params["batch_size"])),
362
+ "learning_rate": float(training_params.get("learning_rate", self.default_params["learning_rate"])),
363
+ "epochs": int(training_params.get("epochs", self.default_params["epochs"])),
364
+ "weight_decay": float(training_params.get("weight_decay", self.default_params["weight_decay"])),
365
+ "warmup_ratio": float(training_params.get("warmup_ratio", self.default_params["warmup_ratio"])),
366
+ "eval_ratio": eval_ratio
367
+ })
368
+
369
+ # Create trainer with proper callback
370
+ self.trainer = self.create_trainer(
371
+ self.model,
372
+ self.tokenizer,
373
+ tokenized_dataset,
374
+ training_args
375
+ )
376
+
377
+ # Start training
378
+ self.trainer.train()
379
+
380
+ # Save the model
381
+ save_path = f"models/gemma-finetuned-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
382
+ os.makedirs(save_path, exist_ok=True)
383
+ self.trainer.save_model(save_path)
384
+ self.tokenizer.save_pretrained(save_path)
385
+ self.model_save_path = save_path
386
+
387
+ return f"Training completed successfully! Model saved to {save_path}"
388
+
389
+ except Exception as e:
390
+ import traceback
391
+ return f"Error during training: {str(e)}\n{traceback.format_exc()}"
392
+
393
+ def setup_lora(self, model, params: Dict) -> torch.nn.Module:
394
+ """Configure LoRA for parameter-efficient fine-tuning"""
395
+ # Prepare the model for training if using 8-bit or 4-bit quantization
396
+ if hasattr(model, "is_quantized") and model.is_quantized:
397
+ model = prepare_model_for_kbit_training(model)
398
+
399
+ lora_config = LoraConfig(
400
+ r=params["lora_r"],
401
+ lora_alpha=params["lora_alpha"],
402
+ target_modules=["q_proj", "k_proj", "v_proj"],
403
+ lora_dropout=params["lora_dropout"],
404
+ bias="none",
405
+ task_type="CAUSAL_LM",
406
+ )
407
+
408
+ model = FastModel.get_peft_model(
409
+ model,
410
+ finetune_vision_layers = False, # Turn off for just text!
411
+ finetune_language_layers = True, # Should leave on!
412
+ finetune_attention_modules = True, # Attention good for GRPO
413
+ finetune_mlp_modules = True, # SHould leave on always!
414
+
415
+ r = 8, # Larger = higher accuracy, but might overfit
416
+ lora_alpha = 8, # Recommended alpha == r at least
417
+ lora_dropout = 0,
418
+ bias = "none",
419
+ random_state = 3407,
420
+ )
421
+ model.print_trainable_parameters()
422
+ model = model.to(self.device)
423
+ return model
424
+
425
+ def create_trainer(self, model, tokenizer, dataset, training_args):
426
+ """Set up the Trainer for model fine-tuning"""
427
+ # Create data collator
428
+ data_collator = DataCollatorForLanguageModeling(
429
+ tokenizer=tokenizer,
430
+ mlm=False
431
+ )
432
+
433
+ # Custom callback to store training history
434
+ class TrainingCallback(TrainerCallback):
435
+ def __init__(self, app):
436
+ self.app = app
437
+
438
+ def on_log(self, args, state, control, logs=None, **kwargs):
439
+ if logs:
440
+ for key in ['loss', 'eval_loss']:
441
+ if key in logs:
442
+ self.app.training_history[key].append(logs[key])
443
+ if 'step' in logs:
444
+ self.app.training_history['step'].append(logs['step'])
445
+
446
+ # Create trainer
447
+ trainer = Trainer(
448
+ model=model,
449
+ args=training_args,
450
+ train_dataset=dataset["train"],
451
+ eval_dataset=dataset["validation"] if "validation" in dataset else None,
452
+ data_collator=data_collator,
453
+ callbacks=[TrainingCallback]
454
+ )
455
+
456
+ return trainer
457
+
458
+ def plot_training_progress(self):
459
+ """Generate a plot of the training progress"""
460
+ if not self.training_history["loss"]:
461
+ return None
462
+
463
+ plt.figure(figsize=(10, 6))
464
+ plt.plot(self.training_history["step"], self.training_history["loss"], label="Training Loss")
465
+
466
+ if self.training_history["eval_loss"]:
467
+ # Get the steps where eval happened
468
+ eval_steps = self.training_history["step"][:len(self.training_history["eval_loss"])]
469
+ plt.plot(eval_steps, self.training_history["eval_loss"], label="Validation Loss", linestyle="--")
470
+
471
+ plt.xlabel("Training Steps")
472
+ plt.ylabel("Loss")
473
+ plt.title("Training Progress")
474
+ plt.legend()
475
+ plt.grid(True)
476
+
477
+ return plt
478
+
479
+ def export_model(self, output_format: str) -> str:
480
+ """Export the fine-tuned model in various formats"""
481
+ if self.model is None or self.model_save_path is None:
482
+ return "No model has been trained yet."
483
+
484
+ export_path = f"{self.model_save_path}/exported_{output_format}"
485
+ os.makedirs(export_path, exist_ok=True)
486
+
487
+ if output_format == "pytorch":
488
+ # Save as PyTorch format
489
+ self.model.save_pretrained(export_path)
490
+ self.tokenizer.save_pretrained(export_path)
491
+ return f"Model exported in PyTorch format to {export_path}"
492
+
493
+ elif output_format == "tensorflow":
494
+ # Convert to TensorFlow format
495
+ try:
496
+ from transformers.modeling_tf_utils import convert_pt_to_tf
497
+
498
+ # First save the PyTorch model
499
+ self.model.save_pretrained(export_path)
500
+ self.tokenizer.save_pretrained(export_path)
501
+
502
+ # Then convert to TF SavedModel format
503
+ tf_model = convert_pt_to_tf(self.model)
504
+ tf_model.save_pretrained(f"{export_path}/tf_saved_model")
505
+
506
+ return f"Model exported in TensorFlow format to {export_path}/tf_saved_model"
507
+ except Exception as e:
508
+ return f"Failed to export as TensorFlow model: {str(e)}"
509
+
510
+ elif output_format == "gguf":
511
+ # Export as GGUF format for local inference
512
+ try:
513
+ import subprocess
514
+
515
+ # First save the model in PyTorch format
516
+ self.model.save_pretrained(export_path)
517
+ self.tokenizer.save_pretrained(export_path)
518
+
519
+ # Use llama.cpp's conversion script (must be installed)
520
+ subprocess.run([
521
+ "python", "-m", "llama_cpp.convert",
522
+ "--outtype", "gguf",
523
+ "--outfile", f"{export_path}/model.gguf",
524
+ export_path
525
+ ])
526
+
527
+ return f"Model exported in GGUF format to {export_path}/model.gguf"
528
+ except Exception as e:
529
+ return f"Failed to export as GGUF model: {str(e)}"
530
+
531
+ else:
532
+ return f"Unsupported export format: {output_format}"
533
+
534
+ def generate_text(self, prompt: str, max_length: int = 100) -> str:
535
+ """Generate text using the fine-tuned model"""
536
+ if self.model is None or self.tokenizer is None:
537
+ return "No model has been loaded or fine-tuned yet."
538
+
539
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
540
+
541
+ with torch.no_grad():
542
+ outputs = self.model.generate(
543
+ **inputs,
544
+ max_length=max_length + inputs.input_ids.shape[1],
545
+ temperature=0.7,
546
+ top_p=0.9,
547
+ do_sample=True,
548
+ pad_token_id=self.tokenizer.pad_token_id
549
+ )
550
+
551
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
552
+ return generated_text
utils/sample_dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import os
4
+ import random
5
+
6
+ def generate_sample_datasets(output_dir="./sample_datasets"):
7
+ """
8
+ Generate sample datasets in CSV, JSONL, and text formats for testing the Gemma fine-tuning UI
9
+
10
+ Parameters:
11
+ output_dir (str): Directory to save the sample datasets
12
+ """
13
+ # Create output directory if it doesn't exist
14
+ os.makedirs(output_dir, exist_ok=True)
15
+
16
+ # Sample data for instruction-response pairs
17
+ instruction_data = [
18
+ {"instruction": "Explain the concept of machine learning in simple terms.",
19
+ "response": "Machine learning is when computers learn from examples rather than being explicitly programmed. It's like teaching a child by showing examples instead of giving exact rules."},
20
+ {"instruction": "What are the health benefits of regular exercise?",
21
+ "response": "Regular exercise improves cardiovascular health, strengthens muscles and bones, helps control weight, reduces stress and anxiety, improves sleep quality, and increases overall energy levels."},
22
+ {"instruction": "How do I make a basic pasta sauce?",
23
+ "response": "For a basic tomato pasta sauce, sauté minced garlic and onions in olive oil until soft. Add canned crushed tomatoes, salt, pepper, and Italian herbs. Simmer for 20-30 minutes, stirring occasionally. Add basil at the end for fresh flavor."},
24
+ {"instruction": "Summarize the water cycle.",
25
+ "response": "The water cycle is the continuous movement of water on, above, and below Earth's surface. It involves evaporation (water turning to vapor), condensation (vapor forming clouds), precipitation (rain or snow), and collection (water returning to bodies of water)."},
26
+ {"instruction": "What is the difference between a simile and a metaphor?",
27
+ "response": "A simile compares two things using 'like' or 'as' (e.g., 'as brave as a lion'). A metaphor directly states that one thing is another (e.g., 'he is a lion in battle'). Both are figurative language techniques used to create vivid descriptions."}
28
+ ]
29
+
30
+ # Generate more instruction-response pairs
31
+ topics = ["history", "science", "literature", "cooking", "technology", "health", "travel", "sports", "music", "art"]
32
+ question_starters = ["Explain", "Describe", "How to", "What is", "Why does", "Compare", "Summarize", "List ways to", "Define", "Analyze"]
33
+
34
+ for _ in range(20):
35
+ topic = random.choice(topics)
36
+ starter = random.choice(question_starters)
37
+ instruction = f"{starter} {topic.lower()} {random.choice(['concepts', 'principles', 'ideas', 'techniques', 'methods'])}"
38
+ response = f"This is a sample response about {topic} that would be more detailed in a real dataset. It would contain multiple sentences explaining {topic} concepts in depth."
39
+ instruction_data.append({"instruction": instruction, "response": response})
40
+
41
+ # Create a dictionary to store sample datasets
42
+ datasets = {}
43
+
44
+ # 1. Create CSV in instruction-response format
45
+ df_instruction = pd.DataFrame(instruction_data)
46
+ datasets["instruction_response.csv"] = df_instruction
47
+
48
+ # 2. Create CSV in input-output format
49
+ input_output_data = [{"input": item["instruction"], "output": item["response"]} for item in instruction_data]
50
+ df_input_output = pd.DataFrame(input_output_data)
51
+ datasets["input_output.csv"] = df_input_output
52
+
53
+ # 3. Create CSV in text-only format
54
+ text_data = [{"text": f"Q: {item['instruction']}\nA: {item['response']}"} for item in instruction_data]
55
+ df_text = pd.DataFrame(text_data)
56
+ datasets["text_only.csv"] = df_text
57
+
58
+ # 4. Create CSV with non-standard format
59
+ custom_data = [{"question": item["instruction"], "answer": item["response"]} for item in instruction_data]
60
+ df_custom = pd.DataFrame(custom_data)
61
+ datasets["custom_format.csv"] = df_custom
62
+
63
+ # 5. Create JSONL in instruction-response format
64
+ jsonl_instruction = instruction_data
65
+ datasets["instruction_response.jsonl"] = jsonl_instruction
66
+
67
+ # 6. Create JSONL in prompt-completion format
68
+ prompt_completion_data = [{"prompt": item["instruction"], "completion": item["response"]} for item in instruction_data]
69
+ datasets["prompt_completion.jsonl"] = prompt_completion_data
70
+
71
+ # 7. Create JSONL with non-standard format
72
+ jsonl_custom = [{"query": item["instruction"], "result": item["response"]} for item in instruction_data]
73
+ datasets["custom_format.jsonl"] = jsonl_custom
74
+
75
+ # 8. Create text format (paragraphs)
76
+ text_paragraphs = "\n\n".join([f"Q: {item['instruction']}\nA: {item['response']}" for item in instruction_data])
77
+ datasets["paragraphs.txt"] = text_paragraphs
78
+
79
+ # 9. Create text format (single examples per line)
80
+ text_lines = "\n".join([f"{item['instruction']} => {item['response']}" for item in instruction_data])
81
+ datasets["single_lines.txt"] = text_lines
82
+
83
+ # Save all datasets
84
+ for filename, data in datasets.items():
85
+ file_path = os.path.join(output_dir, filename)
86
+
87
+ if filename.endswith('.csv'):
88
+ data.to_csv(file_path, index=False)
89
+ elif filename.endswith('.jsonl'):
90
+ with open(file_path, 'w', encoding='utf-8') as f:
91
+ for item in data:
92
+ f.write(json.dumps(item) + '\n')
93
+ elif filename.endswith('.txt'):
94
+ with open(file_path, 'w', encoding='utf-8') as f:
95
+ f.write(data)
96
+
97
+ print(f"Sample datasets generated in {output_dir}")
98
+ return list(datasets.keys())
99
+
100
+ # if __name__ == "__main__":
101
+ # # Generate sample datasets
102
+ # generated_files = generate_sample_datasets()
103
+ # print(f"Generated {len(generated_files)} sample dataset files:")
104
+ # for file in generated_files:
105
+ # print(f" - {file}")