File size: 3,645 Bytes
7206ed3
 
 
 
 
6db6a87
63056c5
ea66b57
7206ed3
13575f0
 
6db6a87
0da9dca
 
 
 
56f79a3
 
 
 
 
 
a8badba
 
 
 
 
 
 
 
 
 
 
0da9dca
 
 
 
 
 
 
 
 
7206ed3
3e17a60
 
 
 
 
 
 
 
 
 
 
 
 
7206ed3
3e17a60
 
 
 
 
 
 
 
 
 
 
 
 
b11b05c
7206ed3
b11b05c
 
 
 
 
 
 
 
 
 
7206ed3
a8badba
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import yaml
import numpy as np
from matplotlib import cm
import gradio as gr
import deeplabcut
import dlclibrary
import transformers

from PIL import Image
import requests

from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w_text, save_results_only_dlc
from detection_utils import predict_md, crop_animal_detections
from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples

from deeplabcut.utils import auxiliaryfunctions
from dlclibrary.dlcmodelzoo.modelzoo_download import (
    download_huggingface_model,
    MODELOPTIONS,
)



# TESTING (passes) download the SuperAnimal models:
#model = 'superanimal_topviewmouse'
#train_dir = 'DLC_models/sa-tvm'
#download_huggingface_model(model, train_dir)

# grab demo data cooco cat:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# megadetector and dlc model look up
MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", # 
                  'md_v5b': "MD_models/md_v5b.0.0.pt"}

# DLC models target  dirs
DLC_models_dict = {'superanimal_topviewmouse': "DLC_models/sa-tvm",
                   'superanimal_quadreped': "DLC_models/sa-q",
                    'full_human': "DLC_models/DLC_human_dancing/"}


#####################################################
def predict_pipeline(img_input,
                     mega_model_input,
                     dlc_model_input_str,
                     flag_dlc_only,
                     flag_show_str_labels,
                     bbox_likelihood_th,
                     kpts_likelihood_th,
                     font_style,
                     font_size,
                     keypt_color,
                     marker_size,
                     ):

    if not flag_dlc_only:
        ############################################################                                               
        # ### Run Megadetector
        md_results = predict_md(img_input, 
                                MD_models_dict[mega_model_input], #mega_model_input,
                                size=640) #Image.fromarray(results.imgs[0])

        ################################################################
        # Obtain animal crops for bboxes with confidence above th
        list_crops = crop_animal_detections(img_input,
                                            md_results,
                                            bbox_likelihood_th)

        ############################################################

    ## Get DLC model and label map  
    
    # If model is found: do not download (previous execution is likely within same day)
    # TODO: can we ask the user whether to reload dlc model if a directory is found?
    if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
        len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
        path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
    else:
        path_to_DLCmodel = download_huggingface_model(dlc_model_input_str, 
                                         DLC_models_dict[dlc_model_input_str])

    # extract map label ids to strings
    pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
                                 'pose_cfg.yaml')
    with open(pose_cfg_path, "r") as stream:
        pose_cfg_dict = yaml.safe_load(stream) 
    map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']],  # pose_cfg_dict['all_joints'] is a list of one-element lists,
                                                     pose_cfg_dict['all_joints_names'])])