Update app.py
Browse files
app.py
CHANGED
@@ -6,12 +6,16 @@ from torchvision import models
|
|
6 |
from transformers import BertTokenizer, BertModel
|
7 |
import pandas as pd
|
8 |
from datasets import load_dataset
|
9 |
-
from torch.utils.data import DataLoader, Dataset
|
10 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Load dataset and filter out null/none values
|
13 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
14 |
-
# Filter out entries where Model is None or empty
|
15 |
dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
|
16 |
|
17 |
# Preprocess text data
|
@@ -26,6 +30,9 @@ class CustomDataset(Dataset):
|
|
26 |
])
|
27 |
self.label_encoder = LabelEncoder()
|
28 |
self.labels = self.label_encoder.fit_transform(dataset['Model'])
|
|
|
|
|
|
|
29 |
|
30 |
def __len__(self):
|
31 |
return len(self.dataset)
|
@@ -41,69 +48,183 @@ class CustomDataset(Dataset):
|
|
41 |
label = self.labels[idx]
|
42 |
return image, text, label
|
43 |
|
44 |
-
#
|
45 |
-
|
46 |
-
def __init__(self):
|
47 |
-
super(ImageModel, self).__init__()
|
48 |
-
self.model = models.resnet18(pretrained=True)
|
49 |
-
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
def
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
def
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
|
|
6 |
from transformers import BertTokenizer, BertModel
|
7 |
import pandas as pd
|
8 |
from datasets import load_dataset
|
9 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
10 |
from sklearn.preprocessing import LabelEncoder
|
11 |
+
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
12 |
+
import seaborn as sns
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
|
17 |
# Load dataset and filter out null/none values
|
18 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
|
|
19 |
dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
|
20 |
|
21 |
# Preprocess text data
|
|
|
30 |
])
|
31 |
self.label_encoder = LabelEncoder()
|
32 |
self.labels = self.label_encoder.fit_transform(dataset['Model'])
|
33 |
+
|
34 |
+
# Save unique model names for later use
|
35 |
+
self.unique_models = self.label_encoder.classes_
|
36 |
|
37 |
def __len__(self):
|
38 |
return len(self.dataset)
|
|
|
48 |
label = self.labels[idx]
|
49 |
return image, text, label
|
50 |
|
51 |
+
# Model classes remain the same as before
|
52 |
+
# ... (ImageModel, TextModel, CombinedModel classes stay unchanged)
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
class ModelTrainerEvaluator:
|
55 |
+
def __init__(self, model, dataset, batch_size=32, learning_rate=0.001):
|
56 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
+
self.model = model.to(self.device)
|
58 |
+
self.batch_size = batch_size
|
59 |
+
self.criterion = nn.CrossEntropyLoss()
|
60 |
+
self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
61 |
+
|
62 |
+
# Split dataset into train, validation, and test
|
63 |
+
total_size = len(dataset)
|
64 |
+
train_size = int(0.7 * total_size)
|
65 |
+
val_size = int(0.15 * total_size)
|
66 |
+
test_size = total_size - train_size - val_size
|
67 |
+
|
68 |
+
train_dataset, val_dataset, test_dataset = random_split(
|
69 |
+
dataset, [train_size, val_size, test_size]
|
70 |
+
)
|
71 |
+
|
72 |
+
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
73 |
+
self.val_loader = DataLoader(val_dataset, batch_size=batch_size)
|
74 |
+
self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
|
75 |
+
|
76 |
+
self.unique_models = dataset.unique_models
|
77 |
|
78 |
+
def train_epoch(self):
|
79 |
+
self.model.train()
|
80 |
+
total_loss = 0
|
81 |
+
predictions = []
|
82 |
+
actual_labels = []
|
83 |
+
|
84 |
+
for batch in tqdm(self.train_loader, desc="Training"):
|
85 |
+
images, texts, labels = batch
|
86 |
+
images = images.to(self.device)
|
87 |
+
labels = labels.to(self.device)
|
88 |
+
|
89 |
+
# Forward pass
|
90 |
+
self.optimizer.zero_grad()
|
91 |
+
outputs = self.model(images, texts)
|
92 |
+
loss = self.criterion(outputs, labels)
|
93 |
+
|
94 |
+
# Backward pass
|
95 |
+
loss.backward()
|
96 |
+
self.optimizer.step()
|
97 |
+
|
98 |
+
total_loss += loss.item()
|
99 |
+
|
100 |
+
# Store predictions
|
101 |
+
_, preds = torch.max(outputs, 1)
|
102 |
+
predictions.extend(preds.cpu().numpy())
|
103 |
+
actual_labels.extend(labels.cpu().numpy())
|
104 |
+
|
105 |
+
return total_loss / len(self.train_loader), predictions, actual_labels
|
106 |
|
107 |
+
def evaluate(self, loader, mode="Validation"):
|
108 |
+
self.model.eval()
|
109 |
+
total_loss = 0
|
110 |
+
predictions = []
|
111 |
+
actual_labels = []
|
112 |
+
|
113 |
+
with torch.no_grad():
|
114 |
+
for batch in tqdm(loader, desc=mode):
|
115 |
+
images, texts, labels = batch
|
116 |
+
images = images.to(self.device)
|
117 |
+
labels = labels.to(self.device)
|
118 |
+
|
119 |
+
outputs = self.model(images, texts)
|
120 |
+
loss = self.criterion(outputs, labels)
|
121 |
+
|
122 |
+
total_loss += loss.item()
|
123 |
+
|
124 |
+
_, preds = torch.max(outputs, 1)
|
125 |
+
predictions.extend(preds.cpu().numpy())
|
126 |
+
actual_labels.extend(labels.cpu().numpy())
|
127 |
+
|
128 |
+
return total_loss / len(loader), predictions, actual_labels
|
129 |
|
130 |
+
def plot_confusion_matrix(self, y_true, y_pred, title):
|
131 |
+
cm = confusion_matrix(y_true, y_pred)
|
132 |
+
plt.figure(figsize=(15, 15))
|
133 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
|
134 |
+
plt.title(title)
|
135 |
+
plt.ylabel('True Label')
|
136 |
+
plt.xlabel('Predicted Label')
|
137 |
+
plt.savefig(f'{title.lower().replace(" ", "_")}.png')
|
138 |
+
plt.close()
|
139 |
|
140 |
+
def generate_evaluation_report(self, y_true, y_pred, title):
|
141 |
+
report = classification_report(y_true, y_pred,
|
142 |
+
target_names=self.unique_models,
|
143 |
+
output_dict=True)
|
144 |
+
df_report = pd.DataFrame(report).transpose()
|
145 |
+
df_report.to_csv(f'{title.lower().replace(" ", "_")}_report.csv')
|
146 |
+
|
147 |
+
accuracy = accuracy_score(y_true, y_pred)
|
148 |
+
|
149 |
+
print(f"\n{title} Results:")
|
150 |
+
print(f"Accuracy: {accuracy:.4f}")
|
151 |
+
print("\nClassification Report:")
|
152 |
+
print(classification_report(y_true, y_pred, target_names=self.unique_models))
|
153 |
+
|
154 |
+
return accuracy, df_report
|
155 |
|
156 |
+
def train_and_evaluate(self, num_epochs=5):
|
157 |
+
best_val_loss = float('inf')
|
158 |
+
train_accuracies = []
|
159 |
+
val_accuracies = []
|
160 |
+
|
161 |
+
for epoch in range(num_epochs):
|
162 |
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
163 |
+
|
164 |
+
# Training
|
165 |
+
train_loss, train_preds, train_labels = self.train_epoch()
|
166 |
+
train_accuracy, _ = self.generate_evaluation_report(
|
167 |
+
train_labels, train_preds, f"Training Epoch {epoch+1}"
|
168 |
+
)
|
169 |
+
self.plot_confusion_matrix(
|
170 |
+
train_labels, train_preds, f"Training Confusion Matrix Epoch {epoch+1}"
|
171 |
+
)
|
172 |
+
|
173 |
+
# Validation
|
174 |
+
val_loss, val_preds, val_labels = self.evaluate(self.val_loader)
|
175 |
+
val_accuracy, _ = self.generate_evaluation_report(
|
176 |
+
val_labels, val_preds, f"Validation Epoch {epoch+1}"
|
177 |
+
)
|
178 |
+
self.plot_confusion_matrix(
|
179 |
+
val_labels, val_preds, f"Validation Confusion Matrix Epoch {epoch+1}"
|
180 |
+
)
|
181 |
+
|
182 |
+
train_accuracies.append(train_accuracy)
|
183 |
+
val_accuracies.append(val_accuracy)
|
184 |
+
|
185 |
+
print(f"\nTraining Loss: {train_loss:.4f}")
|
186 |
+
print(f"Validation Loss: {val_loss:.4f}")
|
187 |
+
|
188 |
+
# Save best model
|
189 |
+
if val_loss < best_val_loss:
|
190 |
+
best_val_loss = val_loss
|
191 |
+
torch.save(self.model.state_dict(), 'best_model.pth')
|
192 |
+
|
193 |
+
# Plot training history
|
194 |
+
plt.figure(figsize=(10, 6))
|
195 |
+
plt.plot(train_accuracies, label='Training Accuracy')
|
196 |
+
plt.plot(val_accuracies, label='Validation Accuracy')
|
197 |
+
plt.title('Model Accuracy over Epochs')
|
198 |
+
plt.xlabel('Epoch')
|
199 |
+
plt.ylabel('Accuracy')
|
200 |
+
plt.legend()
|
201 |
+
plt.savefig('training_history.png')
|
202 |
+
plt.close()
|
203 |
+
|
204 |
+
# Final test evaluation
|
205 |
+
self.model.load_state_dict(torch.load('best_model.pth'))
|
206 |
+
test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test")
|
207 |
+
self.generate_evaluation_report(test_labels, test_preds, "Final Test")
|
208 |
+
self.plot_confusion_matrix(test_labels, test_preds, "Final Test Confusion Matrix")
|
209 |
|
210 |
+
# Usage example
|
211 |
+
def main():
|
212 |
+
# Create dataset
|
213 |
+
custom_dataset = CustomDataset(dataset)
|
214 |
+
|
215 |
+
# Create model
|
216 |
+
model = CombinedModel()
|
217 |
+
|
218 |
+
# Create trainer/evaluator
|
219 |
+
trainer = ModelTrainerEvaluator(
|
220 |
+
model=model,
|
221 |
+
dataset=custom_dataset,
|
222 |
+
batch_size=32,
|
223 |
+
learning_rate=0.001
|
224 |
+
)
|
225 |
+
|
226 |
+
# Train and evaluate
|
227 |
+
trainer.train_and_evaluate(num_epochs=5)
|
228 |
|
229 |
+
if __name__ == "__main__":
|
230 |
+
main()
|