ZDPLI commited on
Commit
56877e9
Β·
verified Β·
1 Parent(s): 5a3ccc8

Upload appSWA.py

Browse files
Files changed (1) hide show
  1. appSWA.py +122 -0
appSWA.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import ViTImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import numpy as np
6
+ import time
7
+
8
+ # -----------------------------
9
+ # Configuration and Setup
10
+ # -----------------------------
11
+
12
+ # Force Gradio to use CUDA (if available)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ # Model path
17
+ model_path = "final_model"
18
+
19
+ # Load image processor and model
20
+ try:
21
+ print("Loading image processor...")
22
+ processor = ViTImageProcessor.from_pretrained(model_path)
23
+
24
+ print("Loading model...")
25
+ model = AutoModelForImageClassification.from_pretrained(model_path)
26
+ model = model.to(device)
27
+ model.eval() # Important for deterministic behavior
28
+ except Exception as e:
29
+ raise RuntimeError(f"Error loading model: {e}")
30
+
31
+ # Attempt to load label mappings
32
+ try:
33
+ labels = model.config.id2label
34
+ assert isinstance(labels, dict) and len(labels) > 0, "Invalid or empty id2label mapping"
35
+ except Exception as e:
36
+ print(f"⚠️ Labels not found in model config: {e}")
37
+ labels = {i: f"Class {i}" for i in range(model.config.num_labels)}
38
+
39
+
40
+ # -----------------------------
41
+ # Standalone Test Mode (Optional)
42
+ # -----------------------------
43
+ def test_inference():
44
+ """Run inference outside Gradio to verify model works"""
45
+ dummy_img = Image.new('RGB', (224, 224), color='red') # Create a dummy image
46
+ print("Running standalone inference test...")
47
+ try:
48
+ inputs = processor(images=dummy_img, return_tensors="pt").to(device)
49
+ with torch.inference_mode():
50
+ outputs = model(**inputs)
51
+ print("βœ… Model inference test successful")
52
+ except Exception as e:
53
+ print(f"❌ Inference test failed: {e}")
54
+
55
+
56
+ # -----------------------------
57
+ # Prediction Function
58
+ # -----------------------------
59
+
60
+ def predict(image):
61
+ if image is None:
62
+ return "No image uploaded."
63
+
64
+ print("\n[INFO] Starting prediction pipeline...")
65
+
66
+ # Step 1: Preprocessing
67
+ print("[STEP 1] Preprocessing image...")
68
+ try:
69
+ start = time.time()
70
+ inputs = processor(images=image, return_tensors="pt").to(device)
71
+ print(f"[DEBUG] Input shape: {inputs['pixel_values'].shape}")
72
+ print(f"[DEBUG] Time taken: {time.time() - start:.2f}s")
73
+ except Exception as e:
74
+ return f"❌ Error in preprocessing: {e}"
75
+
76
+ # Step 2: Inference
77
+ print("[STEP 2] Running inference...")
78
+ try:
79
+ start = time.time()
80
+ with torch.inference_mode():
81
+ outputs = model(**inputs)
82
+ print(f"[DEBUG] Inference completed in {time.time() - start:.2f}s")
83
+ except Exception as e:
84
+ return f"❌ Error in model inference: {e}"
85
+
86
+ # Step 3: Post-processing
87
+ print("[STEP 3] Processing output...")
88
+ try:
89
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
90
+ top5_probs, top5_indices = torch.topk(probs, 5)
91
+
92
+ result = ""
93
+ for i in range(5):
94
+ idx = top5_indices[0][i].item()
95
+ label = labels.get(idx, f"Unknown class {idx}")
96
+ prob = top5_probs[0][i].item() * 100
97
+ result += f"{i + 1}. {label} β€” {prob:.2f}%\n"
98
+ except Exception as e:
99
+ return f"❌ Error post-processing: {e}"
100
+
101
+ print("[INFO] Prediction complete βœ…\n")
102
+ return result.strip()
103
+
104
+
105
+ # -----------------------------
106
+ # Gradio Interface
107
+ # -----------------------------
108
+
109
+ interface = gr.Interface(
110
+ fn=predict,
111
+ inputs=gr.Image(type="pil", label="Upload an Image"),
112
+ outputs=gr.Textbox(label="Top 5 Predictions"),
113
+ title="Fine-Tuned ViT Image Classifier",
114
+ description="Upload an image to get the top 5 predicted classes with confidence scores.",
115
+ allow_flagging="never",
116
+ examples=[["examples/test_image.jpg"]] if "examples" in locals() else None
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ print("\nπŸš€ Launching Gradio interface...\n")
121
+ test_inference() # Optional: Run test before launching
122
+ interface.launch(share=True)