Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- utils/check_dataset.py +272 -0
- utils/model.py +552 -0
- utils/sample_dataset.py +105 -0
utils/check_dataset.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
def validate_dataset(self, file_path, format_type):
|
5 |
+
"""
|
6 |
+
Validate and analyze the dataset format, providing detailed feedback
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
file_path (str): Path to the dataset file
|
10 |
+
format_type (str): File format (csv, jsonl, text)
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
dict: Validation results including format, structure, and statistics
|
14 |
+
"""
|
15 |
+
import pandas as pd
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
|
20 |
+
validation_results = {
|
21 |
+
"is_valid": False,
|
22 |
+
"format": format_type,
|
23 |
+
"detected_structure": None,
|
24 |
+
"statistics": {},
|
25 |
+
"issues": [],
|
26 |
+
"recommendations": []
|
27 |
+
}
|
28 |
+
|
29 |
+
try:
|
30 |
+
# Check if file exists
|
31 |
+
if not os.path.exists(file_path):
|
32 |
+
validation_results["issues"].append(f"File not found: {file_path}")
|
33 |
+
return validation_results
|
34 |
+
|
35 |
+
# Check file size
|
36 |
+
file_size = os.path.getsize(file_path)
|
37 |
+
validation_results["statistics"]["file_size_bytes"] = file_size
|
38 |
+
validation_results["statistics"]["file_size_mb"] = round(file_size / (1024 * 1024), 2)
|
39 |
+
|
40 |
+
if file_size == 0:
|
41 |
+
validation_results["issues"].append("File is empty")
|
42 |
+
return validation_results
|
43 |
+
|
44 |
+
if format_type == "csv":
|
45 |
+
# Load CSV file
|
46 |
+
try:
|
47 |
+
df = pd.read_csv(file_path)
|
48 |
+
validation_results["statistics"]["total_rows"] = len(df)
|
49 |
+
validation_results["statistics"]["total_columns"] = len(df.columns)
|
50 |
+
validation_results["statistics"]["column_names"] = list(df.columns)
|
51 |
+
|
52 |
+
# Check for null values
|
53 |
+
null_counts = df.isnull().sum().to_dict()
|
54 |
+
validation_results["statistics"]["null_counts"] = null_counts
|
55 |
+
|
56 |
+
if validation_results["statistics"]["total_rows"] == 0:
|
57 |
+
validation_results["issues"].append("CSV file has no rows")
|
58 |
+
return validation_results
|
59 |
+
|
60 |
+
# Detect structure
|
61 |
+
if "instruction" in df.columns and "response" in df.columns:
|
62 |
+
validation_results["detected_structure"] = "instruction-response"
|
63 |
+
validation_results["is_valid"] = True
|
64 |
+
elif "input" in df.columns and "output" in df.columns:
|
65 |
+
validation_results["detected_structure"] = "input-output"
|
66 |
+
validation_results["is_valid"] = True
|
67 |
+
elif "prompt" in df.columns and "completion" in df.columns:
|
68 |
+
validation_results["detected_structure"] = "prompt-completion"
|
69 |
+
validation_results["is_valid"] = True
|
70 |
+
elif "text" in df.columns:
|
71 |
+
validation_results["detected_structure"] = "text-only"
|
72 |
+
validation_results["is_valid"] = True
|
73 |
+
else:
|
74 |
+
# Look for text columns
|
75 |
+
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
76 |
+
if text_columns:
|
77 |
+
validation_results["detected_structure"] = "custom"
|
78 |
+
validation_results["statistics"]["potential_text_columns"] = text_columns
|
79 |
+
validation_results["is_valid"] = True
|
80 |
+
validation_results["recommendations"].append(
|
81 |
+
f"Consider renaming columns to match standard formats: instruction/response, input/output, prompt/completion, or text"
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
validation_results["issues"].append("No suitable text columns found in CSV")
|
85 |
+
|
86 |
+
# Check for short text
|
87 |
+
if validation_results["detected_structure"] == "instruction-response":
|
88 |
+
short_instructions = (df["instruction"].str.len() < 10).sum()
|
89 |
+
short_responses = (df["response"].str.len() < 10).sum()
|
90 |
+
validation_results["statistics"]["short_instructions"] = short_instructions
|
91 |
+
validation_results["statistics"]["short_responses"] = short_responses
|
92 |
+
|
93 |
+
if short_instructions > 0:
|
94 |
+
validation_results["issues"].append(f"Found {short_instructions} instructions shorter than 10 characters")
|
95 |
+
if short_responses > 0:
|
96 |
+
validation_results["issues"].append(f"Found {short_responses} responses shorter than 10 characters")
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
validation_results["issues"].append(f"Error parsing CSV: {str(e)}")
|
100 |
+
return validation_results
|
101 |
+
|
102 |
+
elif format_type == "jsonl":
|
103 |
+
try:
|
104 |
+
# Load JSONL file
|
105 |
+
data = []
|
106 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
107 |
+
for line_num, line in enumerate(f, 1):
|
108 |
+
line = line.strip()
|
109 |
+
if not line:
|
110 |
+
continue
|
111 |
+
try:
|
112 |
+
json_obj = json.loads(line)
|
113 |
+
data.append(json_obj)
|
114 |
+
except json.JSONDecodeError:
|
115 |
+
validation_results["issues"].append(f"Invalid JSON at line {line_num}")
|
116 |
+
|
117 |
+
validation_results["statistics"]["total_examples"] = len(data)
|
118 |
+
|
119 |
+
if len(data) == 0:
|
120 |
+
validation_results["issues"].append("No valid JSON objects found in file")
|
121 |
+
return validation_results
|
122 |
+
|
123 |
+
# Get sample of keys from first object
|
124 |
+
if data:
|
125 |
+
validation_results["statistics"]["sample_keys"] = list(data[0].keys())
|
126 |
+
|
127 |
+
# Detect structure
|
128 |
+
structures = []
|
129 |
+
for item in data:
|
130 |
+
if "instruction" in item and "response" in item:
|
131 |
+
structures.append("instruction-response")
|
132 |
+
elif "input" in item and "output" in item:
|
133 |
+
structures.append("input-output")
|
134 |
+
elif "prompt" in item and "completion" in item:
|
135 |
+
structures.append("prompt-completion")
|
136 |
+
elif "text" in item:
|
137 |
+
structures.append("text-only")
|
138 |
+
else:
|
139 |
+
structures.append("custom")
|
140 |
+
|
141 |
+
# Count structure types
|
142 |
+
from collections import Counter
|
143 |
+
structure_counts = Counter(structures)
|
144 |
+
validation_results["statistics"]["structure_counts"] = structure_counts
|
145 |
+
|
146 |
+
# Set detected structure to most common
|
147 |
+
if structures:
|
148 |
+
most_common = structure_counts.most_common(1)[0][0]
|
149 |
+
validation_results["detected_structure"] = most_common
|
150 |
+
validation_results["is_valid"] = True
|
151 |
+
|
152 |
+
# Check if mixed
|
153 |
+
if len(structure_counts) > 1:
|
154 |
+
validation_results["issues"].append(f"Mixed structures detected: {dict(structure_counts)}")
|
155 |
+
validation_results["recommendations"].append("Consider standardizing all records to the same structure")
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
validation_results["issues"].append(f"Error parsing JSONL: {str(e)}")
|
159 |
+
return validation_results
|
160 |
+
|
161 |
+
elif format_type == "text":
|
162 |
+
try:
|
163 |
+
# Read text file
|
164 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
165 |
+
content = f.read()
|
166 |
+
|
167 |
+
# Get basic stats
|
168 |
+
total_chars = len(content)
|
169 |
+
total_words = len(content.split())
|
170 |
+
total_lines = content.count('\n') + 1
|
171 |
+
|
172 |
+
validation_results["statistics"]["total_characters"] = total_chars
|
173 |
+
validation_results["statistics"]["total_words"] = total_words
|
174 |
+
validation_results["statistics"]["total_lines"] = total_lines
|
175 |
+
|
176 |
+
# Check if it's a single large document or multiple examples
|
177 |
+
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
178 |
+
validation_results["statistics"]["total_paragraphs"] = len(paragraphs)
|
179 |
+
|
180 |
+
# Try to detect structure
|
181 |
+
# Look for common patterns like "Q: ... A: ...", "Input: ... Output: ..."
|
182 |
+
has_qa_pattern = re.search(r"Q:.*?A:", content, re.DOTALL) is not None
|
183 |
+
has_input_output = re.search(r"Input:.*?Output:", content, re.DOTALL) is not None
|
184 |
+
has_prompt_completion = re.search(r"Prompt:.*?Completion:", content, re.DOTALL) is not None
|
185 |
+
|
186 |
+
if has_qa_pattern:
|
187 |
+
validation_results["detected_structure"] = "Q&A-format"
|
188 |
+
elif has_input_output:
|
189 |
+
validation_results["detected_structure"] = "input-output-format"
|
190 |
+
elif has_prompt_completion:
|
191 |
+
validation_results["detected_structure"] = "prompt-completion-format"
|
192 |
+
elif len(paragraphs) > 1:
|
193 |
+
validation_results["detected_structure"] = "paragraphs"
|
194 |
+
else:
|
195 |
+
validation_results["detected_structure"] = "continuous-text"
|
196 |
+
|
197 |
+
validation_results["is_valid"] = True
|
198 |
+
|
199 |
+
if validation_results["detected_structure"] == "continuous-text" and total_chars < 1000:
|
200 |
+
validation_results["issues"].append("Text file is very short for fine-tuning")
|
201 |
+
validation_results["recommendations"].append("Consider adding more content or examples")
|
202 |
+
|
203 |
+
except Exception as e:
|
204 |
+
validation_results["issues"].append(f"Error parsing text file: {str(e)}")
|
205 |
+
return validation_results
|
206 |
+
else:
|
207 |
+
validation_results["issues"].append(f"Unsupported file format: {format_type}")
|
208 |
+
return validation_results
|
209 |
+
|
210 |
+
# General recommendations
|
211 |
+
if validation_results["is_valid"]:
|
212 |
+
if not validation_results["issues"]:
|
213 |
+
validation_results["recommendations"].append("Dataset looks good and ready for fine-tuning!")
|
214 |
+
else:
|
215 |
+
validation_results["recommendations"].append("Address the issues above before proceeding with fine-tuning")
|
216 |
+
|
217 |
+
return validation_results
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
validation_results["issues"].append(f"Unexpected error: {str(e)}")
|
221 |
+
return validation_results
|
222 |
+
|
223 |
+
def generate_dataset_report(validation_results):
|
224 |
+
"""
|
225 |
+
Generate a user-friendly report from validation results
|
226 |
+
|
227 |
+
Parameters:
|
228 |
+
validation_results (dict): Results from validate_dataset
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
str: Formatted report
|
232 |
+
"""
|
233 |
+
report = []
|
234 |
+
|
235 |
+
# Add header
|
236 |
+
report.append("# Dataset Validation Report")
|
237 |
+
report.append("")
|
238 |
+
|
239 |
+
# Add validation status
|
240 |
+
if validation_results["is_valid"]:
|
241 |
+
report.append("✅ Dataset is valid and can be used for fine-tuning")
|
242 |
+
else:
|
243 |
+
report.append("❌ Dataset has issues that need to be addressed")
|
244 |
+
report.append("")
|
245 |
+
|
246 |
+
# Add format info
|
247 |
+
report.append(f"**File Format:** {validation_results['format']}")
|
248 |
+
report.append(f"**Detected Structure:** {validation_results['detected_structure']}")
|
249 |
+
report.append("")
|
250 |
+
|
251 |
+
# Add statistics
|
252 |
+
report.append("## Statistics")
|
253 |
+
for key, value in validation_results["statistics"].items():
|
254 |
+
# Format the key for better readability
|
255 |
+
formatted_key = key.replace("_", " ").title()
|
256 |
+
report.append(f"- {formatted_key}: {value}")
|
257 |
+
report.append("")
|
258 |
+
|
259 |
+
# Add issues
|
260 |
+
if validation_results["issues"]:
|
261 |
+
report.append("## Issues")
|
262 |
+
for issue in validation_results["issues"]:
|
263 |
+
report.append(f"- ⚠️ {issue}")
|
264 |
+
report.append("")
|
265 |
+
|
266 |
+
# Add recommendations
|
267 |
+
if validation_results["recommendations"]:
|
268 |
+
report.append("## Recommendations")
|
269 |
+
for recommendation in validation_results["recommendations"]:
|
270 |
+
report.append(f"- 💡 {recommendation}")
|
271 |
+
|
272 |
+
return "\n".join(report)
|
utils/model.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
10 |
+
from datetime import datetime
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
AutoModelForCausalLM,
|
15 |
+
TrainingArguments,
|
16 |
+
Trainer,
|
17 |
+
DataCollatorForLanguageModeling,
|
18 |
+
TrainerCallback
|
19 |
+
)
|
20 |
+
from peft import (
|
21 |
+
LoraConfig,
|
22 |
+
get_peft_model,
|
23 |
+
prepare_model_for_kbit_training
|
24 |
+
)
|
25 |
+
from datasets import load_dataset
|
26 |
+
from unsloth import FastModel
|
27 |
+
|
28 |
+
|
29 |
+
class GemmaFineTuning:
|
30 |
+
def __init__(self):
|
31 |
+
self.model = None
|
32 |
+
self.tokenizer = None
|
33 |
+
self.dataset = None
|
34 |
+
self.trainer = None
|
35 |
+
self.training_history = {"loss": [], "eval_loss": [], "step": []}
|
36 |
+
self.model_save_path = None
|
37 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
+
|
39 |
+
self.fourbit_models = [
|
40 |
+
"unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
|
41 |
+
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
|
42 |
+
"unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
|
43 |
+
"unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
|
44 |
+
]
|
45 |
+
# Default hyperparameters
|
46 |
+
self.default_params = {
|
47 |
+
"model_name": "google/gemma-2b",
|
48 |
+
"learning_rate": 2e-5,
|
49 |
+
"batch_size": 8,
|
50 |
+
"epochs": 3,
|
51 |
+
"max_length": 512,
|
52 |
+
"weight_decay": 0.01,
|
53 |
+
"warmup_ratio": 0.1,
|
54 |
+
"use_lora": True,
|
55 |
+
"lora_r": 16,
|
56 |
+
"lora_alpha": 32,
|
57 |
+
"lora_dropout": 0.05,
|
58 |
+
"eval_ratio": 0.1,
|
59 |
+
}
|
60 |
+
|
61 |
+
def load_model_and_tokenizer(self, model_name: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
62 |
+
"""Load the model and tokenizer"""
|
63 |
+
try:
|
64 |
+
# Map UI model names to actual model IDs
|
65 |
+
model_mapping = {
|
66 |
+
"google/gemma-2b": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
|
67 |
+
"google/gemma-7b": "unsloth/gemma-7b-it-unsloth-bnb-4bit",
|
68 |
+
"google/gemma-2b-it": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
|
69 |
+
"google/gemma-7b-it": "unsloth/gemma-7b-it-unsloth-bnb-4bit"
|
70 |
+
}
|
71 |
+
|
72 |
+
actual_model_name = model_mapping.get(model_name, model_name)
|
73 |
+
|
74 |
+
model, tokenizer = FastModel.from_pretrained(
|
75 |
+
model_name=actual_model_name,
|
76 |
+
max_seq_length=2048,
|
77 |
+
load_in_4bit=True,
|
78 |
+
load_in_8bit=False,
|
79 |
+
full_finetuning=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
# Move model to device
|
83 |
+
model = model.to(self.device)
|
84 |
+
return model, tokenizer
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
raise ValueError(f"Error loading model {model_name}: {str(e)}")
|
88 |
+
|
89 |
+
def prepare_dataset(self, file_path, format_type):
|
90 |
+
"""
|
91 |
+
Prepare and normalize dataset from various formats
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
file_path (str): Path to the dataset file
|
95 |
+
format_type (str): File format (csv, jsonl, text)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
dict: Dataset dictionary with train split
|
99 |
+
"""
|
100 |
+
import pandas as pd
|
101 |
+
import json
|
102 |
+
import os
|
103 |
+
from datasets import Dataset, DatasetDict
|
104 |
+
|
105 |
+
try:
|
106 |
+
if format_type == "csv":
|
107 |
+
# Load CSV file
|
108 |
+
df = pd.read_csv(file_path)
|
109 |
+
|
110 |
+
# Check if the CSV has the expected columns (looking for either instruction-response pairs or text)
|
111 |
+
if "instruction" in df.columns and "response" in df.columns:
|
112 |
+
# Instruction-following dataset format
|
113 |
+
dataset_format = "instruction-response"
|
114 |
+
# Ensure no nulls
|
115 |
+
df = df.dropna(subset=["instruction", "response"])
|
116 |
+
# Create formatted text by combining instruction and response
|
117 |
+
df["text"] = df.apply(lambda row: f"<instruction>{row['instruction']}</instruction>\n<response>{row['response']}</response>", axis=1)
|
118 |
+
elif "input" in df.columns and "output" in df.columns:
|
119 |
+
# Another common format
|
120 |
+
dataset_format = "input-output"
|
121 |
+
df = df.dropna(subset=["input", "output"])
|
122 |
+
df["text"] = df.apply(lambda row: f"<input>{row['input']}</input>\n<output>{row['output']}</output>", axis=1)
|
123 |
+
elif "prompt" in df.columns and "completion" in df.columns:
|
124 |
+
# OpenAI-style format
|
125 |
+
dataset_format = "prompt-completion"
|
126 |
+
df = df.dropna(subset=["prompt", "completion"])
|
127 |
+
df["text"] = df.apply(lambda row: f"<prompt>{row['prompt']}</prompt>\n<completion>{row['completion']}</completion>", axis=1)
|
128 |
+
elif "text" in df.columns:
|
129 |
+
# Simple text format
|
130 |
+
dataset_format = "text-only"
|
131 |
+
df = df.dropna(subset=["text"])
|
132 |
+
else:
|
133 |
+
# Try to infer format from the first text column
|
134 |
+
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
135 |
+
if len(text_columns) > 0:
|
136 |
+
dataset_format = "inferred"
|
137 |
+
df["text"] = df[text_columns[0]]
|
138 |
+
df = df.dropna(subset=["text"])
|
139 |
+
else:
|
140 |
+
raise ValueError("CSV file must contain either 'instruction'/'response', 'input'/'output', 'prompt'/'completion', or 'text' columns")
|
141 |
+
|
142 |
+
# Create dataset
|
143 |
+
dataset = Dataset.from_pandas(df)
|
144 |
+
|
145 |
+
elif format_type == "jsonl":
|
146 |
+
# Load JSONL file
|
147 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
148 |
+
data = [json.loads(line) for line in f if line.strip()]
|
149 |
+
|
150 |
+
# Check and normalize the format
|
151 |
+
normalized_data = []
|
152 |
+
for item in data:
|
153 |
+
normalized_item = {}
|
154 |
+
|
155 |
+
# Try to find either instruction-response pairs or text
|
156 |
+
if "instruction" in item and "response" in item:
|
157 |
+
normalized_item["text"] = f"<instruction>{item['instruction']}</instruction>\n<response>{item['response']}</response>"
|
158 |
+
normalized_item["instruction"] = item["instruction"]
|
159 |
+
normalized_item["response"] = item["response"]
|
160 |
+
elif "input" in item and "output" in item:
|
161 |
+
normalized_item["text"] = f"<input>{item['input']}</input>\n<output>{item['output']}</output>"
|
162 |
+
normalized_item["input"] = item["input"]
|
163 |
+
normalized_item["output"] = item["output"]
|
164 |
+
elif "prompt" in item and "completion" in item:
|
165 |
+
normalized_item["text"] = f"<prompt>{item['prompt']}</prompt>\n<completion>{item['completion']}</completion>"
|
166 |
+
normalized_item["prompt"] = item["prompt"]
|
167 |
+
normalized_item["completion"] = item["completion"]
|
168 |
+
elif "text" in item:
|
169 |
+
normalized_item["text"] = item["text"]
|
170 |
+
else:
|
171 |
+
# Try to infer from the first string value
|
172 |
+
text_keys = [k for k, v in item.items() if isinstance(v, str) and len(v.strip()) > 0]
|
173 |
+
if text_keys:
|
174 |
+
normalized_item["text"] = item[text_keys[0]]
|
175 |
+
else:
|
176 |
+
continue # Skip this item if no usable text found
|
177 |
+
|
178 |
+
normalized_data.append(normalized_item)
|
179 |
+
|
180 |
+
if not normalized_data:
|
181 |
+
raise ValueError("No valid data items found in the JSONL file")
|
182 |
+
|
183 |
+
# Create dataset
|
184 |
+
dataset = Dataset.from_list(normalized_data)
|
185 |
+
|
186 |
+
elif format_type == "text":
|
187 |
+
# For text files, split by newlines and create entries
|
188 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
189 |
+
content = f.read()
|
190 |
+
|
191 |
+
# Check if it's a single large document or multiple examples
|
192 |
+
# If file size > 10KB, try to split into paragraphs
|
193 |
+
if os.path.getsize(file_path) > 10240:
|
194 |
+
# Split by double newlines (paragraphs)
|
195 |
+
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
196 |
+
# Filter out very short paragraphs (less than 20 chars)
|
197 |
+
paragraphs = [p for p in paragraphs if len(p) >= 20]
|
198 |
+
data = [{"text": p} for p in paragraphs]
|
199 |
+
else:
|
200 |
+
# Treat as a single example
|
201 |
+
data = [{"text": content}]
|
202 |
+
|
203 |
+
# Create dataset
|
204 |
+
dataset = Dataset.from_list(data)
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unsupported file format: {format_type}")
|
207 |
+
|
208 |
+
# Return as a DatasetDict with a train split
|
209 |
+
return DatasetDict({"train": dataset})
|
210 |
+
|
211 |
+
except Exception as e:
|
212 |
+
import traceback
|
213 |
+
error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
|
214 |
+
print(error_msg)
|
215 |
+
raise ValueError(error_msg)
|
216 |
+
|
217 |
+
def chunk_text(self, text: str, chunk_size: int) -> List[str]:
|
218 |
+
"""Split text into chunks of approximately chunk_size characters"""
|
219 |
+
words = text.split()
|
220 |
+
chunks = []
|
221 |
+
current_chunk = []
|
222 |
+
current_length = 0
|
223 |
+
|
224 |
+
for word in words:
|
225 |
+
if current_length + len(word) + 1 > chunk_size and current_chunk:
|
226 |
+
chunks.append(" ".join(current_chunk))
|
227 |
+
current_chunk = [word]
|
228 |
+
current_length = len(word)
|
229 |
+
else:
|
230 |
+
current_chunk.append(word)
|
231 |
+
current_length += len(word) + 1 # +1 for the space
|
232 |
+
|
233 |
+
if current_chunk:
|
234 |
+
chunks.append(" ".join(current_chunk))
|
235 |
+
|
236 |
+
return chunks
|
237 |
+
|
238 |
+
def preprocess_dataset(self, dataset, tokenizer, max_length):
|
239 |
+
"""
|
240 |
+
Tokenize and format the dataset for training
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
dataset (DatasetDict): Dataset dictionary with train and validation splits
|
244 |
+
tokenizer: HuggingFace tokenizer
|
245 |
+
max_length (int): Maximum sequence length
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
DatasetDict: Tokenized dataset ready for training
|
249 |
+
"""
|
250 |
+
def tokenize_function(examples):
|
251 |
+
# Check if the dataset has both input and target text columns
|
252 |
+
if "text" in examples:
|
253 |
+
texts = examples["text"]
|
254 |
+
inputs = tokenizer(
|
255 |
+
texts,
|
256 |
+
padding="max_length",
|
257 |
+
truncation=True,
|
258 |
+
max_length=max_length,
|
259 |
+
return_tensors="pt"
|
260 |
+
)
|
261 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
262 |
+
return inputs
|
263 |
+
else:
|
264 |
+
# Try to find text columns based on common naming patterns
|
265 |
+
potential_text_cols = [col for col in examples.keys() if isinstance(examples[col], list) and
|
266 |
+
all(isinstance(item, str) for item in examples[col])]
|
267 |
+
|
268 |
+
if not potential_text_cols:
|
269 |
+
raise ValueError("No suitable text columns found in the dataset")
|
270 |
+
|
271 |
+
# Use the first text column found
|
272 |
+
text_col = potential_text_cols[0]
|
273 |
+
texts = examples[text_col]
|
274 |
+
|
275 |
+
inputs = tokenizer(
|
276 |
+
texts,
|
277 |
+
padding="max_length",
|
278 |
+
truncation=True,
|
279 |
+
max_length=max_length,
|
280 |
+
return_tensors="pt"
|
281 |
+
)
|
282 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
283 |
+
return inputs
|
284 |
+
|
285 |
+
# Apply tokenization to each split
|
286 |
+
tokenized_dataset = {}
|
287 |
+
for split, ds in dataset.items():
|
288 |
+
tokenized_dataset[split] = ds.map(
|
289 |
+
tokenize_function,
|
290 |
+
batched=True,
|
291 |
+
remove_columns=ds.column_names
|
292 |
+
)
|
293 |
+
|
294 |
+
return tokenized_dataset
|
295 |
+
|
296 |
+
def prepare_training_args(self, params: Dict) -> TrainingArguments:
|
297 |
+
"""Set up training arguments based on hyperparameters"""
|
298 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
299 |
+
self.model_save_path = f"gemma-finetuned-{timestamp}"
|
300 |
+
|
301 |
+
args = TrainingArguments(
|
302 |
+
output_dir=self.model_save_path,
|
303 |
+
per_device_train_batch_size=params.get("batch_size", self.default_params["batch_size"]),
|
304 |
+
gradient_accumulation_steps=4,
|
305 |
+
per_device_eval_batch_size=params.get("batch_size", self.default_params["batch_size"]),
|
306 |
+
learning_rate=params.get("learning_rate", self.default_params["learning_rate"]),
|
307 |
+
num_train_epochs=params.get("epochs", self.default_params["epochs"]),
|
308 |
+
warmup_ratio=params.get("warmup_ratio", self.default_params["warmup_ratio"]),
|
309 |
+
weight_decay=params.get("weight_decay", self.default_params["weight_decay"]),
|
310 |
+
logging_steps=1,
|
311 |
+
evaluation_strategy="steps" if params.get("eval_ratio", 0) > 0 else "no",
|
312 |
+
eval_steps=100 if params.get("eval_ratio", 0) > 0 else None,
|
313 |
+
save_strategy="steps",
|
314 |
+
save_steps=100,
|
315 |
+
save_total_limit=2,
|
316 |
+
load_best_model_at_end=True if params.get("eval_ratio", 0) > 0 else False,
|
317 |
+
report_to="none"
|
318 |
+
)
|
319 |
+
return args
|
320 |
+
|
321 |
+
def train(self, training_params: Dict) -> str:
|
322 |
+
"""Main training method that handles the complete training pipeline"""
|
323 |
+
try:
|
324 |
+
if self.dataset is None:
|
325 |
+
return "Error: No dataset loaded. Please preprocess a dataset first."
|
326 |
+
|
327 |
+
# Reset training history
|
328 |
+
self.training_history = {"loss": [], "eval_loss": [], "step": []}
|
329 |
+
|
330 |
+
# Load model and tokenizer if not already loaded or if model name changed
|
331 |
+
current_model_name = training_params.get("model_name", self.default_params["model_name"])
|
332 |
+
if (self.model is None or self.tokenizer is None or
|
333 |
+
getattr(self, '_current_model_name', None) != current_model_name):
|
334 |
+
|
335 |
+
self.model, self.tokenizer = self.load_model_and_tokenizer(current_model_name)
|
336 |
+
self._current_model_name = current_model_name
|
337 |
+
|
338 |
+
# Create validation split if needed
|
339 |
+
eval_ratio = float(training_params.get("eval_ratio", self.default_params["eval_ratio"]))
|
340 |
+
if eval_ratio > 0 and "validation" not in self.dataset:
|
341 |
+
split_dataset = self.dataset["train"].train_test_split(test_size=eval_ratio)
|
342 |
+
self.dataset = {
|
343 |
+
"train": split_dataset["train"],
|
344 |
+
"validation": split_dataset["test"]
|
345 |
+
}
|
346 |
+
|
347 |
+
# Apply LoRA if selected
|
348 |
+
if training_params.get("use_lora", self.default_params["use_lora"]):
|
349 |
+
self.model = self.setup_lora(self.model, {
|
350 |
+
"lora_r": int(training_params.get("lora_r", self.default_params["lora_r"])),
|
351 |
+
"lora_alpha": int(training_params.get("lora_alpha", self.default_params["lora_alpha"])),
|
352 |
+
"lora_dropout": float(training_params.get("lora_dropout", self.default_params["lora_dropout"]))
|
353 |
+
})
|
354 |
+
|
355 |
+
# Preprocess dataset
|
356 |
+
max_length = int(training_params.get("max_length", self.default_params["max_length"]))
|
357 |
+
tokenized_dataset = self.preprocess_dataset(self.dataset, self.tokenizer, max_length)
|
358 |
+
|
359 |
+
# Update training arguments with proper type conversion
|
360 |
+
training_args = self.prepare_training_args({
|
361 |
+
"batch_size": int(training_params.get("batch_size", self.default_params["batch_size"])),
|
362 |
+
"learning_rate": float(training_params.get("learning_rate", self.default_params["learning_rate"])),
|
363 |
+
"epochs": int(training_params.get("epochs", self.default_params["epochs"])),
|
364 |
+
"weight_decay": float(training_params.get("weight_decay", self.default_params["weight_decay"])),
|
365 |
+
"warmup_ratio": float(training_params.get("warmup_ratio", self.default_params["warmup_ratio"])),
|
366 |
+
"eval_ratio": eval_ratio
|
367 |
+
})
|
368 |
+
|
369 |
+
# Create trainer with proper callback
|
370 |
+
self.trainer = self.create_trainer(
|
371 |
+
self.model,
|
372 |
+
self.tokenizer,
|
373 |
+
tokenized_dataset,
|
374 |
+
training_args
|
375 |
+
)
|
376 |
+
|
377 |
+
# Start training
|
378 |
+
self.trainer.train()
|
379 |
+
|
380 |
+
# Save the model
|
381 |
+
save_path = f"models/gemma-finetuned-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
382 |
+
os.makedirs(save_path, exist_ok=True)
|
383 |
+
self.trainer.save_model(save_path)
|
384 |
+
self.tokenizer.save_pretrained(save_path)
|
385 |
+
self.model_save_path = save_path
|
386 |
+
|
387 |
+
return f"Training completed successfully! Model saved to {save_path}"
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
import traceback
|
391 |
+
return f"Error during training: {str(e)}\n{traceback.format_exc()}"
|
392 |
+
|
393 |
+
def setup_lora(self, model, params: Dict) -> torch.nn.Module:
|
394 |
+
"""Configure LoRA for parameter-efficient fine-tuning"""
|
395 |
+
# Prepare the model for training if using 8-bit or 4-bit quantization
|
396 |
+
if hasattr(model, "is_quantized") and model.is_quantized:
|
397 |
+
model = prepare_model_for_kbit_training(model)
|
398 |
+
|
399 |
+
lora_config = LoraConfig(
|
400 |
+
r=params["lora_r"],
|
401 |
+
lora_alpha=params["lora_alpha"],
|
402 |
+
target_modules=["q_proj", "k_proj", "v_proj"],
|
403 |
+
lora_dropout=params["lora_dropout"],
|
404 |
+
bias="none",
|
405 |
+
task_type="CAUSAL_LM",
|
406 |
+
)
|
407 |
+
|
408 |
+
model = FastModel.get_peft_model(
|
409 |
+
model,
|
410 |
+
finetune_vision_layers = False, # Turn off for just text!
|
411 |
+
finetune_language_layers = True, # Should leave on!
|
412 |
+
finetune_attention_modules = True, # Attention good for GRPO
|
413 |
+
finetune_mlp_modules = True, # SHould leave on always!
|
414 |
+
|
415 |
+
r = 8, # Larger = higher accuracy, but might overfit
|
416 |
+
lora_alpha = 8, # Recommended alpha == r at least
|
417 |
+
lora_dropout = 0,
|
418 |
+
bias = "none",
|
419 |
+
random_state = 3407,
|
420 |
+
)
|
421 |
+
model.print_trainable_parameters()
|
422 |
+
model = model.to(self.device)
|
423 |
+
return model
|
424 |
+
|
425 |
+
def create_trainer(self, model, tokenizer, dataset, training_args):
|
426 |
+
"""Set up the Trainer for model fine-tuning"""
|
427 |
+
# Create data collator
|
428 |
+
data_collator = DataCollatorForLanguageModeling(
|
429 |
+
tokenizer=tokenizer,
|
430 |
+
mlm=False
|
431 |
+
)
|
432 |
+
|
433 |
+
# Custom callback to store training history
|
434 |
+
class TrainingCallback(TrainerCallback):
|
435 |
+
def __init__(self, app):
|
436 |
+
self.app = app
|
437 |
+
|
438 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
439 |
+
if logs:
|
440 |
+
for key in ['loss', 'eval_loss']:
|
441 |
+
if key in logs:
|
442 |
+
self.app.training_history[key].append(logs[key])
|
443 |
+
if 'step' in logs:
|
444 |
+
self.app.training_history['step'].append(logs['step'])
|
445 |
+
|
446 |
+
# Create trainer
|
447 |
+
trainer = Trainer(
|
448 |
+
model=model,
|
449 |
+
args=training_args,
|
450 |
+
train_dataset=dataset["train"],
|
451 |
+
eval_dataset=dataset["validation"] if "validation" in dataset else None,
|
452 |
+
data_collator=data_collator,
|
453 |
+
callbacks=[TrainingCallback]
|
454 |
+
)
|
455 |
+
|
456 |
+
return trainer
|
457 |
+
|
458 |
+
def plot_training_progress(self):
|
459 |
+
"""Generate a plot of the training progress"""
|
460 |
+
if not self.training_history["loss"]:
|
461 |
+
return None
|
462 |
+
|
463 |
+
plt.figure(figsize=(10, 6))
|
464 |
+
plt.plot(self.training_history["step"], self.training_history["loss"], label="Training Loss")
|
465 |
+
|
466 |
+
if self.training_history["eval_loss"]:
|
467 |
+
# Get the steps where eval happened
|
468 |
+
eval_steps = self.training_history["step"][:len(self.training_history["eval_loss"])]
|
469 |
+
plt.plot(eval_steps, self.training_history["eval_loss"], label="Validation Loss", linestyle="--")
|
470 |
+
|
471 |
+
plt.xlabel("Training Steps")
|
472 |
+
plt.ylabel("Loss")
|
473 |
+
plt.title("Training Progress")
|
474 |
+
plt.legend()
|
475 |
+
plt.grid(True)
|
476 |
+
|
477 |
+
return plt
|
478 |
+
|
479 |
+
def export_model(self, output_format: str) -> str:
|
480 |
+
"""Export the fine-tuned model in various formats"""
|
481 |
+
if self.model is None or self.model_save_path is None:
|
482 |
+
return "No model has been trained yet."
|
483 |
+
|
484 |
+
export_path = f"{self.model_save_path}/exported_{output_format}"
|
485 |
+
os.makedirs(export_path, exist_ok=True)
|
486 |
+
|
487 |
+
if output_format == "pytorch":
|
488 |
+
# Save as PyTorch format
|
489 |
+
self.model.save_pretrained(export_path)
|
490 |
+
self.tokenizer.save_pretrained(export_path)
|
491 |
+
return f"Model exported in PyTorch format to {export_path}"
|
492 |
+
|
493 |
+
elif output_format == "tensorflow":
|
494 |
+
# Convert to TensorFlow format
|
495 |
+
try:
|
496 |
+
from transformers.modeling_tf_utils import convert_pt_to_tf
|
497 |
+
|
498 |
+
# First save the PyTorch model
|
499 |
+
self.model.save_pretrained(export_path)
|
500 |
+
self.tokenizer.save_pretrained(export_path)
|
501 |
+
|
502 |
+
# Then convert to TF SavedModel format
|
503 |
+
tf_model = convert_pt_to_tf(self.model)
|
504 |
+
tf_model.save_pretrained(f"{export_path}/tf_saved_model")
|
505 |
+
|
506 |
+
return f"Model exported in TensorFlow format to {export_path}/tf_saved_model"
|
507 |
+
except Exception as e:
|
508 |
+
return f"Failed to export as TensorFlow model: {str(e)}"
|
509 |
+
|
510 |
+
elif output_format == "gguf":
|
511 |
+
# Export as GGUF format for local inference
|
512 |
+
try:
|
513 |
+
import subprocess
|
514 |
+
|
515 |
+
# First save the model in PyTorch format
|
516 |
+
self.model.save_pretrained(export_path)
|
517 |
+
self.tokenizer.save_pretrained(export_path)
|
518 |
+
|
519 |
+
# Use llama.cpp's conversion script (must be installed)
|
520 |
+
subprocess.run([
|
521 |
+
"python", "-m", "llama_cpp.convert",
|
522 |
+
"--outtype", "gguf",
|
523 |
+
"--outfile", f"{export_path}/model.gguf",
|
524 |
+
export_path
|
525 |
+
])
|
526 |
+
|
527 |
+
return f"Model exported in GGUF format to {export_path}/model.gguf"
|
528 |
+
except Exception as e:
|
529 |
+
return f"Failed to export as GGUF model: {str(e)}"
|
530 |
+
|
531 |
+
else:
|
532 |
+
return f"Unsupported export format: {output_format}"
|
533 |
+
|
534 |
+
def generate_text(self, prompt: str, max_length: int = 100) -> str:
|
535 |
+
"""Generate text using the fine-tuned model"""
|
536 |
+
if self.model is None or self.tokenizer is None:
|
537 |
+
return "No model has been loaded or fine-tuned yet."
|
538 |
+
|
539 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
540 |
+
|
541 |
+
with torch.no_grad():
|
542 |
+
outputs = self.model.generate(
|
543 |
+
**inputs,
|
544 |
+
max_length=max_length + inputs.input_ids.shape[1],
|
545 |
+
temperature=0.7,
|
546 |
+
top_p=0.9,
|
547 |
+
do_sample=True,
|
548 |
+
pad_token_id=self.tokenizer.pad_token_id
|
549 |
+
)
|
550 |
+
|
551 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
552 |
+
return generated_text
|
utils/sample_dataset.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
|
6 |
+
def generate_sample_datasets(output_dir="./sample_datasets"):
|
7 |
+
"""
|
8 |
+
Generate sample datasets in CSV, JSONL, and text formats for testing the Gemma fine-tuning UI
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
output_dir (str): Directory to save the sample datasets
|
12 |
+
"""
|
13 |
+
# Create output directory if it doesn't exist
|
14 |
+
os.makedirs(output_dir, exist_ok=True)
|
15 |
+
|
16 |
+
# Sample data for instruction-response pairs
|
17 |
+
instruction_data = [
|
18 |
+
{"instruction": "Explain the concept of machine learning in simple terms.",
|
19 |
+
"response": "Machine learning is when computers learn from examples rather than being explicitly programmed. It's like teaching a child by showing examples instead of giving exact rules."},
|
20 |
+
{"instruction": "What are the health benefits of regular exercise?",
|
21 |
+
"response": "Regular exercise improves cardiovascular health, strengthens muscles and bones, helps control weight, reduces stress and anxiety, improves sleep quality, and increases overall energy levels."},
|
22 |
+
{"instruction": "How do I make a basic pasta sauce?",
|
23 |
+
"response": "For a basic tomato pasta sauce, sauté minced garlic and onions in olive oil until soft. Add canned crushed tomatoes, salt, pepper, and Italian herbs. Simmer for 20-30 minutes, stirring occasionally. Add basil at the end for fresh flavor."},
|
24 |
+
{"instruction": "Summarize the water cycle.",
|
25 |
+
"response": "The water cycle is the continuous movement of water on, above, and below Earth's surface. It involves evaporation (water turning to vapor), condensation (vapor forming clouds), precipitation (rain or snow), and collection (water returning to bodies of water)."},
|
26 |
+
{"instruction": "What is the difference between a simile and a metaphor?",
|
27 |
+
"response": "A simile compares two things using 'like' or 'as' (e.g., 'as brave as a lion'). A metaphor directly states that one thing is another (e.g., 'he is a lion in battle'). Both are figurative language techniques used to create vivid descriptions."}
|
28 |
+
]
|
29 |
+
|
30 |
+
# Generate more instruction-response pairs
|
31 |
+
topics = ["history", "science", "literature", "cooking", "technology", "health", "travel", "sports", "music", "art"]
|
32 |
+
question_starters = ["Explain", "Describe", "How to", "What is", "Why does", "Compare", "Summarize", "List ways to", "Define", "Analyze"]
|
33 |
+
|
34 |
+
for _ in range(20):
|
35 |
+
topic = random.choice(topics)
|
36 |
+
starter = random.choice(question_starters)
|
37 |
+
instruction = f"{starter} {topic.lower()} {random.choice(['concepts', 'principles', 'ideas', 'techniques', 'methods'])}"
|
38 |
+
response = f"This is a sample response about {topic} that would be more detailed in a real dataset. It would contain multiple sentences explaining {topic} concepts in depth."
|
39 |
+
instruction_data.append({"instruction": instruction, "response": response})
|
40 |
+
|
41 |
+
# Create a dictionary to store sample datasets
|
42 |
+
datasets = {}
|
43 |
+
|
44 |
+
# 1. Create CSV in instruction-response format
|
45 |
+
df_instruction = pd.DataFrame(instruction_data)
|
46 |
+
datasets["instruction_response.csv"] = df_instruction
|
47 |
+
|
48 |
+
# 2. Create CSV in input-output format
|
49 |
+
input_output_data = [{"input": item["instruction"], "output": item["response"]} for item in instruction_data]
|
50 |
+
df_input_output = pd.DataFrame(input_output_data)
|
51 |
+
datasets["input_output.csv"] = df_input_output
|
52 |
+
|
53 |
+
# 3. Create CSV in text-only format
|
54 |
+
text_data = [{"text": f"Q: {item['instruction']}\nA: {item['response']}"} for item in instruction_data]
|
55 |
+
df_text = pd.DataFrame(text_data)
|
56 |
+
datasets["text_only.csv"] = df_text
|
57 |
+
|
58 |
+
# 4. Create CSV with non-standard format
|
59 |
+
custom_data = [{"question": item["instruction"], "answer": item["response"]} for item in instruction_data]
|
60 |
+
df_custom = pd.DataFrame(custom_data)
|
61 |
+
datasets["custom_format.csv"] = df_custom
|
62 |
+
|
63 |
+
# 5. Create JSONL in instruction-response format
|
64 |
+
jsonl_instruction = instruction_data
|
65 |
+
datasets["instruction_response.jsonl"] = jsonl_instruction
|
66 |
+
|
67 |
+
# 6. Create JSONL in prompt-completion format
|
68 |
+
prompt_completion_data = [{"prompt": item["instruction"], "completion": item["response"]} for item in instruction_data]
|
69 |
+
datasets["prompt_completion.jsonl"] = prompt_completion_data
|
70 |
+
|
71 |
+
# 7. Create JSONL with non-standard format
|
72 |
+
jsonl_custom = [{"query": item["instruction"], "result": item["response"]} for item in instruction_data]
|
73 |
+
datasets["custom_format.jsonl"] = jsonl_custom
|
74 |
+
|
75 |
+
# 8. Create text format (paragraphs)
|
76 |
+
text_paragraphs = "\n\n".join([f"Q: {item['instruction']}\nA: {item['response']}" for item in instruction_data])
|
77 |
+
datasets["paragraphs.txt"] = text_paragraphs
|
78 |
+
|
79 |
+
# 9. Create text format (single examples per line)
|
80 |
+
text_lines = "\n".join([f"{item['instruction']} => {item['response']}" for item in instruction_data])
|
81 |
+
datasets["single_lines.txt"] = text_lines
|
82 |
+
|
83 |
+
# Save all datasets
|
84 |
+
for filename, data in datasets.items():
|
85 |
+
file_path = os.path.join(output_dir, filename)
|
86 |
+
|
87 |
+
if filename.endswith('.csv'):
|
88 |
+
data.to_csv(file_path, index=False)
|
89 |
+
elif filename.endswith('.jsonl'):
|
90 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
91 |
+
for item in data:
|
92 |
+
f.write(json.dumps(item) + '\n')
|
93 |
+
elif filename.endswith('.txt'):
|
94 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
95 |
+
f.write(data)
|
96 |
+
|
97 |
+
print(f"Sample datasets generated in {output_dir}")
|
98 |
+
return list(datasets.keys())
|
99 |
+
|
100 |
+
# if __name__ == "__main__":
|
101 |
+
# # Generate sample datasets
|
102 |
+
# generated_files = generate_sample_datasets()
|
103 |
+
# print(f"Generated {len(generated_files)} sample dataset files:")
|
104 |
+
# for file in generated_files:
|
105 |
+
# print(f" - {file}")
|