File size: 4,043 Bytes
7f2d78e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import json
import os
from train_resnet50 import ResNet, Bottleneck

# Load ImageNet class labels
try:
    with open("imagenet_classes.json", "r") as f:
        class_labels = json.load(f)
    print(f"Loaded {len(class_labels)} class labels")
except FileNotFoundError:
    print("Warning: imagenet_classes.json not found, creating simplified labels")
    # Fallback to a simplified version
    class_labels = {str(i): f"class_{i}" for i in range(1000)}
except json.JSONDecodeError:
    print("Warning: Error parsing imagenet_classes.json, using simplified labels")
    class_labels = {str(i): f"class_{i}" for i in range(1000)}
except Exception as e:
    print(f"Warning: Unexpected error loading class labels: {e}")
    class_labels = {str(i): f"class_{i}" for i in range(1000)}


def create_model():
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    return model


def load_model(model_path):
    model = create_model()
    try:
        checkpoint = torch.load(model_path, map_location="cpu")

        # Handle DataParallel/DDP state dict
        state_dict = checkpoint["model_state_dict"]
        new_state_dict = {}
        for k, v in state_dict.items():
            name = k.replace("module.", "") if k.startswith("module.") else k
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict)
        model.eval()
        print("Model loaded successfully!")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Loading pretrained ResNet50 as fallback...")
        model = torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
        model.eval()
        return model


# Preprocessing transform
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Global variable for model
global_model = None


def predict(image):
    global global_model

    # Load model only once
    if global_model is None:
        try:
            global_model = load_model("best_model.pth")
        except Exception as e:
            print(f"Error loading model: {e}")
            return None

    # Preprocess image
    if image is None:
        return None

    try:
        image = Image.fromarray(image)
        image = transform(image).unsqueeze(0)

        # Make prediction
        with torch.no_grad():
            outputs = global_model(image)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

        # Get top 5 predictions
        top5_prob, top5_catid = torch.topk(probabilities, 5)

        # Create results dictionary
        results = []
        for i in range(5):
            class_idx = top5_catid[i].item()
            # Use list indexing instead of dictionary get()
            class_label = (
                class_labels[class_idx]
                if class_idx < len(class_labels)
                else f"class_{class_idx}"
            )
            results.append(
                {
                    "label": class_label,
                    "class_id": class_idx,
                    "confidence": float(top5_prob[i].item()),
                }
            )

        return results
    except Exception as e:
        print(f"Error during prediction: {e}")
        print(f"Class indices: {[idx.item() for idx in top5_catid]}")  # Debug info
        return None


# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(),
    outputs=gr.JSON(),
    title="ResNet50 ImageNet Classifier",
    description="Upload an image to get top-5 predictions from our trained ResNet50 model.",
)

# Launch the app
if __name__ == "__main__":
    iface.launch(share=True)