darpanaswal commited on
Commit
d5ca01c
·
verified ·
1 Parent(s): c2c4a52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -18
app.py CHANGED
@@ -11,6 +11,8 @@ import torch.nn.functional as F
11
  import matplotlib.pyplot as plt
12
  from PIL import Image, ImageEnhance
13
  import torchvision.transforms as transforms
 
 
14
 
15
  ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
16
 
@@ -19,34 +21,123 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  # Number of classes
21
  num_classes = 6
 
 
 
 
 
 
 
 
 
22
 
23
  # Load the pre-trained ResNet model
24
- model = models.resnet18(pretrained=True)
25
- for param in model.parameters():
26
- param.requires_grad = False # Freeze feature extractor
27
 
28
  # Modify the classifier for 6 classes with an additional hidden layer
29
- model.fc = nn.Sequential(
30
- nn.Linear(model.fc.in_features, num_classes)
31
- )
 
 
32
 
33
  # Load trained weights
34
- model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
35
  model.eval()
36
 
37
  # Class labels
38
- class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Image transformation function
41
  def transform_image(image):
42
  """Preprocess the input image."""
43
  mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
44
  img_size=224
45
- transform = transforms.Compose([
46
- transforms.Resize((img_size, img_size)),
47
- transforms.ToTensor(),
48
- transforms.Normalize(mean, std)
49
- ])
 
 
 
 
50
 
51
  img_tensor = transform(image).unsqueeze(0).to(device)
52
  return img_tensor
@@ -109,18 +200,37 @@ def predict(image, brightness, contrast, hue, overlay_image, alpha):
109
  with torch.no_grad():
110
  output = model(image_tensor)
111
  probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Generate Bar Chart
114
  with plt.xkcd():
115
- fig, ax = plt.subplots(figsize=(5, 3))
116
- ax.bar(class_labels, probabilities, color='skyblue')
 
 
 
117
  ax.set_ylabel("Probability")
118
  ax.set_title("Class Probabilities")
119
  ax.set_ylim([0, 1])
120
- for i, v in enumerate(probabilities):
 
 
121
  ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
122
 
123
- return final_image, fig
124
 
125
  # Gradio Interface
126
  with gr.Blocks() as interface:
@@ -137,10 +247,11 @@ with gr.Blocks() as interface:
137
 
138
  with gr.Column():
139
  processed_image = gr.Image(label="Final Processed Image")
 
140
  bar_chart = gr.Plot(label="Class Probabilities")
141
 
142
  inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
143
- outputs = [processed_image, bar_chart]
144
 
145
  # Event listeners for real-time updates
146
  image_input.change(predict, inputs=inputs, outputs=outputs)
 
11
  import matplotlib.pyplot as plt
12
  from PIL import Image, ImageEnhance
13
  import torchvision.transforms as transforms
14
+ import urllib
15
+ import json
16
 
17
  ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
18
 
 
21
 
22
  # Number of classes
23
  num_classes = 6
24
+ '''
25
+ resnet imagenet
26
+ '''
27
+ url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
28
+ with urllib.request.urlopen(url) as f:
29
+ imagenet_classes = json.load(f)
30
+
31
+ ## Convert to dictionary format {0: "tench", 1: "goldfish", ..., 999: "toilet tissue"}
32
+ cifar10_classes = {int(k): v[1] for k, v in imagenet_classes.items()}
33
 
34
  # Load the pre-trained ResNet model
35
+ model = models.resnet152(pretrained=True)
 
 
36
 
37
  # Modify the classifier for 6 classes with an additional hidden layer
38
+ # model.fc = nn.Sequential(
39
+ # nn.Linear(model.fc.in_features, 512),
40
+ # nn.ReLU(),
41
+ # nn.Linear(512, num_classes)
42
+ # )
43
 
44
  # Load trained weights
45
+ # model.load_state_dict(torch.load('model_old.pth', map_location=torch.device('cpu')))
46
  model.eval()
47
 
48
  # Class labels
49
+ # class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
50
+ class_labels = [cifar10_classes[i] for i in range(len(cifar10_classes))]
51
+
52
+ class MultiLayerGradCAM:
53
+ def __init__(self, model, target_layers=None):
54
+ self.model = model
55
+ self.target_layers = target_layers if target_layers else ['layer4']
56
+ self.activations = []
57
+ self.gradients = []
58
+ self.handles = []
59
+
60
+ # Register hooks
61
+ for name, module in self.model.named_modules():
62
+ if name in self.target_layers:
63
+ self.handles.append(
64
+ module.register_forward_hook(self._forward_hook)
65
+ )
66
+ self.handles.append(
67
+ module.register_backward_hook(self._backward_hook)
68
+ )
69
+
70
+ def _forward_hook(self, module, input, output):
71
+ self.activations.append(output.detach())
72
+
73
+ def _backward_hook(self, module, grad_input, grad_output):
74
+ self.gradients.append(grad_output[0].detach())
75
+
76
+ def _find_layer(self, layer_name):
77
+ for name, module in self.model.named_modules():
78
+ if name == layer_name:
79
+ return module
80
+ raise ValueError(f"Layer {layer_name} not found in model")
81
+
82
+ def generate(self, input_tensor, target_class=None):
83
+ device = next(self.model.parameters()).device
84
+ self.model.zero_grad()
85
+
86
+ # Forward pass
87
+ output = self.model(input_tensor.to(device))
88
+ pred_class = torch.argmax(output).item() if target_class is None else target_class
89
+
90
+ # Backward pass
91
+ one_hot = torch.zeros_like(output)
92
+ one_hot[0][pred_class] = 1
93
+ output.backward(gradient=one_hot)
94
+
95
+ # Process activations and gradients
96
+ heatmaps = []
97
+ for act, grad in zip(self.activations, reversed(self.gradients)):
98
+ # Compute weights
99
+ weights = F.adaptive_avg_pool2d(grad, 1)
100
+
101
+ # Create weighted combination of activation maps
102
+ cam = torch.mul(act, weights).sum(dim=1, keepdim=True)
103
+ cam = F.relu(cam)
104
+ print(cam.shape)
105
+ # Upsample to input size
106
+ cam = F.interpolate(cam, size=input_tensor.shape[2:],
107
+ mode='bilinear', align_corners=False)
108
+ heatmaps.append(cam.squeeze().cpu().numpy())
109
+
110
+ # Combine heatmaps from different layers
111
+ combined_heatmap = np.mean(heatmaps, axis=0)
112
+ # print(combined_heatmap.shape)
113
+ # Normalize
114
+ combined_heatmap = np.maximum(combined_heatmap, 0)
115
+ combined_heatmap = (combined_heatmap - combined_heatmap.min()) / \
116
+ (combined_heatmap.max() - combined_heatmap.min() + 1e-10)
117
+
118
+
119
+ return combined_heatmap, pred_class
120
+
121
+ def __del__(self):
122
+ for handle in self.handles:
123
+ handle.remove()
124
+
125
+ gradcam = MultiLayerGradCAM(model, target_layers=['layer3', 'layer4'])
126
 
127
  # Image transformation function
128
  def transform_image(image):
129
  """Preprocess the input image."""
130
  mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
131
  img_size=224
132
+ transform = transforms.Compose([ #IMAGENET
133
+ transforms.Resize(256), # Resize shorter side to 256, keeping aspect ratio
134
+ transforms.CenterCrop(224), # Crop the center 224x224 region
135
+ transforms.ToTensor(), # Convert to tensor (scales to [0,1])
136
+ transforms.Normalize( # Normalize using ImageNet mean & std
137
+ mean=[0.485, 0.456, 0.406],
138
+ std=[0.229, 0.224, 0.225]
139
+ )
140
+ ])
141
 
142
  img_tensor = transform(image).unsqueeze(0).to(device)
143
  return img_tensor
 
200
  with torch.no_grad():
201
  output = model(image_tensor)
202
  probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
203
+ # pred_class = np.argmax(probabilities)
204
+ # top_5 = torch.topk(probabilities, 5)
205
+
206
+ heatmap, _ = gradcam.generate(image_tensor)
207
+
208
+ # Create GradCAM overlay
209
+ final_np = np.array(final_image)
210
+ heatmap = cv2.resize(heatmap, (final_np.shape[1], final_np.shape[0]))
211
+ heatmap = np.uint8(255 * heatmap)
212
+ heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
213
+ heatmap_rgb = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
214
+ superimposed = cv2.addWeighted(heatmap_rgb, 0.5, final_np, 0.5, 0)
215
+ gradcam_image = Image.fromarray(superimposed)
216
+
217
 
218
  # Generate Bar Chart
219
  with plt.xkcd():
220
+ fig, ax = plt.subplots(figsize=(6, 4))
221
+ top5_indices = np.argsort(probabilities)[-5:][::-1] # Indices of top 5 probabilities
222
+ top5_probs = probabilities[top5_indices]
223
+ top5_labels = [class_labels[i] for i in top5_indices]
224
+ ax.bar(top5_labels, top5_probs, color='skyblue')
225
  ax.set_ylabel("Probability")
226
  ax.set_title("Class Probabilities")
227
  ax.set_ylim([0, 1])
228
+ plt.tight_layout(pad=3)
229
+ ax.set_xticklabels(top5_labels, rotation=45, ha="right", fontsize=8)
230
+ for i, v in enumerate(top5_probs):
231
  ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
232
 
233
+ return final_image, gradcam_image, fig
234
 
235
  # Gradio Interface
236
  with gr.Blocks() as interface:
 
247
 
248
  with gr.Column():
249
  processed_image = gr.Image(label="Final Processed Image")
250
+ gradcam_output = gr.Image(label="GradCAM Overlay")
251
  bar_chart = gr.Plot(label="Class Probabilities")
252
 
253
  inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
254
+ outputs = [processed_image, gradcam_output, bar_chart]
255
 
256
  # Event listeners for real-time updates
257
  image_input.change(predict, inputs=inputs, outputs=outputs)