import torch import gradio as gr from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import numpy as np import time # ----------------------------- # Configuration and Setup # ----------------------------- # Force Gradio to use CUDA (if available) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Model path model_path = "final_model" # Load image processor and model try: print("Loading image processor...") processor = ViTImageProcessor.from_pretrained(model_path) print("Loading model...") model = AutoModelForImageClassification.from_pretrained(model_path) model = model.to(device) model.eval() # Important for deterministic behavior except Exception as e: raise RuntimeError(f"Error loading model: {e}") # Attempt to load label mappings try: labels = model.config.id2label assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping" except Exception as e: print(f"⚠️ Labels not found in model config: {e}") labels = {i: f"Class {i}" for i in range(model.config.num_labels)} # ----------------------------- # Standalone Test Mode (Optional) # ----------------------------- def test_inference(): """Run inference outside Gradio to verify model works""" dummy_img = Image.new('RGB', (224, 224), color='red') # Create a dummy image print("Running standalone inference test...") try: inputs = processor(images=dummy_img, return_tensors="pt").to(device) with torch.inference_mode(): outputs = model(**inputs) print("✅ Model inference test successful") except Exception as e: print(f"❌ Inference test failed: {e}") # ----------------------------- # Prediction Function # ----------------------------- def predict(image): if image is None: return "No image uploaded." print("\n[INFO] Starting prediction pipeline...") # Step 1: Preprocessing print("[STEP 1] Preprocessing image...") try: start = time.time() inputs = processor(images=image, return_tensors="pt").to(device) print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}") print(f"[DEBUG] Time taken: {time.time() - start:.2f}s") except Exception as e: return f"❌ Error in preprocessing: {e}" # Step 2: Inference print("[STEP 2] Running inference...") try: start = time.time() with torch.inference_mode(): outputs = model(**inputs) print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s") except Exception as e: return f"❌ Error in model inference: {e}" # Step 3: Post-processing print("[STEP 3] Processing output...") try: probs = torch.nn.functional.softmax(outputs.logits, dim=1) top5_probs, top5_indices = torch.topk(probs, 5) result = "" for i in range(5): idx = top5_indices[0][i].item() label = labels.get(idx, f"Unknown class {idx}") prob = top5_probs[0][i].item() * 100 result += f"{i + 1}. {label} — {prob:.2f}%\n" except Exception as e: return f"❌ Error post-processing: {e}" print("[INFO] Prediction complete ✅\n") return result.strip() # ----------------------------- # Gradio Interface # ----------------------------- interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=gr.Textbox(label="Top 5 Predictions"), title="Fine-Tuned ViT Image Classifier", description="Upload an image to get the top 5 predicted classes with confidence scores.", allow_flagging="never", examples=[["examples/test_image.jpg"]] if "examples" in locals() else None ) if __name__ == "__main__": print("\n🚀 Launching Gradio interface...\n") test_inference() # Optional: Run test before launching interface.launch(share=True)