AkashDataScience commited on
Commit
61d765f
·
1 Parent(s): 655a5f6

Addin GradCAM

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -10,6 +10,8 @@ from PIL import Image
10
  from models.common import DetectMultiBackend
11
  from utils.augmentations import letterbox
12
  from utils.plots import Annotator, colors
 
 
13
  from utils.torch_utils import select_device, smart_inference_mode
14
  from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes
15
 
@@ -18,6 +20,7 @@ data = "data.yaml"
18
  # Load model
19
  device = select_device('cpu')
20
  model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
 
21
 
22
  false_detection_data = glob(os.path.join("false_detection", '*.jpg'))
23
  false_detection_data = [x.replace('\\', '/') for x in false_detection_data]
@@ -61,6 +64,8 @@ def display_false_detection_data(false_detection_data, number_of_samples):
61
 
62
  def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10):
63
  im0 = input_img.copy()
 
 
64
  stride, names, pt = model.stride, model.names, model.pt
65
  imgsz = check_img_size((640, 640), s=stride) # check image size
66
 
@@ -107,7 +112,11 @@ def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True,
107
  else:
108
  misclassified_images = None
109
 
110
- return im0, misclassified_images
 
 
 
 
111
 
112
  title = "YOLOv9 model to detect shirt/tshirt"
113
  description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"
@@ -129,6 +138,7 @@ demo = gr.Interface(inference,
129
  gr.Checkbox(label="Show False Detection"),
130
  gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")],
131
  outputs= [gr.Image(width=640, height=640, label="Output"),
 
132
  gr.Plot(label="False Detection")],
133
  title=title,
134
  description=description,
 
10
  from models.common import DetectMultiBackend
11
  from utils.augmentations import letterbox
12
  from utils.plots import Annotator, colors
13
+ from pytorch_grad_cam import EigenCAM
14
+ from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
15
  from utils.torch_utils import select_device, smart_inference_mode
16
  from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes
17
 
 
20
  # Load model
21
  device = select_device('cpu')
22
  model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
23
+ target_layers = [model.model[-2]]
24
 
25
  false_detection_data = glob(os.path.join("false_detection", '*.jpg'))
26
  false_detection_data = [x.replace('\\', '/') for x in false_detection_data]
 
64
 
65
  def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10):
66
  im0 = input_img.copy()
67
+ im_resized = cv2.resize(im0, (640, 640))
68
+ rgb_img = im_resized.copy()
69
  stride, names, pt = model.stride, model.names, model.pt
70
  imgsz = check_img_size((640, 640), s=stride) # check image size
71
 
 
112
  else:
113
  misclassified_images = None
114
 
115
+ cam = EigenCAM(model, target_layers,task='od')
116
+ grayscale_cam = cam(rgb_img)[0, :, :]
117
+ cam_image = show_cam_on_image(im_resized, grayscale_cam, use_rgb=True)
118
+
119
+ return im0, cam_image, misclassified_images
120
 
121
  title = "YOLOv9 model to detect shirt/tshirt"
122
  description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"
 
138
  gr.Checkbox(label="Show False Detection"),
139
  gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")],
140
  outputs= [gr.Image(width=640, height=640, label="Output"),
141
+ gr.Plot(label="GradCAM"),
142
  gr.Plot(label="False Detection")],
143
  title=title,
144
  description=description,