Spaces:
Configuration error
Configuration error
import os | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from ResNet_for_CC import CC_model # Import fixed model | |
# Set device (CPU/GPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the trained CC_model | |
model_path = "CC_net.pt" | |
model = CC_model(num_classes=14) | |
# Load model weights | |
state_dict = torch.load(model_path, map_location=device) | |
model.load_state_dict(state_dict, strict=False) | |
model.to(device) | |
model.eval() | |
# Clothing1M Class Labels | |
class_labels = [ | |
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
"Vest", "Underwear" | |
] | |
# Define image transformations | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# ๐น **Dynamically Fetch All Example Images (Including .webp)** | |
def get_example_images(): | |
examples_dir = "examples" | |
if not os.path.exists(examples_dir): | |
print("[WARNING] 'examples/' directory does not exist.") | |
return [] | |
# Fetch all image files (including .webp) | |
image_files = sorted([ | |
os.path.join(examples_dir, f) for f in os.listdir(examples_dir) | |
if f.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) | |
]) | |
if not image_files: | |
print("[WARNING] No images found in 'examples/' directory.") | |
print(f"[INFO] Found {len(image_files)} images in 'examples/'") | |
return image_files | |
# ๐น **Classification Function** | |
def classify_image(image): | |
print("\n[DEBUG] Received image for classification.") | |
try: | |
image = transform(image).unsqueeze(0).to(device) | |
print("[DEBUG] Image transformed and moved to device.") | |
with torch.no_grad(): | |
output = model(image) | |
print(f"[DEBUG] Model output shape: {output.shape}") | |
print(f"[DEBUG] Model output values: {output}") | |
if output.shape[1] != 14: | |
return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}." | |
# Convert logits to probabilities | |
probabilities = F.softmax(output, dim=1) | |
print(f"[DEBUG] Softmax probabilities: {probabilities}") | |
# Print class predictions and probabilities | |
for i, prob in enumerate(probabilities[0].tolist()): | |
print(f"[INFO] {class_labels[i]}: {prob * 100:.2f}%") | |
# Get predicted class index | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
print(f"[DEBUG] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})") | |
# Validate prediction | |
if 0 <= predicted_class < len(class_labels): | |
predicted_label = class_labels[predicted_class] | |
confidence = probabilities[0][predicted_class].item() * 100 | |
return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" | |
else: | |
return "[ERROR] Model returned an invalid class index." | |
except Exception as e: | |
print(f"[ERROR] Exception during classification: {e}") | |
return "Error in classification. Check console for details." | |
# ๐น **Create Gradio Interface with Dynamic Examples** | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="Clothing1M Image Classifier", | |
description="Upload a clothing image, or select an example below to classify it.", | |
examples=get_example_images() # Dynamically load all images including .webp | |
) | |
# Run the Interface | |
if __name__ == "__main__": | |
print("[INFO] Launching Gradio interface...") | |
interface.launch() | |