bgaspra commited on
Commit
3ea3100
·
verified ·
1 Parent(s): d140f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -60
app.py CHANGED
@@ -6,12 +6,16 @@ from torchvision import models
6
  from transformers import BertTokenizer, BertModel
7
  import pandas as pd
8
  from datasets import load_dataset
9
- from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
 
11
 
12
  # Load dataset and filter out null/none values
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
- # Filter out entries where Model is None or empty
15
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
16
 
17
  # Preprocess text data
@@ -26,6 +30,9 @@ class CustomDataset(Dataset):
26
  ])
27
  self.label_encoder = LabelEncoder()
28
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
 
 
 
29
 
30
  def __len__(self):
31
  return len(self.dataset)
@@ -41,69 +48,183 @@ class CustomDataset(Dataset):
41
  label = self.labels[idx]
42
  return image, text, label
43
 
44
- # Define CNN for image processing
45
- class ImageModel(nn.Module):
46
- def __init__(self):
47
- super(ImageModel, self).__init__()
48
- self.model = models.resnet18(pretrained=True)
49
- self.model.fc = nn.Linear(self.model.fc.in_features, 512)
50
 
51
- def forward(self, x):
52
- return self.model(x)
53
-
54
- # Define MLP for text processing
55
- class TextModel(nn.Module):
56
- def __init__(self):
57
- super(TextModel, self).__init__()
58
- self.bert = BertModel.from_pretrained('bert-base-uncased')
59
- self.fc = nn.Linear(768, 512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def forward(self, x):
62
- output = self.bert(**x)
63
- return self.fc(output.pooler_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Combined model
66
- class CombinedModel(nn.Module):
67
- def __init__(self):
68
- super(CombinedModel, self).__init__()
69
- self.image_model = ImageModel()
70
- self.text_model = TextModel()
71
- self.fc = nn.Linear(1024, len(dataset['Model']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def forward(self, image, text):
74
- image_features = self.image_model(image)
75
- text_features = self.text_model(text)
76
- combined = torch.cat((image_features, text_features), dim=1)
77
- return self.fc(combined)
 
 
 
 
78
 
79
- # Instantiate model
80
- model = CombinedModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # Define predict function
83
- def predict(image):
84
- model.eval()
85
- with torch.no_grad():
86
- image = transforms.ToTensor()(image).unsqueeze(0)
87
- image = transforms.Resize((224, 224))(image)
88
- text_input = tokenizer(
89
- "Sample prompt",
90
- return_tensors='pt',
91
- padding=True,
92
- truncation=True
93
- )
94
- output = model(image, text_input)
95
- _, indices = torch.topk(output, 5)
96
- recommended_models = [dataset['Model'][i] for i in indices[0]]
97
- return recommended_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Set up Gradio interface
100
- interface = gr.Interface(
101
- fn=predict,
102
- inputs=gr.Image(type="pil"),
103
- outputs=gr.Textbox(label="Recommended Models"),
104
- title="AI Image Model Recommender",
105
- description="Upload an AI-generated image to receive model recommendations."
106
- )
 
 
 
 
 
 
 
 
 
 
107
 
108
- # Launch the app
109
- interface.launch()
 
6
  from transformers import BertTokenizer, BertModel
7
  import pandas as pd
8
  from datasets import load_dataset
9
+ from torch.utils.data import DataLoader, Dataset, random_split
10
  from sklearn.preprocessing import LabelEncoder
11
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from tqdm import tqdm
16
 
17
  # Load dataset and filter out null/none values
18
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
 
19
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
20
 
21
  # Preprocess text data
 
30
  ])
31
  self.label_encoder = LabelEncoder()
32
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
33
+
34
+ # Save unique model names for later use
35
+ self.unique_models = self.label_encoder.classes_
36
 
37
  def __len__(self):
38
  return len(self.dataset)
 
48
  label = self.labels[idx]
49
  return image, text, label
50
 
51
+ # Model classes remain the same as before
52
+ # ... (ImageModel, TextModel, CombinedModel classes stay unchanged)
 
 
 
 
53
 
54
+ class ModelTrainerEvaluator:
55
+ def __init__(self, model, dataset, batch_size=32, learning_rate=0.001):
56
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+ self.model = model.to(self.device)
58
+ self.batch_size = batch_size
59
+ self.criterion = nn.CrossEntropyLoss()
60
+ self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
61
+
62
+ # Split dataset into train, validation, and test
63
+ total_size = len(dataset)
64
+ train_size = int(0.7 * total_size)
65
+ val_size = int(0.15 * total_size)
66
+ test_size = total_size - train_size - val_size
67
+
68
+ train_dataset, val_dataset, test_dataset = random_split(
69
+ dataset, [train_size, val_size, test_size]
70
+ )
71
+
72
+ self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
73
+ self.val_loader = DataLoader(val_dataset, batch_size=batch_size)
74
+ self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
75
+
76
+ self.unique_models = dataset.unique_models
77
 
78
+ def train_epoch(self):
79
+ self.model.train()
80
+ total_loss = 0
81
+ predictions = []
82
+ actual_labels = []
83
+
84
+ for batch in tqdm(self.train_loader, desc="Training"):
85
+ images, texts, labels = batch
86
+ images = images.to(self.device)
87
+ labels = labels.to(self.device)
88
+
89
+ # Forward pass
90
+ self.optimizer.zero_grad()
91
+ outputs = self.model(images, texts)
92
+ loss = self.criterion(outputs, labels)
93
+
94
+ # Backward pass
95
+ loss.backward()
96
+ self.optimizer.step()
97
+
98
+ total_loss += loss.item()
99
+
100
+ # Store predictions
101
+ _, preds = torch.max(outputs, 1)
102
+ predictions.extend(preds.cpu().numpy())
103
+ actual_labels.extend(labels.cpu().numpy())
104
+
105
+ return total_loss / len(self.train_loader), predictions, actual_labels
106
 
107
+ def evaluate(self, loader, mode="Validation"):
108
+ self.model.eval()
109
+ total_loss = 0
110
+ predictions = []
111
+ actual_labels = []
112
+
113
+ with torch.no_grad():
114
+ for batch in tqdm(loader, desc=mode):
115
+ images, texts, labels = batch
116
+ images = images.to(self.device)
117
+ labels = labels.to(self.device)
118
+
119
+ outputs = self.model(images, texts)
120
+ loss = self.criterion(outputs, labels)
121
+
122
+ total_loss += loss.item()
123
+
124
+ _, preds = torch.max(outputs, 1)
125
+ predictions.extend(preds.cpu().numpy())
126
+ actual_labels.extend(labels.cpu().numpy())
127
+
128
+ return total_loss / len(loader), predictions, actual_labels
129
 
130
+ def plot_confusion_matrix(self, y_true, y_pred, title):
131
+ cm = confusion_matrix(y_true, y_pred)
132
+ plt.figure(figsize=(15, 15))
133
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
134
+ plt.title(title)
135
+ plt.ylabel('True Label')
136
+ plt.xlabel('Predicted Label')
137
+ plt.savefig(f'{title.lower().replace(" ", "_")}.png')
138
+ plt.close()
139
 
140
+ def generate_evaluation_report(self, y_true, y_pred, title):
141
+ report = classification_report(y_true, y_pred,
142
+ target_names=self.unique_models,
143
+ output_dict=True)
144
+ df_report = pd.DataFrame(report).transpose()
145
+ df_report.to_csv(f'{title.lower().replace(" ", "_")}_report.csv')
146
+
147
+ accuracy = accuracy_score(y_true, y_pred)
148
+
149
+ print(f"\n{title} Results:")
150
+ print(f"Accuracy: {accuracy:.4f}")
151
+ print("\nClassification Report:")
152
+ print(classification_report(y_true, y_pred, target_names=self.unique_models))
153
+
154
+ return accuracy, df_report
155
 
156
+ def train_and_evaluate(self, num_epochs=5):
157
+ best_val_loss = float('inf')
158
+ train_accuracies = []
159
+ val_accuracies = []
160
+
161
+ for epoch in range(num_epochs):
162
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
163
+
164
+ # Training
165
+ train_loss, train_preds, train_labels = self.train_epoch()
166
+ train_accuracy, _ = self.generate_evaluation_report(
167
+ train_labels, train_preds, f"Training Epoch {epoch+1}"
168
+ )
169
+ self.plot_confusion_matrix(
170
+ train_labels, train_preds, f"Training Confusion Matrix Epoch {epoch+1}"
171
+ )
172
+
173
+ # Validation
174
+ val_loss, val_preds, val_labels = self.evaluate(self.val_loader)
175
+ val_accuracy, _ = self.generate_evaluation_report(
176
+ val_labels, val_preds, f"Validation Epoch {epoch+1}"
177
+ )
178
+ self.plot_confusion_matrix(
179
+ val_labels, val_preds, f"Validation Confusion Matrix Epoch {epoch+1}"
180
+ )
181
+
182
+ train_accuracies.append(train_accuracy)
183
+ val_accuracies.append(val_accuracy)
184
+
185
+ print(f"\nTraining Loss: {train_loss:.4f}")
186
+ print(f"Validation Loss: {val_loss:.4f}")
187
+
188
+ # Save best model
189
+ if val_loss < best_val_loss:
190
+ best_val_loss = val_loss
191
+ torch.save(self.model.state_dict(), 'best_model.pth')
192
+
193
+ # Plot training history
194
+ plt.figure(figsize=(10, 6))
195
+ plt.plot(train_accuracies, label='Training Accuracy')
196
+ plt.plot(val_accuracies, label='Validation Accuracy')
197
+ plt.title('Model Accuracy over Epochs')
198
+ plt.xlabel('Epoch')
199
+ plt.ylabel('Accuracy')
200
+ plt.legend()
201
+ plt.savefig('training_history.png')
202
+ plt.close()
203
+
204
+ # Final test evaluation
205
+ self.model.load_state_dict(torch.load('best_model.pth'))
206
+ test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test")
207
+ self.generate_evaluation_report(test_labels, test_preds, "Final Test")
208
+ self.plot_confusion_matrix(test_labels, test_preds, "Final Test Confusion Matrix")
209
 
210
+ # Usage example
211
+ def main():
212
+ # Create dataset
213
+ custom_dataset = CustomDataset(dataset)
214
+
215
+ # Create model
216
+ model = CombinedModel()
217
+
218
+ # Create trainer/evaluator
219
+ trainer = ModelTrainerEvaluator(
220
+ model=model,
221
+ dataset=custom_dataset,
222
+ batch_size=32,
223
+ learning_rate=0.001
224
+ )
225
+
226
+ # Train and evaluate
227
+ trainer.train_and_evaluate(num_epochs=5)
228
 
229
+ if __name__ == "__main__":
230
+ main()