File size: 5,074 Bytes
ea89305
e20f20e
 
 
ea89305
7a4ce8c
 
 
e20f20e
 
f021f8a
dd7e161
5fa7919
b58e2bf
 
e20f20e
 
 
 
 
 
2a7b6f5
 
dd7e161
 
 
2a7b6f5
 
 
dd7e161
 
 
e20f20e
dd314fc
e20f20e
dd314fc
e20f20e
ea89305
e20f20e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b58e2bf
03b969c
46676c0
e20f20e
49c76ed
b58e2bf
da2d10b
79cf760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da2d10b
a6d62b8
 
bacee3c
cc1cb8d
b58e2bf
848b9af
b58e2bf
 
af92f07
b58e2bf
 
 
 
bacee3c
 
 
 
 
cc1cb8d
79cf760
 
da2d10b
 
 
cc1cb8d
b58e2bf
 
79cf760
 
 
a2ecd8f
da2d10b
b58e2bf
 
06cf3c0
b58e2bf
06cf3c0
b58e2bf
 
 
e20f20e
 
c2c6f96
e20f20e
 
 
 
 
 
b58e2bf
ea89305
e20f20e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import spaces
import torch, torchvision
import sys
# sys.path.insert(0, 'test_mmpose/')

import mim
mim.install('mmcv-full==1.5.0')

import mmpose
import gradio as gr


from PIL import Image
import cv2
import numpy as np

pose_config = 'configs/topdown_heatmap_hrnet_w48_coco_256x192.py'
pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
det_config = 'configs/faster_rcnn_r50_fpn_1x_coco.py'
det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

@spaces.GPU
def cache_features():
    from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
                         vis_pose_result, process_mmdet_results)
    from mmdet.apis import inference_detector, init_detector
    init_pose_model(pose_config, pose_checkpoint, device='cuda')
    init_detector(det_config, det_checkpoint, device='cuda')
cache_features()
from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
                         vis_pose_result, process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
# initialize pose model
pose_model = init_pose_model(pose_config, pose_checkpoint, device='cuda')
# initialize detector
det_model = init_detector(det_config, det_checkpoint, device='cuda')

@spaces.GPU
def predict(img):
    mmdet_results = inference_detector(det_model, img)
    person_results = process_mmdet_results(mmdet_results, cat_id=1)

    pose_results, returned_outputs = inference_top_down_pose_model(
      pose_model,
      img,
      person_results,
      bbox_thr=0.3,
      format='xyxy',
      dataset=pose_model.cfg.data.test.type)
    
    vis_result = vis_pose_result(
      pose_model,
      img,
      pose_results,
      dataset=pose_model.cfg.data.test.type,
      show=False)

    #original_image = Image.open(img)
    width, height, channels = img.shape
    #vis_result = cv2.resize(vis_result, dsize=None, fx=0.5, fy=0.5)
    print(f"POSE_RESULTS: {pose_results}")
    
    # define colors for each body part
    body_part = {
        "nose": 0,
        "left_eye": 1,
        "right_eye": 2,
        "left_ear": 3,
        "right_ear": 4,
        "left_shoulder": 5,
        "right_shoulder": 6,
        "left_elbow": 7,
        "right_elbow": 8,
        "left_wrist": 9,
        "right_wrist": 10,
        "left_hip": 11,
        "right_hip": 12,
        "left_knee": 13,
        "right_knee": 14,
        "left_ankle": 15,
        "right_ankle": 16
    }
    orange=(51,153,255)
    blue=(255,128,0)
    green=(0,255,0)
    
    # create a black image of the same size as the original image
    black_img = np.zeros((width, height, 3), np.uint8)
    
    # iterate through each person in the POSE_RESULTS data
    for person in pose_results:
        # get the keypoints for this person
        keypoints = person['keypoints']
        
        # draw lines between keypoints to form a skeleton
        skeleton = [("right_eye", "left_eye", orange),("nose", "left_eye", orange), ("left_eye", "left_ear", orange), ("nose", "right_eye", orange), ("right_eye", "right_ear", orange),
                    ("left_shoulder", "left_ear", orange),("right_shoulder", "right_ear", orange), ("left_shoulder", "right_shoulder", orange), ("left_shoulder", "left_elbow", green), ("right_shoulder", "right_elbow",blue),
                    ("left_elbow", "left_wrist",green), ("right_elbow", "right_wrist",blue), ("left_shoulder", "left_hip",orange),
                    ("right_shoulder", "right_hip", orange), ("left_hip", "right_hip", orange), ("left_hip", "left_knee",green),
                    ("right_hip", "right_knee",blue), ("left_knee", "left_ankle",green), ("right_knee", "right_ankle",blue)]
        for start_part, end_part, color in skeleton:
            start_idx = list(body_part.keys()).index(start_part)
            end_idx = list(body_part.keys()).index(end_part)
            if keypoints[start_idx][2] > 0.1 and keypoints[end_idx][2] > 0.1:
                pt1 = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
                pt2 = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
                cv2.line(black_img, pt1, pt2, color, thickness=2, lineType=cv2.LINE_AA)
    
        # draw circles at each keypoint
        #for i in range(keypoints.shape[0]):
        #    pt = (int(keypoints[i][0]), int(keypoints[i][1]))
        #    cv2.circle(black_img, pt, 3, (255, 255, 255), thickness=-1, lineType=cv2.LINE_AA)


    
    # write black_img to a jpg file
    
    cv2.waitKey(0)
    cv2.imwrite("output.jpg", black_img)
    cv2.destroyAllWindows()
    
    return vis_result, "output.jpg"

example_list = ['examples/demo2.png']
title = "MMPose estimation"
description = ""
article = ""

# Create the Gradio demo
demo = gr.Interface(fn=predict,
                    inputs=gr.Image(), 
                    outputs=[gr.Image(label='Prediction'), gr.Image(label='Poses')], 
                    #examples=example_list, 
                    title=title,
                    description=description,
                    article=article)

# Launch the demo!
demo.launch()