orhir commited on
Commit
0e6ae18
·
verified ·
1 Parent(s): 5318ebc

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +116 -137
gradio_utils/utils.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import spaces
 
2
  import random
3
  import collections
4
  import gradio as gr
@@ -13,14 +15,18 @@ from mmpose.core import wrap_fp16_model
13
  from mmpose.models import build_posenet
14
  from torchvision import transforms
15
  import matplotlib.patheffects as mpe
 
 
16
  from demo import Resize_Pad
17
  from EdgeCape.models import *
18
 
19
 
20
  def process_img(support_image, global_state):
21
  global_state['images']['image_orig'] = support_image
22
- global_state['images']['image_kp'] = support_image
23
- reset_kp(global_state)
 
 
24
  return support_image, global_state
25
 
26
 
@@ -28,7 +34,7 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
28
  adj_mx = torch.empty(0, device=device)
29
  batch_size = len(skeleton)
30
  for b in range(batch_size):
31
- edges = torch.tensor(skeleton[b])
32
  adj = torch.zeros(num_pts, num_pts, device=device)
33
  adj[edges[:, 0], edges[:, 1]] = 1
34
  adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)
@@ -37,7 +43,6 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
37
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
38
  return adj
39
 
40
-
41
  @spaces.GPU(duration=30)
42
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
43
  skeleton=None, prediction=None, radius=6, in_color=None,
@@ -46,104 +51,79 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
46
  prediction = prediction[-1] * h
47
  if isinstance(prediction, torch.Tensor):
48
  prediction = prediction.cpu().numpy()
49
- if isinstance(skeleton, list):
50
- skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0]
51
- original_skeleton = skeleton
52
- support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img))
53
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
54
- error_mask = None
55
- for id, (img, w, keypoint, adj) in enumerate(zip([support_img, support_img, query_img],
56
- [support_w, support_w, query_w],
57
- # [support_kp, query_kp])):
58
- [support_kp, support_kp, prediction],
59
- [original_skeleton, skeleton, skeleton])):
60
- color = in_color
61
- f, axes = plt.subplots()
62
- plt.imshow(img, alpha=img_alpha)
63
-
64
- # On qeury image plot
65
- if id == 2 and target_keypoints is not None:
66
- error = np.linalg.norm(keypoint - target_keypoints, axis=-1)
67
- error_mask = error > (256 * 0.05)
68
-
69
- for k in range(keypoint.shape[0]):
70
- if w[k] > 0:
71
- kp = keypoint[k, :2]
72
- c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)
73
- if error_mask is not None and error_mask[k]:
74
- c = (1, 1, 0, 0.75)
75
- patch = plt.Circle(kp,
76
- radius,
77
- color=c,
78
- path_effects=[mpe.withStroke(linewidth=8, foreground='black'),
79
- mpe.withStroke(linewidth=4, foreground='white'),
80
- mpe.withStroke(linewidth=2, foreground='black'),
81
- ],
82
- zorder=260)
83
- axes.add_patch(patch)
84
- axes.text(kp[0], kp[1], k, fontsize=10, color='black', ha="center", va="center", zorder=320, )
85
- else:
86
- patch = plt.Circle(kp,
87
- radius,
88
- color=c,
89
- path_effects=[mpe.withStroke(linewidth=2, foreground='black')],
90
- zorder=200)
91
- axes.add_patch(patch)
92
- axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center",
93
- zorder=300,
94
- path_effects=[
95
- mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')])
96
- # axes.text(kp[0], kp[1], k)
97
- plt.draw()
98
-
99
- if adj is not None:
100
- # Make max value 6
101
- draw_skeleton = adj ** 1
102
- max_skel_val = np.max(draw_skeleton)
103
- draw_skeleton = draw_skeleton / max_skel_val * 6
104
- for i in range(1, keypoint.shape[0]):
105
- for j in range(0, i):
106
- if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0:
107
- if color is None:
108
- num_colors = int((skeleton > 0.05).sum() / 2)
109
- color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1)))
110
- c = next(color)
111
- elif isinstance(color, str):
112
- c = color
113
- elif isinstance(color, collections.Iterable):
114
- c = next(color)
115
- else:
116
- raise ValueError("Color must be a string or an iterable")
117
- if w[i] > 0 and w[j] > 0 and skeleton[i][j] > 0:
118
- width = draw_skeleton[i][j]
119
- stroke_width = width + (width / 3)
120
- patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]],
121
- [keypoint[i, 1], keypoint[j, 1]],
122
- linewidth=width, color=c, alpha=0.6,
123
- path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')],
124
- zorder=1)
125
- axes.add_artist(patch)
126
 
127
  plt.axis('off') # command for hiding the axis.
128
- plt.subplots_adjust(0, 0, 1, 1, 0, 0)
129
  return plt
130
 
131
  @spaces.GPU(duration=30)
132
  def process(query_img, state,
133
  cfg_path='configs/test/1shot_split1.py',
134
  checkpoint_path='ckpt/1shot_split1.pth'):
 
 
135
  cfg = Config.fromfile(cfg_path)
136
- width, height, _ = state['images']['image_orig'].shape
137
  kp_src_np = np.array(state['points']).copy().astype(np.float32)
138
- kp_src_np[:, 0] = kp_src_np[:, 0] / (width // 4) * cfg.model.encoder_config.img_size
139
- kp_src_np[:, 1] = kp_src_np[:, 1] / (height // 4) * cfg.model.encoder_config.img_size
140
- kp_src_np = np.flip(kp_src_np, 1).copy()
141
  kp_src_tensor = torch.tensor(kp_src_np).float()
142
  preprocess = transforms.Compose([
143
  transforms.ToTensor(),
144
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
145
- Resize_Pad(cfg.model.encoder_config.img_size,
146
- cfg.model.encoder_config.img_size)])
147
 
148
  if len(state['skeleton']) == 0:
149
  state['skeleton'] = [(0, 0)]
@@ -154,8 +134,7 @@ def process(query_img, state,
154
  # Create heatmap from keypoints
155
  genHeatMap = TopDownGenerateTargetFewShot()
156
  data_cfg = cfg.data_cfg
157
- data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size,
158
- cfg.model.encoder_config.img_size])
159
  data_cfg['joint_weights'] = None
160
  data_cfg['use_different_joint_weights'] = False
161
  kp_src_3d = torch.cat(
@@ -172,22 +151,23 @@ def process(query_img, state,
172
  torch.tensor(target_weight_s).float()[None])
173
 
174
  data = {
175
- 'img_s': [support_img],
176
- 'img_q': q_img,
177
- 'target_s': [target_s],
178
- 'target_weight_s': [target_weight_s],
179
  'target_q': None,
180
  'target_weight_q': None,
181
  'return_loss': False,
182
  'img_metas': [{'sample_skeleton': [state['skeleton']],
183
  'query_skeleton': state['skeleton'],
184
- 'sample_joints_3d': [kp_src_3d],
185
- 'query_joints_3d': kp_src_3d,
186
  'sample_center': [kp_src_tensor.mean(dim=0)],
187
  'query_center': kp_src_tensor.mean(dim=0),
188
  'sample_scale': [
189
  kp_src_tensor.max(dim=0)[0] -
190
- kp_src_tensor.min(dim=0)[0]],
 
191
  'query_scale': kp_src_tensor.max(dim=0)[0] -
192
  kp_src_tensor.min(dim=0)[0],
193
  'sample_rotation': [0],
@@ -204,7 +184,7 @@ def process(query_img, state,
204
  if fp16_cfg is not None:
205
  wrap_fp16_model(model)
206
  load_checkpoint(model, checkpoint_path, map_location='cpu')
207
- model.eval()
208
  with torch.no_grad():
209
  outputs = model(**data)
210
  # visualize results
@@ -218,39 +198,47 @@ def process(query_img, state,
218
  vis_s_weight,
219
  None,
220
  vis_s_weight,
221
- outputs['skeleton'],
222
  torch.tensor(outputs['points']).squeeze(),
223
  original_skeleton=state['skeleton'],
224
  img_alpha=1.0,
225
  )
226
- return out, state
227
-
228
-
229
- def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02):
230
- state['color_idx'] = 0
231
- state['images']['image_orig'] = np.array(support_img)[:, :, ::-1].copy()
232
- support_img, posed_support, _ = set_query(support_img, state, example=True)
233
- w, h = support_img.size
234
- draw_pose = ImageDraw.Draw(support_img)
235
- draw_limb = ImageDraw.Draw(posed_support)
236
- r = int(r * w)
237
- width = int(width * w)
238
- for pixel in state['kp_src']:
239
- leftUpPoint = (pixel[1] - r, pixel[0] - r)
240
- rightDownPoint = (pixel[1] + r, pixel[0] + r)
241
- twoPointList = [leftUpPoint, rightDownPoint]
242
- draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
243
- draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
244
- for limb in state['skeleton']:
245
- point_a = state['kp_src'][limb[0]][::-1]
246
- point_b = state['kp_src'][limb[1]][::-1]
247
- if state['color_idx'] < len(COLORS):
248
- c = COLORS[state['color_idx']]
249
- state['color_idx'] += 1
250
- else:
251
- c = random.choices(range(256), k=3)
252
- draw_limb.line([point_a, point_b], fill=tuple(c), width=width)
253
- return support_img, posed_support, query_img, state
 
 
 
 
 
 
 
 
254
 
255
 
256
  def get_select_coords(global_state,
@@ -260,17 +248,6 @@ def get_select_coords(global_state,
260
  """
261
  xy = evt.index
262
  global_state["points"].append(xy)
263
- # point_idx = get_latest_points_pair(points)
264
- # if point_idx is None:
265
- # points[0] = {'start': xy, 'target': None}
266
- # print(f'Click Image - Start - {xy}')
267
- # elif points[point_idx].get('target', None) is None:
268
- # points[point_idx]['target'] = xy
269
- # print(f'Click Image - Target - {xy}')
270
- # else:
271
- # points[point_idx + 1] = {'start': xy, 'target': None}
272
- # print(f'Click Image - Start - {xy}')
273
-
274
  image_raw = global_state['images']['image_kp']
275
  image_draw = update_image_draw(
276
  image_raw,
@@ -362,7 +339,9 @@ def print_memory_usage():
362
  torch.cuda.max_memory_allocated()
363
  print(f"Available GPU memory: {available_memory / 1e9} GB")
364
  else:
 
365
  print("No GPU available")
 
366
 
367
  def draw_limbs_on_image(image,
368
  points,):
 
1
+ import copy
2
  import spaces
3
+ import json
4
  import random
5
  import collections
6
  import gradio as gr
 
15
  from mmpose.models import build_posenet
16
  from torchvision import transforms
17
  import matplotlib.patheffects as mpe
18
+
19
+ from EdgeCape import TopDownGenerateTargetFewShot
20
  from demo import Resize_Pad
21
  from EdgeCape.models import *
22
 
23
 
24
  def process_img(support_image, global_state):
25
  global_state['images']['image_orig'] = support_image
26
+ if global_state["load_example"]:
27
+ global_state["load_example"] = False
28
+ return global_state['images']['image_kp'], global_state
29
+ _, _ = reset_kp(global_state)
30
  return support_image, global_state
31
 
32
 
 
34
  adj_mx = torch.empty(0, device=device)
35
  batch_size = len(skeleton)
36
  for b in range(batch_size):
37
+ edges = torch.tensor(skeleton[b]).long()
38
  adj = torch.zeros(num_pts, num_pts, device=device)
39
  adj[edges[:, 0], edges[:, 1]] = 1
40
  adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0)
 
43
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
44
  return adj
45
 
 
46
  @spaces.GPU(duration=30)
47
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
48
  skeleton=None, prediction=None, radius=6, in_color=None,
 
51
  prediction = prediction[-1] * h
52
  if isinstance(prediction, torch.Tensor):
53
  prediction = prediction.cpu().numpy()
54
+ if isinstance(original_skeleton, list):
55
+ original_skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0]
 
 
56
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
57
+ img = query_img
58
+ w = query_w
59
+ keypoint = prediction
60
+ adj = skeleton
61
+ color = None
62
+ f, axes = plt.subplots()
63
+ plt.imshow(img, alpha=img_alpha)
64
+ for k in range(keypoint.shape[0]):
65
+ if w[k] > 0:
66
+ kp = keypoint[k, :2]
67
+ c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6)
68
+ patch = plt.Circle(kp,
69
+ radius,
70
+ color=c,
71
+ path_effects=[mpe.withStroke(linewidth=2, foreground='black')],
72
+ zorder=200)
73
+ axes.add_patch(patch)
74
+ axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center",
75
+ zorder=300,
76
+ path_effects=[
77
+ mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')])
78
+ plt.draw()
79
+
80
+ if adj is not None:
81
+ max_skel_val = np.max(adj)
82
+ draw_skeleton = adj / max_skel_val * 6
83
+ for i in range(1, keypoint.shape[0]):
84
+ for j in range(0, i):
85
+ if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0:
86
+ if color is None:
87
+ num_colors = int((adj > 0.05).sum() / 2)
88
+ color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1)))
89
+ c = next(color)
90
+ elif isinstance(color, str):
91
+ c = color
92
+ elif isinstance(color, collections.Iterable):
93
+ c = next(color)
94
+ else:
95
+ raise ValueError("Color must be a string or an iterable")
96
+ if w[i] > 0 and w[j] > 0 and adj[i][j] > 0:
97
+ width = draw_skeleton[i][j]
98
+ stroke_width = width + (width / 3)
99
+ patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]],
100
+ [keypoint[i, 1], keypoint[j, 1]],
101
+ linewidth=width, color=c, alpha=0.6,
102
+ path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')],
103
+ zorder=1)
104
+ axes.add_artist(patch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  plt.axis('off') # command for hiding the axis.
 
107
  return plt
108
 
109
  @spaces.GPU(duration=30)
110
  def process(query_img, state,
111
  cfg_path='configs/test/1shot_split1.py',
112
  checkpoint_path='ckpt/1shot_split1.pth'):
113
+ print(state)
114
+ device = print_memory_usage()
115
  cfg = Config.fromfile(cfg_path)
116
+ width, height, _ = np.array(state['images']['image_orig']).shape
117
  kp_src_np = np.array(state['points']).copy().astype(np.float32)
118
+ kp_src_np[:, 0] = kp_src_np[:, 0] / width * 256
119
+ kp_src_np[:, 1] = kp_src_np[:, 1] / height * 256
120
+ kp_src_np = kp_src_np.copy()
121
  kp_src_tensor = torch.tensor(kp_src_np).float()
122
  preprocess = transforms.Compose([
123
  transforms.ToTensor(),
124
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
125
+ Resize_Pad(256, 256)
126
+ ])
127
 
128
  if len(state['skeleton']) == 0:
129
  state['skeleton'] = [(0, 0)]
 
134
  # Create heatmap from keypoints
135
  genHeatMap = TopDownGenerateTargetFewShot()
136
  data_cfg = cfg.data_cfg
137
+ data_cfg['image_size'] = np.array([256, 256])
 
138
  data_cfg['joint_weights'] = None
139
  data_cfg['use_different_joint_weights'] = False
140
  kp_src_3d = torch.cat(
 
151
  torch.tensor(target_weight_s).float()[None])
152
 
153
  data = {
154
+ 'img_s': [support_img.to(device)],
155
+ 'img_q': q_img.to(device),
156
+ 'target_s': [target_s.to(device)],
157
+ 'target_weight_s': [target_weight_s.to(device)],
158
  'target_q': None,
159
  'target_weight_q': None,
160
  'return_loss': False,
161
  'img_metas': [{'sample_skeleton': [state['skeleton']],
162
  'query_skeleton': state['skeleton'],
163
+ 'sample_joints_3d': [kp_src_3d.to(device)],
164
+ 'query_joints_3d': kp_src_3d.to(device),
165
  'sample_center': [kp_src_tensor.mean(dim=0)],
166
  'query_center': kp_src_tensor.mean(dim=0),
167
  'sample_scale': [
168
  kp_src_tensor.max(dim=0)[0] -
169
+ kp_src_tensor.min(dim=0)[0]
170
+ ],
171
  'query_scale': kp_src_tensor.max(dim=0)[0] -
172
  kp_src_tensor.min(dim=0)[0],
173
  'sample_rotation': [0],
 
184
  if fp16_cfg is not None:
185
  wrap_fp16_model(model)
186
  load_checkpoint(model, checkpoint_path, map_location='cpu')
187
+ model.eval().to(device)
188
  with torch.no_grad():
189
  outputs = model(**data)
190
  # visualize results
 
198
  vis_s_weight,
199
  None,
200
  vis_s_weight,
201
+ outputs['skeleton'][1],
202
  torch.tensor(outputs['points']).squeeze(),
203
  original_skeleton=state['skeleton'],
204
  img_alpha=1.0,
205
  )
206
+ return out
207
+
208
+
209
+ def update_examples(support_img, query_image, global_state_str):
210
+ example_state = json.loads(global_state_str)
211
+ example_state["load_example"] = True
212
+ example_state["curr_type_point"] = "start"
213
+ example_state["prev_point"] = None
214
+ example_state['images'] = {}
215
+ example_state['images']['image_orig'] = support_img
216
+ example_state['images']['image_kp'] = support_img
217
+ example_state['images']['image_skeleton'] = support_img
218
+ image_draw = example_state['images']['image_orig'].copy()
219
+ for xy in example_state['points']:
220
+ image_draw = update_image_draw(
221
+ image_draw,
222
+ xy,
223
+ example_state
224
+ )
225
+ kp_image = image_draw.copy()
226
+ example_state['images']['image_kp'] = kp_image
227
+ pts_list = example_state['points']
228
+ for limb in example_state['skeleton']:
229
+ prev_point = pts_list[limb[0]]
230
+ curr_point = pts_list[limb[1]]
231
+ points = [prev_point, curr_point]
232
+ image_draw = draw_limbs_on_image(image_draw,
233
+ points
234
+ )
235
+ skel_image = image_draw.copy()
236
+ example_state['images']['image_skel'] = skel_image
237
+ return (support_img,
238
+ kp_image,
239
+ skel_image,
240
+ query_image,
241
+ example_state)
242
 
243
 
244
  def get_select_coords(global_state,
 
248
  """
249
  xy = evt.index
250
  global_state["points"].append(xy)
 
 
 
 
 
 
 
 
 
 
 
251
  image_raw = global_state['images']['image_kp']
252
  image_draw = update_image_draw(
253
  image_raw,
 
339
  torch.cuda.max_memory_allocated()
340
  print(f"Available GPU memory: {available_memory / 1e9} GB")
341
  else:
342
+ device = "cpu"
343
  print("No GPU available")
344
+ return device
345
 
346
  def draw_limbs_on_image(image,
347
  points,):