Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ import torch.nn.functional as F
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
from PIL import Image, ImageEnhance
|
13 |
import torchvision.transforms as transforms
|
|
|
|
|
14 |
|
15 |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
16 |
|
@@ -19,34 +21,123 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
19 |
|
20 |
# Number of classes
|
21 |
num_classes = 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Load the pre-trained ResNet model
|
24 |
-
model = models.
|
25 |
-
for param in model.parameters():
|
26 |
-
param.requires_grad = False # Freeze feature extractor
|
27 |
|
28 |
# Modify the classifier for 6 classes with an additional hidden layer
|
29 |
-
model.fc = nn.Sequential(
|
30 |
-
|
31 |
-
)
|
|
|
|
|
32 |
|
33 |
# Load trained weights
|
34 |
-
model.load_state_dict(torch.load('
|
35 |
model.eval()
|
36 |
|
37 |
# Class labels
|
38 |
-
class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Image transformation function
|
41 |
def transform_image(image):
|
42 |
"""Preprocess the input image."""
|
43 |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
44 |
img_size=224
|
45 |
-
transform = transforms.Compose([
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
52 |
return img_tensor
|
@@ -109,18 +200,37 @@ def predict(image, brightness, contrast, hue, overlay_image, alpha):
|
|
109 |
with torch.no_grad():
|
110 |
output = model(image_tensor)
|
111 |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
# Generate Bar Chart
|
114 |
with plt.xkcd():
|
115 |
-
fig, ax = plt.subplots(figsize=(
|
116 |
-
|
|
|
|
|
|
|
117 |
ax.set_ylabel("Probability")
|
118 |
ax.set_title("Class Probabilities")
|
119 |
ax.set_ylim([0, 1])
|
120 |
-
|
|
|
|
|
121 |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
|
122 |
|
123 |
-
return final_image, fig
|
124 |
|
125 |
# Gradio Interface
|
126 |
with gr.Blocks() as interface:
|
@@ -137,10 +247,11 @@ with gr.Blocks() as interface:
|
|
137 |
|
138 |
with gr.Column():
|
139 |
processed_image = gr.Image(label="Final Processed Image")
|
|
|
140 |
bar_chart = gr.Plot(label="Class Probabilities")
|
141 |
|
142 |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
|
143 |
-
outputs = [processed_image, bar_chart]
|
144 |
|
145 |
# Event listeners for real-time updates
|
146 |
image_input.change(predict, inputs=inputs, outputs=outputs)
|
|
|
11 |
import matplotlib.pyplot as plt
|
12 |
from PIL import Image, ImageEnhance
|
13 |
import torchvision.transforms as transforms
|
14 |
+
import urllib
|
15 |
+
import json
|
16 |
|
17 |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
|
18 |
|
|
|
21 |
|
22 |
# Number of classes
|
23 |
num_classes = 6
|
24 |
+
'''
|
25 |
+
resnet imagenet
|
26 |
+
'''
|
27 |
+
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
28 |
+
with urllib.request.urlopen(url) as f:
|
29 |
+
imagenet_classes = json.load(f)
|
30 |
+
|
31 |
+
## Convert to dictionary format {0: "tench", 1: "goldfish", ..., 999: "toilet tissue"}
|
32 |
+
cifar10_classes = {int(k): v[1] for k, v in imagenet_classes.items()}
|
33 |
|
34 |
# Load the pre-trained ResNet model
|
35 |
+
model = models.resnet152(pretrained=True)
|
|
|
|
|
36 |
|
37 |
# Modify the classifier for 6 classes with an additional hidden layer
|
38 |
+
# model.fc = nn.Sequential(
|
39 |
+
# nn.Linear(model.fc.in_features, 512),
|
40 |
+
# nn.ReLU(),
|
41 |
+
# nn.Linear(512, num_classes)
|
42 |
+
# )
|
43 |
|
44 |
# Load trained weights
|
45 |
+
# model.load_state_dict(torch.load('model_old.pth', map_location=torch.device('cpu')))
|
46 |
model.eval()
|
47 |
|
48 |
# Class labels
|
49 |
+
# class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
|
50 |
+
class_labels = [cifar10_classes[i] for i in range(len(cifar10_classes))]
|
51 |
+
|
52 |
+
class MultiLayerGradCAM:
|
53 |
+
def __init__(self, model, target_layers=None):
|
54 |
+
self.model = model
|
55 |
+
self.target_layers = target_layers if target_layers else ['layer4']
|
56 |
+
self.activations = []
|
57 |
+
self.gradients = []
|
58 |
+
self.handles = []
|
59 |
+
|
60 |
+
# Register hooks
|
61 |
+
for name, module in self.model.named_modules():
|
62 |
+
if name in self.target_layers:
|
63 |
+
self.handles.append(
|
64 |
+
module.register_forward_hook(self._forward_hook)
|
65 |
+
)
|
66 |
+
self.handles.append(
|
67 |
+
module.register_backward_hook(self._backward_hook)
|
68 |
+
)
|
69 |
+
|
70 |
+
def _forward_hook(self, module, input, output):
|
71 |
+
self.activations.append(output.detach())
|
72 |
+
|
73 |
+
def _backward_hook(self, module, grad_input, grad_output):
|
74 |
+
self.gradients.append(grad_output[0].detach())
|
75 |
+
|
76 |
+
def _find_layer(self, layer_name):
|
77 |
+
for name, module in self.model.named_modules():
|
78 |
+
if name == layer_name:
|
79 |
+
return module
|
80 |
+
raise ValueError(f"Layer {layer_name} not found in model")
|
81 |
+
|
82 |
+
def generate(self, input_tensor, target_class=None):
|
83 |
+
device = next(self.model.parameters()).device
|
84 |
+
self.model.zero_grad()
|
85 |
+
|
86 |
+
# Forward pass
|
87 |
+
output = self.model(input_tensor.to(device))
|
88 |
+
pred_class = torch.argmax(output).item() if target_class is None else target_class
|
89 |
+
|
90 |
+
# Backward pass
|
91 |
+
one_hot = torch.zeros_like(output)
|
92 |
+
one_hot[0][pred_class] = 1
|
93 |
+
output.backward(gradient=one_hot)
|
94 |
+
|
95 |
+
# Process activations and gradients
|
96 |
+
heatmaps = []
|
97 |
+
for act, grad in zip(self.activations, reversed(self.gradients)):
|
98 |
+
# Compute weights
|
99 |
+
weights = F.adaptive_avg_pool2d(grad, 1)
|
100 |
+
|
101 |
+
# Create weighted combination of activation maps
|
102 |
+
cam = torch.mul(act, weights).sum(dim=1, keepdim=True)
|
103 |
+
cam = F.relu(cam)
|
104 |
+
print(cam.shape)
|
105 |
+
# Upsample to input size
|
106 |
+
cam = F.interpolate(cam, size=input_tensor.shape[2:],
|
107 |
+
mode='bilinear', align_corners=False)
|
108 |
+
heatmaps.append(cam.squeeze().cpu().numpy())
|
109 |
+
|
110 |
+
# Combine heatmaps from different layers
|
111 |
+
combined_heatmap = np.mean(heatmaps, axis=0)
|
112 |
+
# print(combined_heatmap.shape)
|
113 |
+
# Normalize
|
114 |
+
combined_heatmap = np.maximum(combined_heatmap, 0)
|
115 |
+
combined_heatmap = (combined_heatmap - combined_heatmap.min()) / \
|
116 |
+
(combined_heatmap.max() - combined_heatmap.min() + 1e-10)
|
117 |
+
|
118 |
+
|
119 |
+
return combined_heatmap, pred_class
|
120 |
+
|
121 |
+
def __del__(self):
|
122 |
+
for handle in self.handles:
|
123 |
+
handle.remove()
|
124 |
+
|
125 |
+
gradcam = MultiLayerGradCAM(model, target_layers=['layer3', 'layer4'])
|
126 |
|
127 |
# Image transformation function
|
128 |
def transform_image(image):
|
129 |
"""Preprocess the input image."""
|
130 |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
|
131 |
img_size=224
|
132 |
+
transform = transforms.Compose([ #IMAGENET
|
133 |
+
transforms.Resize(256), # Resize shorter side to 256, keeping aspect ratio
|
134 |
+
transforms.CenterCrop(224), # Crop the center 224x224 region
|
135 |
+
transforms.ToTensor(), # Convert to tensor (scales to [0,1])
|
136 |
+
transforms.Normalize( # Normalize using ImageNet mean & std
|
137 |
+
mean=[0.485, 0.456, 0.406],
|
138 |
+
std=[0.229, 0.224, 0.225]
|
139 |
+
)
|
140 |
+
])
|
141 |
|
142 |
img_tensor = transform(image).unsqueeze(0).to(device)
|
143 |
return img_tensor
|
|
|
200 |
with torch.no_grad():
|
201 |
output = model(image_tensor)
|
202 |
probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
|
203 |
+
# pred_class = np.argmax(probabilities)
|
204 |
+
# top_5 = torch.topk(probabilities, 5)
|
205 |
+
|
206 |
+
heatmap, _ = gradcam.generate(image_tensor)
|
207 |
+
|
208 |
+
# Create GradCAM overlay
|
209 |
+
final_np = np.array(final_image)
|
210 |
+
heatmap = cv2.resize(heatmap, (final_np.shape[1], final_np.shape[0]))
|
211 |
+
heatmap = np.uint8(255 * heatmap)
|
212 |
+
heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
213 |
+
heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
|
214 |
+
superimposed = cv2.addWeighted(heatmap_rgb, 0.5, final_np, 0.5, 0)
|
215 |
+
gradcam_image = Image.fromarray(superimposed)
|
216 |
+
|
217 |
|
218 |
# Generate Bar Chart
|
219 |
with plt.xkcd():
|
220 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
221 |
+
top5_indices = np.argsort(probabilities)[-5:][::-1] # Indices of top 5 probabilities
|
222 |
+
top5_probs = probabilities[top5_indices]
|
223 |
+
top5_labels = [class_labels[i] for i in top5_indices]
|
224 |
+
ax.bar(top5_labels, top5_probs, color='skyblue')
|
225 |
ax.set_ylabel("Probability")
|
226 |
ax.set_title("Class Probabilities")
|
227 |
ax.set_ylim([0, 1])
|
228 |
+
plt.tight_layout(pad=3)
|
229 |
+
ax.set_xticklabels(top5_labels, rotation=45, ha="right", fontsize=8)
|
230 |
+
for i, v in enumerate(top5_probs):
|
231 |
ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
|
232 |
|
233 |
+
return final_image, gradcam_image, fig
|
234 |
|
235 |
# Gradio Interface
|
236 |
with gr.Blocks() as interface:
|
|
|
247 |
|
248 |
with gr.Column():
|
249 |
processed_image = gr.Image(label="Final Processed Image")
|
250 |
+
gradcam_output = gr.Image(label="GradCAM Overlay")
|
251 |
bar_chart = gr.Plot(label="Class Probabilities")
|
252 |
|
253 |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
|
254 |
+
outputs = [processed_image, gradcam_output, bar_chart]
|
255 |
|
256 |
# Event listeners for real-time updates
|
257 |
image_input.change(predict, inputs=inputs, outputs=outputs)
|