File size: 4,117 Bytes
56877e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import gradio as gr
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
import time

# -----------------------------
# Configuration and Setup
# -----------------------------

# Force Gradio to use CUDA (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model path
model_path = "final_model"

# Load image processor and model
try:
    print("Loading image processor...")
    processor = ViTImageProcessor.from_pretrained(model_path)

    print("Loading model...")
    model = AutoModelForImageClassification.from_pretrained(model_path)
    model = model.to(device)
    model.eval()  # Important for deterministic behavior
except Exception as e:
    raise RuntimeError(f"Error loading model: {e}")

# Attempt to load label mappings
try:
    labels = model.config.id2label
    assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping"
except Exception as e:
    print(f"⚠️ Labels not found in model config: {e}")
    labels = {i: f"Class {i}" for i in range(model.config.num_labels)}


# -----------------------------
# Standalone Test Mode (Optional)
# -----------------------------
def test_inference():
    """Run inference outside Gradio to verify model works"""
    dummy_img = Image.new('RGB', (224, 224), color='red')  # Create a dummy image
    print("Running standalone inference test...")
    try:
        inputs = processor(images=dummy_img, return_tensors="pt").to(device)
        with torch.inference_mode():
            outputs = model(**inputs)
        print("βœ… Model inference test successful")
    except Exception as e:
        print(f"❌ Inference test failed: {e}")


# -----------------------------
# Prediction Function
# -----------------------------

def predict(image):
    if image is None:
        return "No image uploaded."

    print("\n[INFO] Starting prediction pipeline...")

    # Step 1: Preprocessing
    print("[STEP 1] Preprocessing image...")
    try:
        start = time.time()
        inputs = processor(images=image, return_tensors="pt").to(device)
        print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}")
        print(f"[DEBUG] Time taken: {time.time() - start:.2f}s")
    except Exception as e:
        return f"❌ Error in preprocessing: {e}"

    # Step 2: Inference
    print("[STEP 2] Running inference...")
    try:
        start = time.time()
        with torch.inference_mode():
            outputs = model(**inputs)
        print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s")
    except Exception as e:
        return f"❌ Error in model inference: {e}"

    # Step 3: Post-processing
    print("[STEP 3] Processing output...")
    try:
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        top5_probs, top5_indices = torch.topk(probs, 5)

        result = ""
        for i in range(5):
            idx = top5_indices[0][i].item()
            label = labels.get(idx, f"Unknown class {idx}")
            prob = top5_probs[0][i].item() * 100
            result += f"{i + 1}. {label} β€” {prob:.2f}%\n"
    except Exception as e:
        return f"❌ Error post-processing: {e}"

    print("[INFO] Prediction complete βœ…\n")
    return result.strip()


# -----------------------------
# Gradio Interface
# -----------------------------

interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Textbox(label="Top 5 Predictions"),
    title="Fine-Tuned ViT Image Classifier",
    description="Upload an image to get the top 5 predicted classes with confidence scores.",
    allow_flagging="never",
    examples=[["examples/test_image.jpg"]] if "examples" in locals() else None
)

if __name__ == "__main__":
    print("\nπŸš€ Launching Gradio interface...\n")
    test_inference()  # Optional: Run test before launching
    interface.launch(share=True)