Nabeel Raza commited on
Commit
a0630af
·
1 Parent(s): 9205e98

add: OD option

Browse files
Files changed (2) hide show
  1. app.py +24 -4
  2. explain.py +27 -20
app.py CHANGED
@@ -4,18 +4,38 @@ from PIL import Image
4
  from explain import get_results, reproduce
5
 
6
 
7
- def classify_and_explain(image):
8
  reproduce()
9
  # This function will classify the image and return a list of image paths
10
- list_of_images = get_results(img_for_testing=image)
11
  return list_of_images
12
 
13
 
14
  def get_examples():
15
- return [Image.open(i) for i in glob("samples/*")]
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  demo = gr.Interface(
19
- fn=classify_and_explain, inputs="image", outputs="gallery", examples=get_examples()
 
 
 
 
 
 
 
 
20
  )
21
  demo.launch()
 
4
  from explain import get_results, reproduce
5
 
6
 
7
+ def classify_and_explain(image, object_detection=False):
8
  reproduce()
9
  # This function will classify the image and return a list of image paths
10
+ list_of_images = get_results(img_for_testing=image, od=object_detection)
11
  return list_of_images
12
 
13
 
14
  def get_examples():
15
+
16
+ od_off_examples = [
17
+ "samples/DSC_0315.jpg",
18
+ "samples/20210401_123624.jpg",
19
+ "samples/IMG_1299.jpg",
20
+ "samples/20210420_112400.jpg",
21
+ "samples/IMG_1300.jpg",
22
+ "samples/20210420_112406.jpg",
23
+ ]
24
+
25
+ return [
26
+ [Image.open(i), True] for i in glob("samples/*") if i not in od_off_examples
27
+ ] + [[Image.open(i), False] for i in glob("samples/*") if i in od_off_examples]
28
 
29
 
30
  demo = gr.Interface(
31
+ fn=classify_and_explain,
32
+ inputs=[
33
+ "image",
34
+ gr.Checkbox(
35
+ label="Extract Leaves", info="What to extract leafs before classification"
36
+ ),
37
+ ],
38
+ outputs="gallery",
39
+ examples=get_examples(),
40
  )
41
  demo.launch()
explain.py CHANGED
@@ -15,6 +15,7 @@ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
15
  from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
16
 
17
  from ultralytics import YOLO
 
18
  # from rembg import remove
19
  import uuid
20
 
@@ -132,7 +133,7 @@ def save_explanation_results(res, path):
132
 
133
 
134
  model, image_transform = get_model(model_name)
135
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
136
  model.train()
137
  target_layers = [model.conv_head]
138
  gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
@@ -140,38 +141,44 @@ gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
140
  yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
141
 
142
 
143
- def get_results(img_path=None, img_for_testing=None):
144
  if img_path is None and img_for_testing is None:
145
  raise ValueError("Either img_path or img_for_testing should be provided.")
146
 
147
  if img_path is not None:
148
- results = yolo_model(img_path)
149
  image = Image.open(img_path)
150
 
151
  if img_for_testing is not None:
152
- results = yolo_model(img_for_testing)
153
  image = Image.fromarray(img_for_testing)
154
 
155
  result_paths = []
156
 
157
- for i, result in enumerate(results):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  unique_id = uuid.uuid4().hex
159
- save_path = f"/tmp/with-white-bg-result-{unique_id}.png"
160
- bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
161
- bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
162
-
163
- # bbox_image = remove(bbox_image).convert("RGB")
164
- # bbox_image = Image.fromarray(
165
- # np.where(
166
- # np.array(bbox_image) == [0, 0, 0],
167
- # [255, 255, 255],
168
- # np.array(bbox_image),
169
- # ).astype(np.uint8)
170
- # )
171
-
172
- res = make_prediction_and_explain(bbox_image)
173
  save_explanation_results(res, save_path)
174
-
175
  result_paths.append(save_path)
176
 
177
  return result_paths
 
15
  from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image
16
 
17
  from ultralytics import YOLO
18
+
19
  # from rembg import remove
20
  import uuid
21
 
 
133
 
134
 
135
  model, image_transform = get_model(model_name)
136
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
137
  model.train()
138
  target_layers = [model.conv_head]
139
  gbp_model = GuidedBackpropReLUModel(model=model, device="cpu")
 
141
  yolo_model = YOLO(YOLO_MODEL_WEIGHTS)
142
 
143
 
144
+ def get_results(img_path=None, img_for_testing=None, od=False):
145
  if img_path is None and img_for_testing is None:
146
  raise ValueError("Either img_path or img_for_testing should be provided.")
147
 
148
  if img_path is not None:
 
149
  image = Image.open(img_path)
150
 
151
  if img_for_testing is not None:
 
152
  image = Image.fromarray(img_for_testing)
153
 
154
  result_paths = []
155
 
156
+ if od:
157
+ results = yolo_model(img_path if img_path else img_for_testing)
158
+ for i, result in enumerate(results):
159
+ unique_id = uuid.uuid4().hex
160
+ save_path = f"/tmp/with-bg-result-{unique_id}.png"
161
+ bbox = result.boxes.xyxy[0].cpu().numpy().astype(int)
162
+ bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
163
+
164
+ # bbox_image = remove(bbox_image).convert("RGB")
165
+ # bbox_image = Image.fromarray(
166
+ # np.where(
167
+ # np.array(bbox_image) == [0, 0, 0],
168
+ # [255, 255, 255],
169
+ # np.array(bbox_image),
170
+ # ).astype(np.uint8)
171
+ # )
172
+
173
+ res = make_prediction_and_explain(bbox_image)
174
+ save_explanation_results(res, save_path)
175
+
176
+ result_paths.append(save_path)
177
+ else:
178
  unique_id = uuid.uuid4().hex
179
+ save_path = f"/tmp/with-bg-result-{unique_id}.png"
180
+ res = make_prediction_and_explain(image)
 
 
 
 
 
 
 
 
 
 
 
 
181
  save_explanation_results(res, save_path)
 
182
  result_paths.append(save_path)
183
 
184
  return result_paths