Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import ViTImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import numpy as np | |
import time | |
# ----------------------------- | |
# Configuration and Setup | |
# ----------------------------- | |
# Force Gradio to use CUDA (if available) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Model path | |
model_path = "final_model" | |
# Load image processor and model | |
try: | |
print("Loading image processor...") | |
processor = ViTImageProcessor.from_pretrained(model_path) | |
print("Loading model...") | |
model = AutoModelForImageClassification.from_pretrained(model_path) | |
model = model.to(device) | |
model.eval() # Important for deterministic behavior | |
except Exception as e: | |
raise RuntimeError(f"Error loading model: {e}") | |
# Attempt to load label mappings | |
try: | |
labels = model.config.id2label | |
assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping" | |
except Exception as e: | |
print(f"β οΈ Labels not found in model config: {e}") | |
labels = {i: f"Class {i}" for i in range(model.config.num_labels)} | |
# ----------------------------- | |
# Standalone Test Mode (Optional) | |
# ----------------------------- | |
def test_inference(): | |
"""Run inference outside Gradio to verify model works""" | |
dummy_img = Image.new('RGB', (224, 224), color='red') # Create a dummy image | |
print("Running standalone inference test...") | |
try: | |
inputs = processor(images=dummy_img, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
outputs = model(**inputs) | |
print("β Model inference test successful") | |
except Exception as e: | |
print(f"β Inference test failed: {e}") | |
# ----------------------------- | |
# Prediction Function | |
# ----------------------------- | |
def predict(image): | |
if image is None: | |
return "No image uploaded." | |
print("\n[INFO] Starting prediction pipeline...") | |
# Step 1: Preprocessing | |
print("[STEP 1] Preprocessing image...") | |
try: | |
start = time.time() | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}") | |
print(f"[DEBUG] Time taken: {time.time() - start:.2f}s") | |
except Exception as e: | |
return f"β Error in preprocessing: {e}" | |
# Step 2: Inference | |
print("[STEP 2] Running inference...") | |
try: | |
start = time.time() | |
with torch.inference_mode(): | |
outputs = model(**inputs) | |
print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s") | |
except Exception as e: | |
return f"β Error in model inference: {e}" | |
# Step 3: Post-processing | |
print("[STEP 3] Processing output...") | |
try: | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
top5_probs, top5_indices = torch.topk(probs, 5) | |
result = "" | |
for i in range(5): | |
idx = top5_indices[0][i].item() | |
label = labels.get(idx, f"Unknown class {idx}") | |
prob = top5_probs[0][i].item() * 100 | |
result += f"{i + 1}. {label} β {prob:.2f}%\n" | |
except Exception as e: | |
return f"β Error post-processing: {e}" | |
print("[INFO] Prediction complete β \n") | |
return result.strip() | |
# ----------------------------- | |
# Gradio Interface | |
# ----------------------------- | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload an Image"), | |
outputs=gr.Textbox(label="Top 5 Predictions"), | |
title="Fine-Tuned ViT Image Classifier", | |
description="Upload an image to get the top 5 predicted classes with confidence scores.", | |
allow_flagging="never", | |
examples=[["examples/test_image.jpg"]] if "examples" in locals() else None | |
) | |
if __name__ == "__main__": | |
print("\nπ Launching Gradio interface...\n") | |
test_inference() # Optional: Run test before launching | |
interface.launch(share=True) |