import os import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from ResNet_for_CC import CC_model # Import fixed model # Set device (CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained CC_model model_path = "CC_net.pt" model = CC_model(num_classes=14) # Load model weights state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # Clothing1M Class Labels class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # Define image transformations transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 🔹 **Dynamically Fetch All Example Images (Including .webp)** def get_example_images(): examples_dir = "examples" if not os.path.exists(examples_dir): print("[WARNING] 'examples/' directory does not exist.") return [] # Fetch all image files (including .webp) image_files = sorted([ os.path.join(examples_dir, f) for f in os.listdir(examples_dir) if f.lower().endswith((".png", ".jpg", ".jpeg", ".webp")) ]) if not image_files: print("[WARNING] No images found in 'examples/' directory.") print(f"[INFO] Found {len(image_files)} images in 'examples/'") return image_files # 🔹 **Classification Function** def classify_image(image): print("\n[DEBUG] Received image for classification.") try: image = transform(image).unsqueeze(0).to(device) print("[DEBUG] Image transformed and moved to device.") with torch.no_grad(): output = model(image) print(f"[DEBUG] Model output shape: {output.shape}") print(f"[DEBUG] Model output values: {output}") if output.shape[1] != 14: return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}." # Convert logits to probabilities probabilities = F.softmax(output, dim=1) print(f"[DEBUG] Softmax probabilities: {probabilities}") # Print class predictions and probabilities for i, prob in enumerate(probabilities[0].tolist()): print(f"[INFO] {class_labels[i]}: {prob * 100:.2f}%") # Get predicted class index predicted_class = torch.argmax(probabilities, dim=1).item() print(f"[DEBUG] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})") # Validate prediction if 0 <= predicted_class < len(class_labels): predicted_label = class_labels[predicted_class] confidence = probabilities[0][predicted_class].item() * 100 return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" else: return "[ERROR] Model returned an invalid class index." except Exception as e: print(f"[ERROR] Exception during classification: {e}") return "Error in classification. Check console for details." # 🔹 **Create Gradio Interface with Dynamic Examples** interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="Clothing1M Image Classifier", description="Upload a clothing image, or select an example below to classify it.", examples=get_example_images() # Dynamically load all images including .webp ) # Run the Interface if __name__ == "__main__": print("[INFO] Launching Gradio interface...") interface.launch()