DermaScanBeta / app.py
ZDPLI's picture
Rename appSWA.py to app.py
80a0a9d verified
raw
history blame
4.12 kB
import torch
import gradio as gr
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
import time
# -----------------------------
# Configuration and Setup
# -----------------------------
# Force Gradio to use CUDA (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model path
model_path = "final_model"
# Load image processor and model
try:
print("Loading image processor...")
processor = ViTImageProcessor.from_pretrained(model_path)
print("Loading model...")
model = AutoModelForImageClassification.from_pretrained(model_path)
model = model.to(device)
model.eval() # Important for deterministic behavior
except Exception as e:
raise RuntimeError(f"Error loading model: {e}")
# Attempt to load label mappings
try:
labels = model.config.id2label
assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping"
except Exception as e:
print(f"⚠️ Labels not found in model config: {e}")
labels = {i: f"Class {i}" for i in range(model.config.num_labels)}
# -----------------------------
# Standalone Test Mode (Optional)
# -----------------------------
def test_inference():
"""Run inference outside Gradio to verify model works"""
dummy_img = Image.new('RGB', (224, 224), color='red') # Create a dummy image
print("Running standalone inference test...")
try:
inputs = processor(images=dummy_img, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model(**inputs)
print("βœ… Model inference test successful")
except Exception as e:
print(f"❌ Inference test failed: {e}")
# -----------------------------
# Prediction Function
# -----------------------------
def predict(image):
if image is None:
return "No image uploaded."
print("\n[INFO] Starting prediction pipeline...")
# Step 1: Preprocessing
print("[STEP 1] Preprocessing image...")
try:
start = time.time()
inputs = processor(images=image, return_tensors="pt").to(device)
print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}")
print(f"[DEBUG] Time taken: {time.time() - start:.2f}s")
except Exception as e:
return f"❌ Error in preprocessing: {e}"
# Step 2: Inference
print("[STEP 2] Running inference...")
try:
start = time.time()
with torch.inference_mode():
outputs = model(**inputs)
print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s")
except Exception as e:
return f"❌ Error in model inference: {e}"
# Step 3: Post-processing
print("[STEP 3] Processing output...")
try:
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
top5_probs, top5_indices = torch.topk(probs, 5)
result = ""
for i in range(5):
idx = top5_indices[0][i].item()
label = labels.get(idx, f"Unknown class {idx}")
prob = top5_probs[0][i].item() * 100
result += f"{i + 1}. {label} β€” {prob:.2f}%\n"
except Exception as e:
return f"❌ Error post-processing: {e}"
print("[INFO] Prediction complete βœ…\n")
return result.strip()
# -----------------------------
# Gradio Interface
# -----------------------------
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Textbox(label="Top 5 Predictions"),
title="Fine-Tuned ViT Image Classifier",
description="Upload an image to get the top 5 predicted classes with confidence scores.",
allow_flagging="never",
examples=[["examples/test_image.jpg"]] if "examples" in locals() else None
)
if __name__ == "__main__":
print("\nπŸš€ Launching Gradio interface...\n")
test_inference() # Optional: Run test before launching
interface.launch(share=True)