AkashDataScience commited on
Commit
5d43fef
·
1 Parent(s): 1ac603d

Updated inferencing code

Browse files
Files changed (1) hide show
  1. app.py +24 -36
app.py CHANGED
@@ -1,27 +1,24 @@
1
  import os
2
  import cv2
 
3
  import math
4
  import torch
5
  import numpy as np
6
  import gradio as gr
7
- import albumentations
8
  import matplotlib.pyplot as plt
9
- from glob import glob
10
  from PIL import Image
11
- from pytorch_grad_cam import EigenCAM
12
- from models.common import DetectMultiBackend
13
- from albumentations.pytorch import ToTensorV2
14
- from utils.augmentations import letterbox
15
  from utils.plots import Annotator, colors
16
- from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
 
 
17
  from utils.torch_utils import select_device, smart_inference_mode
18
- from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes
19
 
20
  weights = "runs/train/best_striped.pt"
21
  data = "data.yaml"
22
  # Load model
23
  device = select_device('cpu')
24
- model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
25
  #target_layers = [model.model.model[-1]]
26
 
27
  false_detection_data = glob(os.path.join("false_detection", '*.jpg'))
@@ -65,41 +62,32 @@ def display_false_detection_data(false_detection_data, number_of_samples):
65
  return fig
66
 
67
  def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10):
68
- im0 = input_img.copy()
69
- rgb_img = cv2.resize(im0, (640, 640))
70
  stride, names, pt = model.stride, model.names, model.pt
71
- imgsz = check_img_size((640, 640), s=stride) # check image size
72
-
73
- bs = 1
74
- # Run inference
75
- model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
76
- seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
77
-
78
- with dt[0]:
79
- im = letterbox(input_img, imgsz, stride=stride, auto=True)[0] # padded resize
80
- im = im.transpose((2, 0, 1))[::-1]
81
- im = np.ascontiguousarray(im)
82
- im = torch.from_numpy(im).to(model.device)
83
- im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
84
- im /= 255 # 0 - 255 to 0.0 - 1.0
85
- if len(im.shape) == 3:
86
- im = im[None] # expand for batch dim
87
-
88
  # Inference
89
- with dt[1]:
90
- pred = model(im, augment=False, visualize=False)
91
 
92
- # NMS
93
- with dt[2]:
94
- pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10)
95
 
96
  # Process predictions
 
97
  for i, det in enumerate(pred): # per image
98
  seen += 1
99
- annotator = Annotator(im0, line_width=2, example=str(model.names))
100
  if len(det):
101
  # Rescale boxes from img_size to im0 size
102
- det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
103
 
104
  # Write results
105
  for *xyxy, conf, cls in reversed(det):
@@ -117,7 +105,7 @@ def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True,
117
  # grayscale_cam = cam(im)[0, :]
118
  # cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
119
 
120
- return im0, misclassified_images
121
 
122
  title = "YOLOv9 model to detect shirt/tshirt"
123
  description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"
 
1
  import os
2
  import cv2
3
+ import glob
4
  import math
5
  import torch
6
  import numpy as np
7
  import gradio as gr
 
8
  import matplotlib.pyplot as plt
9
+
10
  from PIL import Image
 
 
 
 
11
  from utils.plots import Annotator, colors
12
+ from utils.augmentations import letterbox
13
+ from models.common import DetectMultiBackend
14
+ from utils.general import non_max_suppression, scale_boxes
15
  from utils.torch_utils import select_device, smart_inference_mode
 
16
 
17
  weights = "runs/train/best_striped.pt"
18
  data = "data.yaml"
19
  # Load model
20
  device = select_device('cpu')
21
+ model = DetectMultiBackend(weights=weights, device=device, fp16=False, data=data)
22
  #target_layers = [model.model.model[-1]]
23
 
24
  false_detection_data = glob(os.path.join("false_detection", '*.jpg'))
 
62
  return fig
63
 
64
  def inference(input_img, conf_thres, iou_thres, is_false_detection_images=True, num_false_detection_images=10):
 
 
65
  stride, names, pt = model.stride, model.names, model.pt
66
+
67
+ # Load image
68
+ img0 = input_img.copy()
69
+ img = letterbox(img0, 640, stride=stride, auto=True)[0]
70
+ img = img[:, :, ::-1].transpose(2, 0, 1)
71
+ img = np.ascontiguousarray(img)
72
+ img = torch.from_numpy(img).to(device).float()
73
+ img /= 255.0
74
+ if img.ndimension() == 3:
75
+ img = img.unsqueeze(0)
76
+
 
 
 
 
 
 
77
  # Inference
78
+ pred = model(img, augment=False, visualize=False)
 
79
 
80
+ # Apply NMS
81
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, max_det=1000)
 
82
 
83
  # Process predictions
84
+ seen = 0
85
  for i, det in enumerate(pred): # per image
86
  seen += 1
87
+ annotator = Annotator(img0, line_width=2, example=str(model.names))
88
  if len(det):
89
  # Rescale boxes from img_size to im0 size
90
+ det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round()
91
 
92
  # Write results
93
  for *xyxy, conf, cls in reversed(det):
 
105
  # grayscale_cam = cam(im)[0, :]
106
  # cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
107
 
108
+ return img0, misclassified_images
109
 
110
  title = "YOLOv9 model to detect shirt/tshirt"
111
  description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"