Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -152,7 +152,8 @@ def extract_zip(zip_file):
|
|
152 |
return temp_dir
|
153 |
|
154 |
# Function to classify images in a folder
|
155 |
-
|
|
|
156 |
images = []
|
157 |
labels = []
|
158 |
preds = []
|
@@ -164,15 +165,19 @@ def classify_images(image_dir, model_pipeline):
|
|
164 |
img_path = os.path.join(folder_path, img_name)
|
165 |
try:
|
166 |
img = Image.open(img_path).convert("RGB")
|
|
|
|
|
167 |
pred = model_pipeline(img)
|
168 |
pred_label = np.argmax([x['score'] for x in pred])
|
|
|
169 |
preds.append(pred_label)
|
170 |
labels.append(ground_truth_label)
|
171 |
images.append(img_name)
|
172 |
except Exception as e:
|
173 |
-
print(f"Error processing image {img_name}: {e}")
|
174 |
return labels, preds, images
|
175 |
|
|
|
176 |
# Function to generate evaluation metrics
|
177 |
def evaluate_model(labels, preds):
|
178 |
cm = confusion_matrix(labels, preds)
|
@@ -200,20 +205,22 @@ def evaluate_model(labels, preds):
|
|
200 |
|
201 |
return accuracy, roc_score, report, fig, fig_roc
|
202 |
|
|
|
203 |
# Batch processing for all models
|
204 |
# Batch processing for all models
|
205 |
def process_zip(zip_file):
|
206 |
extracted_dir = extract_zip(zip_file.name)
|
207 |
-
|
208 |
-
# Initialize model pipelines (
|
209 |
model_pipelines = [pipe0, pipe1, pipe2]
|
210 |
|
211 |
# Run classification for each model
|
212 |
results = {}
|
213 |
for idx, pipe in enumerate(model_pipelines): # Ensure each model pipeline is used separately
|
214 |
-
|
|
|
215 |
accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
|
216 |
-
|
217 |
# Store results for each model
|
218 |
results[f'Model_{idx}_accuracy'] = accuracy
|
219 |
results[f'Model_{idx}_roc_score'] = roc_score
|
@@ -232,6 +239,7 @@ def process_zip(zip_file):
|
|
232 |
results['Model_2_cm_fig'], results['Model_2_roc_fig'])
|
233 |
|
234 |
|
|
|
235 |
# Single image section
|
236 |
def load_url(url):
|
237 |
try:
|
|
|
152 |
return temp_dir
|
153 |
|
154 |
# Function to classify images in a folder
|
155 |
+
# Function to classify images in a folder
|
156 |
+
def classify_images(image_dir, model_pipeline, model_idx):
|
157 |
images = []
|
158 |
labels = []
|
159 |
preds = []
|
|
|
165 |
img_path = os.path.join(folder_path, img_name)
|
166 |
try:
|
167 |
img = Image.open(img_path).convert("RGB")
|
168 |
+
|
169 |
+
# Now use the specific model pipeline passed in
|
170 |
pred = model_pipeline(img)
|
171 |
pred_label = np.argmax([x['score'] for x in pred])
|
172 |
+
|
173 |
preds.append(pred_label)
|
174 |
labels.append(ground_truth_label)
|
175 |
images.append(img_name)
|
176 |
except Exception as e:
|
177 |
+
print(f"Error processing image {img_name} in model {model_idx}: {e}")
|
178 |
return labels, preds, images
|
179 |
|
180 |
+
|
181 |
# Function to generate evaluation metrics
|
182 |
def evaluate_model(labels, preds):
|
183 |
cm = confusion_matrix(labels, preds)
|
|
|
205 |
|
206 |
return accuracy, roc_score, report, fig, fig_roc
|
207 |
|
208 |
+
# Batch processing for all models
|
209 |
# Batch processing for all models
|
210 |
# Batch processing for all models
|
211 |
def process_zip(zip_file):
|
212 |
extracted_dir = extract_zip(zip_file.name)
|
213 |
+
|
214 |
+
# Initialize model pipelines (already initialized outside)
|
215 |
model_pipelines = [pipe0, pipe1, pipe2]
|
216 |
|
217 |
# Run classification for each model
|
218 |
results = {}
|
219 |
for idx, pipe in enumerate(model_pipelines): # Ensure each model pipeline is used separately
|
220 |
+
print(f"Processing with model {idx}")
|
221 |
+
labels, preds, images = classify_images(extracted_dir, pipe, idx) # Pass in model index for debugging
|
222 |
accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
|
223 |
+
|
224 |
# Store results for each model
|
225 |
results[f'Model_{idx}_accuracy'] = accuracy
|
226 |
results[f'Model_{idx}_roc_score'] = roc_score
|
|
|
239 |
results['Model_2_cm_fig'], results['Model_2_roc_fig'])
|
240 |
|
241 |
|
242 |
+
|
243 |
# Single image section
|
244 |
def load_url(url):
|
245 |
try:
|