Spaces:
Sleeping
Sleeping
Update gradio_utils/utils.py
Browse files- gradio_utils/utils.py +116 -137
gradio_utils/utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import spaces
|
|
|
2 |
import random
|
3 |
import collections
|
4 |
import gradio as gr
|
@@ -13,14 +15,18 @@ from mmpose.core import wrap_fp16_model
|
|
13 |
from mmpose.models import build_posenet
|
14 |
from torchvision import transforms
|
15 |
import matplotlib.patheffects as mpe
|
|
|
|
|
16 |
from demo import Resize_Pad
|
17 |
from EdgeCape.models import *
|
18 |
|
19 |
|
20 |
def process_img(support_image, global_state):
|
21 |
global_state['images']['image_orig'] = support_image
|
22 |
-
global_state[
|
23 |
-
|
|
|
|
|
24 |
return support_image, global_state
|
25 |
|
26 |
|
@@ -28,7 +34,7 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
|
|
28 |
adj_mx = torch.empty(0, device=device)
|
29 |
batch_size = len(skeleton)
|
30 |
for b in range(batch_size):
|
31 |
-
edges = torch.tensor(skeleton[b])
|
32 |
adj = torch.zeros(num_pts, num_pts, device=device)
|
33 |
adj[edges[:, 0], edges[:, 1]] = 1
|
34 |
adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)
|
@@ -37,7 +43,6 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
|
|
37 |
adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
|
38 |
return adj
|
39 |
|
40 |
-
|
41 |
@spaces.GPU(duration=30)
|
42 |
def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
|
43 |
skeleton=None, prediction=None, radius=6, in_color=None,
|
@@ -46,104 +51,79 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
|
|
46 |
prediction = prediction[-1] * h
|
47 |
if isinstance(prediction, torch.Tensor):
|
48 |
prediction = prediction.cpu().numpy()
|
49 |
-
if isinstance(
|
50 |
-
|
51 |
-
original_skeleton = skeleton
|
52 |
-
support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img))
|
53 |
query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
max_skel_val = np.max(draw_skeleton)
|
103 |
-
draw_skeleton = draw_skeleton / max_skel_val * 6
|
104 |
-
for i in range(1, keypoint.shape[0]):
|
105 |
-
for j in range(0, i):
|
106 |
-
if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0:
|
107 |
-
if color is None:
|
108 |
-
num_colors = int((skeleton > 0.05).sum() / 2)
|
109 |
-
color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1)))
|
110 |
-
c = next(color)
|
111 |
-
elif isinstance(color, str):
|
112 |
-
c = color
|
113 |
-
elif isinstance(color, collections.Iterable):
|
114 |
-
c = next(color)
|
115 |
-
else:
|
116 |
-
raise ValueError("Color must be a string or an iterable")
|
117 |
-
if w[i] > 0 and w[j] > 0 and skeleton[i][j] > 0:
|
118 |
-
width = draw_skeleton[i][j]
|
119 |
-
stroke_width = width + (width / 3)
|
120 |
-
patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]],
|
121 |
-
[keypoint[i, 1], keypoint[j, 1]],
|
122 |
-
linewidth=width, color=c, alpha=0.6,
|
123 |
-
path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')],
|
124 |
-
zorder=1)
|
125 |
-
axes.add_artist(patch)
|
126 |
|
127 |
plt.axis('off') # command for hiding the axis.
|
128 |
-
plt.subplots_adjust(0, 0, 1, 1, 0, 0)
|
129 |
return plt
|
130 |
|
131 |
@spaces.GPU(duration=30)
|
132 |
def process(query_img, state,
|
133 |
cfg_path='configs/test/1shot_split1.py',
|
134 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
|
|
|
|
135 |
cfg = Config.fromfile(cfg_path)
|
136 |
-
width, height, _ = state['images']['image_orig'].shape
|
137 |
kp_src_np = np.array(state['points']).copy().astype(np.float32)
|
138 |
-
kp_src_np[:, 0] = kp_src_np[:, 0] /
|
139 |
-
kp_src_np[:, 1] = kp_src_np[:, 1] /
|
140 |
-
kp_src_np =
|
141 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
142 |
preprocess = transforms.Compose([
|
143 |
transforms.ToTensor(),
|
144 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
145 |
-
Resize_Pad(
|
146 |
-
|
147 |
|
148 |
if len(state['skeleton']) == 0:
|
149 |
state['skeleton'] = [(0, 0)]
|
@@ -154,8 +134,7 @@ def process(query_img, state,
|
|
154 |
# Create heatmap from keypoints
|
155 |
genHeatMap = TopDownGenerateTargetFewShot()
|
156 |
data_cfg = cfg.data_cfg
|
157 |
-
data_cfg['image_size'] = np.array([
|
158 |
-
cfg.model.encoder_config.img_size])
|
159 |
data_cfg['joint_weights'] = None
|
160 |
data_cfg['use_different_joint_weights'] = False
|
161 |
kp_src_3d = torch.cat(
|
@@ -172,22 +151,23 @@ def process(query_img, state,
|
|
172 |
torch.tensor(target_weight_s).float()[None])
|
173 |
|
174 |
data = {
|
175 |
-
'img_s': [support_img],
|
176 |
-
'img_q': q_img,
|
177 |
-
'target_s': [target_s],
|
178 |
-
'target_weight_s': [target_weight_s],
|
179 |
'target_q': None,
|
180 |
'target_weight_q': None,
|
181 |
'return_loss': False,
|
182 |
'img_metas': [{'sample_skeleton': [state['skeleton']],
|
183 |
'query_skeleton': state['skeleton'],
|
184 |
-
'sample_joints_3d': [kp_src_3d],
|
185 |
-
'query_joints_3d': kp_src_3d,
|
186 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
187 |
'query_center': kp_src_tensor.mean(dim=0),
|
188 |
'sample_scale': [
|
189 |
kp_src_tensor.max(dim=0)[0] -
|
190 |
-
kp_src_tensor.min(dim=0)[0]
|
|
|
191 |
'query_scale': kp_src_tensor.max(dim=0)[0] -
|
192 |
kp_src_tensor.min(dim=0)[0],
|
193 |
'sample_rotation': [0],
|
@@ -204,7 +184,7 @@ def process(query_img, state,
|
|
204 |
if fp16_cfg is not None:
|
205 |
wrap_fp16_model(model)
|
206 |
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
207 |
-
model.eval()
|
208 |
with torch.no_grad():
|
209 |
outputs = model(**data)
|
210 |
# visualize results
|
@@ -218,39 +198,47 @@ def process(query_img, state,
|
|
218 |
vis_s_weight,
|
219 |
None,
|
220 |
vis_s_weight,
|
221 |
-
outputs['skeleton'],
|
222 |
torch.tensor(outputs['points']).squeeze(),
|
223 |
original_skeleton=state['skeleton'],
|
224 |
img_alpha=1.0,
|
225 |
)
|
226 |
-
return out
|
227 |
-
|
228 |
-
|
229 |
-
def update_examples(support_img,
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
def get_select_coords(global_state,
|
@@ -260,17 +248,6 @@ def get_select_coords(global_state,
|
|
260 |
"""
|
261 |
xy = evt.index
|
262 |
global_state["points"].append(xy)
|
263 |
-
# point_idx = get_latest_points_pair(points)
|
264 |
-
# if point_idx is None:
|
265 |
-
# points[0] = {'start': xy, 'target': None}
|
266 |
-
# print(f'Click Image - Start - {xy}')
|
267 |
-
# elif points[point_idx].get('target', None) is None:
|
268 |
-
# points[point_idx]['target'] = xy
|
269 |
-
# print(f'Click Image - Target - {xy}')
|
270 |
-
# else:
|
271 |
-
# points[point_idx + 1] = {'start': xy, 'target': None}
|
272 |
-
# print(f'Click Image - Start - {xy}')
|
273 |
-
|
274 |
image_raw = global_state['images']['image_kp']
|
275 |
image_draw = update_image_draw(
|
276 |
image_raw,
|
@@ -362,7 +339,9 @@ def print_memory_usage():
|
|
362 |
torch.cuda.max_memory_allocated()
|
363 |
print(f"Available GPU memory: {available_memory / 1e9} GB")
|
364 |
else:
|
|
|
365 |
print("No GPU available")
|
|
|
366 |
|
367 |
def draw_limbs_on_image(image,
|
368 |
points,):
|
|
|
1 |
+
import copy
|
2 |
import spaces
|
3 |
+
import json
|
4 |
import random
|
5 |
import collections
|
6 |
import gradio as gr
|
|
|
15 |
from mmpose.models import build_posenet
|
16 |
from torchvision import transforms
|
17 |
import matplotlib.patheffects as mpe
|
18 |
+
|
19 |
+
from EdgeCape import TopDownGenerateTargetFewShot
|
20 |
from demo import Resize_Pad
|
21 |
from EdgeCape.models import *
|
22 |
|
23 |
|
24 |
def process_img(support_image, global_state):
|
25 |
global_state['images']['image_orig'] = support_image
|
26 |
+
if global_state["load_example"]:
|
27 |
+
global_state["load_example"] = False
|
28 |
+
return global_state['images']['image_kp'], global_state
|
29 |
+
_, _ = reset_kp(global_state)
|
30 |
return support_image, global_state
|
31 |
|
32 |
|
|
|
34 |
adj_mx = torch.empty(0, device=device)
|
35 |
batch_size = len(skeleton)
|
36 |
for b in range(batch_size):
|
37 |
+
edges = torch.tensor(skeleton[b]).long()
|
38 |
adj = torch.zeros(num_pts, num_pts, device=device)
|
39 |
adj[edges[:, 0], edges[:, 1]] = 1
|
40 |
adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)
|
|
|
43 |
adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
|
44 |
return adj
|
45 |
|
|
|
46 |
@spaces.GPU(duration=30)
|
47 |
def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
|
48 |
skeleton=None, prediction=None, radius=6, in_color=None,
|
|
|
51 |
prediction = prediction[-1] * h
|
52 |
if isinstance(prediction, torch.Tensor):
|
53 |
prediction = prediction.cpu().numpy()
|
54 |
+
if isinstance(original_skeleton, list):
|
55 |
+
original_skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0]
|
|
|
|
|
56 |
query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
|
57 |
+
img = query_img
|
58 |
+
w = query_w
|
59 |
+
keypoint = prediction
|
60 |
+
adj = skeleton
|
61 |
+
color = None
|
62 |
+
f, axes = plt.subplots()
|
63 |
+
plt.imshow(img, alpha=img_alpha)
|
64 |
+
for k in range(keypoint.shape[0]):
|
65 |
+
if w[k] > 0:
|
66 |
+
kp = keypoint[k, :2]
|
67 |
+
c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)
|
68 |
+
patch = plt.Circle(kp,
|
69 |
+
radius,
|
70 |
+
color=c,
|
71 |
+
path_effects=[mpe.withStroke(linewidth=2, foreground='black')],
|
72 |
+
zorder=200)
|
73 |
+
axes.add_patch(patch)
|
74 |
+
axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center",
|
75 |
+
zorder=300,
|
76 |
+
path_effects=[
|
77 |
+
mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')])
|
78 |
+
plt.draw()
|
79 |
+
|
80 |
+
if adj is not None:
|
81 |
+
max_skel_val = np.max(adj)
|
82 |
+
draw_skeleton = adj / max_skel_val * 6
|
83 |
+
for i in range(1, keypoint.shape[0]):
|
84 |
+
for j in range(0, i):
|
85 |
+
if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0:
|
86 |
+
if color is None:
|
87 |
+
num_colors = int((adj > 0.05).sum() / 2)
|
88 |
+
color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1)))
|
89 |
+
c = next(color)
|
90 |
+
elif isinstance(color, str):
|
91 |
+
c = color
|
92 |
+
elif isinstance(color, collections.Iterable):
|
93 |
+
c = next(color)
|
94 |
+
else:
|
95 |
+
raise ValueError("Color must be a string or an iterable")
|
96 |
+
if w[i] > 0 and w[j] > 0 and adj[i][j] > 0:
|
97 |
+
width = draw_skeleton[i][j]
|
98 |
+
stroke_width = width + (width / 3)
|
99 |
+
patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]],
|
100 |
+
[keypoint[i, 1], keypoint[j, 1]],
|
101 |
+
linewidth=width, color=c, alpha=0.6,
|
102 |
+
path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')],
|
103 |
+
zorder=1)
|
104 |
+
axes.add_artist(patch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
plt.axis('off') # command for hiding the axis.
|
|
|
107 |
return plt
|
108 |
|
109 |
@spaces.GPU(duration=30)
|
110 |
def process(query_img, state,
|
111 |
cfg_path='configs/test/1shot_split1.py',
|
112 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
113 |
+
print(state)
|
114 |
+
device = print_memory_usage()
|
115 |
cfg = Config.fromfile(cfg_path)
|
116 |
+
width, height, _ = np.array(state['images']['image_orig']).shape
|
117 |
kp_src_np = np.array(state['points']).copy().astype(np.float32)
|
118 |
+
kp_src_np[:, 0] = kp_src_np[:, 0] / width * 256
|
119 |
+
kp_src_np[:, 1] = kp_src_np[:, 1] / height * 256
|
120 |
+
kp_src_np = kp_src_np.copy()
|
121 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
122 |
preprocess = transforms.Compose([
|
123 |
transforms.ToTensor(),
|
124 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
125 |
+
Resize_Pad(256, 256)
|
126 |
+
])
|
127 |
|
128 |
if len(state['skeleton']) == 0:
|
129 |
state['skeleton'] = [(0, 0)]
|
|
|
134 |
# Create heatmap from keypoints
|
135 |
genHeatMap = TopDownGenerateTargetFewShot()
|
136 |
data_cfg = cfg.data_cfg
|
137 |
+
data_cfg['image_size'] = np.array([256, 256])
|
|
|
138 |
data_cfg['joint_weights'] = None
|
139 |
data_cfg['use_different_joint_weights'] = False
|
140 |
kp_src_3d = torch.cat(
|
|
|
151 |
torch.tensor(target_weight_s).float()[None])
|
152 |
|
153 |
data = {
|
154 |
+
'img_s': [support_img.to(device)],
|
155 |
+
'img_q': q_img.to(device),
|
156 |
+
'target_s': [target_s.to(device)],
|
157 |
+
'target_weight_s': [target_weight_s.to(device)],
|
158 |
'target_q': None,
|
159 |
'target_weight_q': None,
|
160 |
'return_loss': False,
|
161 |
'img_metas': [{'sample_skeleton': [state['skeleton']],
|
162 |
'query_skeleton': state['skeleton'],
|
163 |
+
'sample_joints_3d': [kp_src_3d.to(device)],
|
164 |
+
'query_joints_3d': kp_src_3d.to(device),
|
165 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
166 |
'query_center': kp_src_tensor.mean(dim=0),
|
167 |
'sample_scale': [
|
168 |
kp_src_tensor.max(dim=0)[0] -
|
169 |
+
kp_src_tensor.min(dim=0)[0]
|
170 |
+
],
|
171 |
'query_scale': kp_src_tensor.max(dim=0)[0] -
|
172 |
kp_src_tensor.min(dim=0)[0],
|
173 |
'sample_rotation': [0],
|
|
|
184 |
if fp16_cfg is not None:
|
185 |
wrap_fp16_model(model)
|
186 |
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
187 |
+
model.eval().to(device)
|
188 |
with torch.no_grad():
|
189 |
outputs = model(**data)
|
190 |
# visualize results
|
|
|
198 |
vis_s_weight,
|
199 |
None,
|
200 |
vis_s_weight,
|
201 |
+
outputs['skeleton'][1],
|
202 |
torch.tensor(outputs['points']).squeeze(),
|
203 |
original_skeleton=state['skeleton'],
|
204 |
img_alpha=1.0,
|
205 |
)
|
206 |
+
return out
|
207 |
+
|
208 |
+
|
209 |
+
def update_examples(support_img, query_image, global_state_str):
|
210 |
+
example_state = json.loads(global_state_str)
|
211 |
+
example_state["load_example"] = True
|
212 |
+
example_state["curr_type_point"] = "start"
|
213 |
+
example_state["prev_point"] = None
|
214 |
+
example_state['images'] = {}
|
215 |
+
example_state['images']['image_orig'] = support_img
|
216 |
+
example_state['images']['image_kp'] = support_img
|
217 |
+
example_state['images']['image_skeleton'] = support_img
|
218 |
+
image_draw = example_state['images']['image_orig'].copy()
|
219 |
+
for xy in example_state['points']:
|
220 |
+
image_draw = update_image_draw(
|
221 |
+
image_draw,
|
222 |
+
xy,
|
223 |
+
example_state
|
224 |
+
)
|
225 |
+
kp_image = image_draw.copy()
|
226 |
+
example_state['images']['image_kp'] = kp_image
|
227 |
+
pts_list = example_state['points']
|
228 |
+
for limb in example_state['skeleton']:
|
229 |
+
prev_point = pts_list[limb[0]]
|
230 |
+
curr_point = pts_list[limb[1]]
|
231 |
+
points = [prev_point, curr_point]
|
232 |
+
image_draw = draw_limbs_on_image(image_draw,
|
233 |
+
points
|
234 |
+
)
|
235 |
+
skel_image = image_draw.copy()
|
236 |
+
example_state['images']['image_skel'] = skel_image
|
237 |
+
return (support_img,
|
238 |
+
kp_image,
|
239 |
+
skel_image,
|
240 |
+
query_image,
|
241 |
+
example_state)
|
242 |
|
243 |
|
244 |
def get_select_coords(global_state,
|
|
|
248 |
"""
|
249 |
xy = evt.index
|
250 |
global_state["points"].append(xy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
image_raw = global_state['images']['image_kp']
|
252 |
image_draw = update_image_draw(
|
253 |
image_raw,
|
|
|
339 |
torch.cuda.max_memory_allocated()
|
340 |
print(f"Available GPU memory: {available_memory / 1e9} GB")
|
341 |
else:
|
342 |
+
device = "cpu"
|
343 |
print("No GPU available")
|
344 |
+
return device
|
345 |
|
346 |
def draw_limbs_on_image(image,
|
347 |
points,):
|