File size: 9,516 Bytes
fcf3dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dab660
fcf3dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dab660
fcf3dc1
 
 
8dab660
fcf3dc1
 
 
 
 
 
 
8dab660
fcf3dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

import gradio as gr
import torch
import torchvision
import numpy as np
from PIL import Image
import PIL.ImageDraw as ImageDraw
import math
import pdb

from dlclive import DLCLive, Processor

import matplotlib.pyplot as plt

#########################################
# https://www.programcreek.com/python/?code=fjchange%2Fobject_centric_VAD%2Fobject_centric_VAD-master%2Fobject_detection%2Futils%2Fvisualization_utils.py
def draw_keypoints_on_image(image,
                            keypoints,
                            color='red',
                            radius=2,
                            use_normalized_coordinates=True):
  """Draws keypoints on an image.

  Args:
    image: a PIL.Image object.
    keypoints: a numpy array with shape [num_keypoints, 2].
    color: color to draw the keypoints with. Default is red.
    radius: keypoint radius. Default value is 2.
    use_normalized_coordinates: if True (default), treat keypoint values as
      relative to the image.  Otherwise treat them as absolute.
  """
  # get a drawing context
  draw = ImageDraw.Draw(image)

  im_width, im_height = image.size
  keypoints_x = [k[1] for k in keypoints]
  keypoints_y = [k[0] for k in keypoints]

  # adjust keypoints coords if required
  if use_normalized_coordinates:
    keypoints_x = tuple([im_width * x for x in keypoints_x])
    keypoints_y = tuple([im_height * y for y in keypoints_y])

  # draw ellipses around keypoints
  for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
    draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
                  (keypoint_x + radius, keypoint_y + radius)],
                  outline=color, fill=color)
############################################

# Predict detections with MegaDetector v5a model
def predict_md(im, size=640):
    # resize image
    g = (size / max(im.size))  # gain
    im = im.resize((int(x * g) for x in im.size), Image.ANTIALIAS)  # resize
    
    ## detect objects
    results = MD_model(im)  # inference # vars(results).keys()= dict_keys(['imgs', 'pred', 'names', 'files', 'times', 'xyxy', 'xywh', 'xyxyn', 'xywhn', 'n', 't', 's'])
    results.render()  # updates results.imgs with boxes and labels

    return results #Image.fromarray(results.imgs[0]) ---return animals only?

def crop_animal_detections(yolo_results, likelihood_th):
    ## crop if animal and return list of crops

    list_labels_as_str = yolo_results.names #['animal', 'person', 'vehicle']
    list_np_animal_crops = []

    # for every image
    for img, det_array in zip(yolo_results.imgs,
                              yolo_results.xyxy):

        # for every detection
        for j in range(det_array.shape[0]):

            # compute coords around bbox rounded to the nearest integer (for pasting later)
            xmin_rd = int(math.floor(det_array[j,0])) # int() should suffice?
            ymin_rd = int(math.floor(det_array[j,1]))

            xmax_rd = int(math.ceil(det_array[j,2]))
            ymax_rd = int(math.ceil(det_array[j,3]))

            pred_llk = det_array[j,4] #-----TODO: filter based on likelihood?
            pred_label = det_array[j,5]

            if (pred_label == list_labels_as_str.index('animal')) and \
                (pred_llk >= likelihood_th):
                area = (xmin_rd, ymin_rd, xmax_rd, ymax_rd)

                crop = Image.fromarray(img).crop(area)
                crop_np = np.asarray(crop)

                # add to list
                list_np_animal_crops.append(crop_np)

    # for detections_dict in img_data["detections"]:
    #     index = img_data["detections"].index(detections_dict)
    #     if detections_dict["conf"] > 0.8: 
    #         x1, y1,w_box, h_box = detections_dict["bbox"]
    #         ymin,xmin,ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
            
    #         imageWidth=img.size[0]
    #         imageHeight= img.size[1]
    #         area = (xmin * imageWidth, ymin * imageHeight, xmax * imageWidth,
    #                 ymax * imageHeight)
    #         crop = img.crop(area)
    #         crop_np = np.asarray(crop)
    # 
    # if detections_dict["category"] == "1":
    return list_np_animal_crops

def predict_dlc(list_np_crops,DLCmodel,dlc_proc):
    # run dlc thru list of crops
    dlc_live = DLCLive(DLCmodel, processor=dlc_proc)
    dlc_live.init_inference(list_np_crops[0])

    list_kpts_per_crop = []
    for crop in list_np_crops:
        # scale crop?
        keypts = dlc_live.get_pose(crop) # third column is llk!
        list_kpts_per_crop.append(keypts)

        return list_kpts_per_crop



def predict_pipeline(img_input):

    # these eventually user inputs....
    path_to_DLCmodel =  "DLC_models/DLC_Cat_resnet_50_iteration-0_shuffle-0"
    likelihood_th = 0.8

    # Run Megadetector
    md_results = predict_md(img_input) #Image.fromarray(results.imgs[0])

    # Obtain animal crops with confidence above th
    list_crops = crop_animal_detections(md_results,
                                        likelihood_th)

    # Run DLC
    # TODO: add llk threshold for kpts too?
    dlc_proc = Processor()
    list_kpts_per_crop = predict_dlc(list_crops,
                                     path_to_DLCmodel,
                                     dlc_proc)
    

    # # Produce final image
    # fig = plt.Figure(md_results.imgs[0].shape[:2]) #figsize=(10,10)) #md_results.imgs[0].shape)
    for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
                                                  list_kpts_per_crop)):

        ## Draw keypts on crop
        img_crop = Image.fromarray(np_crop)
        draw_keypoints_on_image(img_crop,
                                kpts_crop, # a numpy array with shape [num_keypoints, 2].
                                color='red',
                                radius=2,
                                use_normalized_coordinates=False) # if True, then I should use md_results.xyxyn

        ## Paste crop in original image
        # https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.paste
        img_input.paste(img_crop,
                        box = tuple([int(math.floor(t)) for t in md_results.xyxy[0][ic,:2]]))
                                                                   
        # plt.imshow(np_crop)
        # plt.scatter(kpts_crop[:,0], kpts_crop[:,1], 40,
        #             color='r')
        # img_overlay = Image.frombytes('RGB', 
        #                                fig.canvas.get_width_height(),
        #                                fig.canvas.tostring_rgb())
    return img_input #Image.fromarray(list_crops[0]) #Image.fromarray(md_results.imgs[0]) #


##########################################################
# Get MegaDetector model
# TODO: Allow user selectable model?
# models = ["model_weights/md_v5a.0.0.pt","model_weights/md_v5b.0.0.pt"]
MD_model = torch.hub.load('ultralytics/yolov5', 'custom', "model_weights/md_v5a.0.0.pt")



####################################################
# Create user interface and launch
#inputs = [image, chosen_model, size]
inputs = gr.inputs.Image(type="pil", label="Input Image")
outputs = gr.outputs.Image(type="pil", label="Output Image")
#image = gr.inputs.Image(type="pil", label="Input Image")
#chosen_model = gr.inputs.Dropdown(choices = models, value = "model_weights/md_v5a.0.0.pt",type = "value", label="Model Weight")
#size = 640

title = "MegaDetector v5 + DLC live"
description = "Detect and estimate pose of animals camera trap images using MegaDetector v5a + DLClive"
# article = "<p style='text-align: center'>This app makes predictions using a YOLOv5x6 model that was trained to detect animals, humans, and vehicles in camera trap images; find out more about the project on <a href='https://github.com/microsoft/CameraTraps'>GitHub</a>. This app was built by Henry Lydecker but really depends on code and models developed by <a href='http://ecologize.org/'>Ecologize</a> and <a href='http://aka.ms/aiforearth'>Microsoft AI for Earth</a>. Find out more about the YOLO model from the original creator, <a href='https://pjreddie.com/darknet/yolo/'>Joseph Redmon</a>. YOLOv5 is a family of compound-scaled object detection models trained on the COCO dataset and developed by Ultralytics, and includes simple functionality for Test Time Augmentation (TTA), model ensembling, hyperparameter evolution, and export to ONNX, CoreML and TFLite. <a href='https://github.com/ultralytics/yolov5'>Source code</a> | <a href='https://pytorch.org/hub/ultralytics_yolov5'>PyTorch Hub</a></p>"
# examples = [['data/Macropod.jpg'], ['data/koala2.jpg'],['data/cat.jpg'],['data/BrushtailPossum.jpg']]

gr.Interface(predict_pipeline, 
             inputs, 
             outputs, 
             title=title, 
             description=description,
             theme="huggingface").launch(enable_queue=True)


# def dlclive_pose(model, crop_np, crop, fname, index,dlc_proc):
#     dlc_live = DLCLive(model, processor=dlc_proc) 
#     dlc_live.init_inference(crop_np)
#     keypts = dlc_live.get_pose(crop_np) 
#     savetxt(str(fname)+ '_' + str(index) + '.csv' , keypts, delimiter=',')
#     xpose = []
#     ypose = []
#     for key in keypts[:,2]:
#        # if key > 0.05: # which value do we need here?
#             i = np.where(keypts[:,2]==key)
#             xpose.append(keypts[i,0])
#             ypose.append(keypts[i,1])
#     plt.imshow(crop)
#     plt.scatter(xpose[:], ypose[:], 40, color='cyan')
#     plt.savefig(str(fname)+ '_' + str(index) + '.png')
#     plt.show()
#     plt.clf()