Spaces:
Sleeping
Sleeping
File size: 5,793 Bytes
e1ab149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from models.model import EfficientNetModel, CNNModel
class AnimalClassifierApp:
def __init__(self):
"""Initialize the application."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.labels = ["bird", "cat", "dog", "horse"]
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load models
self.models = self.load_models()
if not self.models:
print("Warning: No models found in checkpoints directory!")
def load_models(self):
"""Load both trained models."""
models = {}
# Load EfficientNet
try:
efficientnet = EfficientNetModel(num_classes=len(self.labels))
efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
if os.path.exists(efficientnet_path):
checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get('model_state_dict', checkpoint)
efficientnet.load_state_dict(state_dict, strict=False)
efficientnet.eval()
models['EfficientNet'] = efficientnet
print("Successfully loaded EfficientNet model")
except Exception as e:
print(f"Error loading EfficientNet model: {str(e)}")
# Load CNN
try:
cnn = CNNModel(num_classes=len(self.labels))
cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
if os.path.exists(cnn_path):
checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
state_dict = checkpoint.get('model_state_dict', checkpoint)
cnn.load_state_dict(state_dict, strict=False)
cnn.eval()
models['CNN'] = cnn
print("Successfully loaded CNN model")
except Exception as e:
print(f"Error loading CNN model: {str(e)}")
return models
def predict(self, image: Image.Image):
"""Make predictions with both models and create comparison visualizations."""
if not self.models:
return "No trained models found. Please train the models first."
# Preprocess image
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Get predictions from both models
results = {}
probabilities = {}
for model_name, model in self.models.items():
with torch.no_grad():
output = model(img_tensor)
probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
probabilities[model_name] = probs
# Get top prediction
pred_idx = np.argmax(probs)
pred_label = self.labels[pred_idx]
pred_prob = probs[pred_idx]
results[model_name] = (pred_label, pred_prob)
# Create comparison plot
fig = plt.figure(figsize=(12, 5))
# Plot for EfficientNet
if 'EfficientNet' in probabilities:
plt.subplot(1, 2, 1)
plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
plt.title('EfficientNet Predictions')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.ylabel('Probability')
# Plot for CNN
if 'CNN' in probabilities:
plt.subplot(1, 2, 2)
plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
plt.title('CNN Predictions')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.ylabel('Probability')
plt.tight_layout()
# Create results text
text_results = "Model Predictions:\n\n"
for model_name, (label, prob) in results.items():
text_results += f"{model_name}:\n"
text_results += f"Top prediction: {label} ({prob:.2%})\n"
text_results += "All probabilities:\n"
for label, prob in zip(self.labels, probabilities[model_name]):
text_results += f" {label}: {prob:.2%}\n"
text_results += "\n"
return [
fig, # Probability plots
text_results # Detailed text results
]
def create_interface(self):
"""Create Gradio interface."""
return gr.Interface(
fn=self.predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Plot(label="Prediction Probabilities"),
gr.Textbox(label="Detailed Results", lines=10)
],
title="Animal Classifier - Model Comparison",
description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
)
def main():
"""Run the web application."""
app = AnimalClassifierApp()
interface = app.create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
if __name__ == "__main__":
main() |