yeq6x commited on
Commit
23281d5
·
1 Parent(s): 19d010a
Files changed (1) hide show
  1. app.py +134 -69
app.py CHANGED
@@ -81,81 +81,146 @@ def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_si
81
  try:
82
  @spaces.GPU
83
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
84
- return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  except:
86
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
87
- return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
88
-
89
- def _get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
90
- if type(uploaded_image) == str:
91
- uploaded_image = Image.open(uploaded_image)
92
- if type(source_num) == str:
93
- source_num = int(source_num)
94
- if type(x_coords) == str:
95
- x_coords = int(x_coords)
96
- if type(y_coords) == str:
97
- y_coords = int(y_coords)
98
-
99
- dec5, _ = model(x)
100
- feature_map = dec5
101
- # アップロード画像の前処理
102
- if uploaded_image is not None:
103
- uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
104
- else:
105
- uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
106
- target_feature_map, _ = model(uploaded_image)
107
- img = torch.cat((x, uploaded_image))
108
- feature_map = torch.cat((feature_map, target_feature_map))
109
-
110
- source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
111
- keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
112
-
113
- # Matplotlibでプロットして画像として保存
114
- fig, axs = plt.subplots(2, 3, figsize=(10, 6))
115
- axs[0, 0].imshow(source_map, cmap='hot')
116
- axs[0, 0].set_title("Source Map")
117
- axs[0, 1].imshow(target_map, cmap='hot')
118
- axs[0, 1].set_title("Target Map")
119
- axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
120
- axs[0, 2].set_title("Keypoint Map")
121
- axs[1, 0].imshow(blended_source.permute(1, 2, 0))
122
- axs[1, 0].set_title("Blended Source")
123
- axs[1, 1].imshow(blended_target.permute(1, 2, 0))
124
- axs[1, 1].set_title("Blended Target")
125
- axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
126
- axs[1, 2].set_title("Blended Keypoint")
127
- for ax in axs.flat:
128
- ax.axis('off')
129
-
130
- plt.tight_layout()
131
- plt.close(fig)
132
- return fig
133
-
134
- def setup(model_dict, input_image=None):
135
- global model, device, x, test_imgs, points, mean_vector_list
136
- # str -> dictに変換
137
- if type(model_dict) == str:
138
- model_dict = eval(model_dict)
139
- model_name = model_dict["name"]
140
- feature_dim = model_dict["feature_dim"]
141
- model_path = f"checkpoints/{model_name}"
142
- model, device = load_model(model_path, feature_dim)
143
- x = load_data(device)
144
- test_imgs, points = load_keypoints(device)
145
- feature_map, _ = model(test_imgs)
146
- mean_vector_list = utils.get_mean_vector(feature_map, points)
147
-
148
- if input_image is not None:
149
- fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
150
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
 
 
 
152
 
153
- models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
154
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
155
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
156
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
157
 
158
- setup(models[0])
159
 
160
 
161
  with gr.Blocks() as demo:
 
81
  try:
82
  @spaces.GPU
83
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
84
+ if type(uploaded_image) == str:
85
+ uploaded_image = Image.open(uploaded_image)
86
+ if type(source_num) == str:
87
+ source_num = int(source_num)
88
+ if type(x_coords) == str:
89
+ x_coords = int(x_coords)
90
+ if type(y_coords) == str:
91
+ y_coords = int(y_coords)
92
+
93
+ dec5, _ = model(x)
94
+ feature_map = dec5
95
+ # アップロード画像の前処理
96
+ if uploaded_image is not None:
97
+ uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
98
+ else:
99
+ uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
100
+ target_feature_map, _ = model(uploaded_image)
101
+ img = torch.cat((x, uploaded_image))
102
+ feature_map = torch.cat((feature_map, target_feature_map))
103
+
104
+ source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
105
+ keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
106
+
107
+ # Matplotlibでプロットして画像として保存
108
+ fig, axs = plt.subplots(2, 3, figsize=(10, 6))
109
+ axs[0, 0].imshow(source_map, cmap='hot')
110
+ axs[0, 0].set_title("Source Map")
111
+ axs[0, 1].imshow(target_map, cmap='hot')
112
+ axs[0, 1].set_title("Target Map")
113
+ axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
114
+ axs[0, 2].set_title("Keypoint Map")
115
+ axs[1, 0].imshow(blended_source.permute(1, 2, 0))
116
+ axs[1, 0].set_title("Blended Source")
117
+ axs[1, 1].imshow(blended_target.permute(1, 2, 0))
118
+ axs[1, 1].set_title("Blended Target")
119
+ axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
120
+ axs[1, 2].set_title("Blended Keypoint")
121
+ for ax in axs.flat:
122
+ ax.axis('off')
123
+
124
+ plt.tight_layout()
125
+ plt.close(fig)
126
+ return fig
127
+
128
+ def setup(model_dict, input_image=None):
129
+ global model, device, x, test_imgs, points, mean_vector_list
130
+ # str -> dictに変換
131
+ if type(model_dict) == str:
132
+ model_dict = eval(model_dict)
133
+ model_name = model_dict["name"]
134
+ feature_dim = model_dict["feature_dim"]
135
+ model_path = f"checkpoints/{model_name}"
136
+ model, device = load_model(model_path, feature_dim)
137
+ x = load_data(device)
138
+ test_imgs, points = load_keypoints(device)
139
+ feature_map, _ = model(test_imgs)
140
+ mean_vector_list = utils.get_mean_vector(feature_map, points)
141
+
142
+ if input_image is not None:
143
+ fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
144
+ return fig
145
+
146
+
147
+ models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
148
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
149
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
150
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
151
+
152
+ setup(models[0])
153
  except:
154
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
155
+ if type(uploaded_image) == str:
156
+ uploaded_image = Image.open(uploaded_image)
157
+ if type(source_num) == str:
158
+ source_num = int(source_num)
159
+ if type(x_coords) == str:
160
+ x_coords = int(x_coords)
161
+ if type(y_coords) == str:
162
+ y_coords = int(y_coords)
163
+
164
+ dec5, _ = model(x)
165
+ feature_map = dec5
166
+ # アップロード画像の前処理
167
+ if uploaded_image is not None:
168
+ uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
169
+ else:
170
+ uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
171
+ target_feature_map, _ = model(uploaded_image)
172
+ img = torch.cat((x, uploaded_image))
173
+ feature_map = torch.cat((feature_map, target_feature_map))
174
+
175
+ source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
176
+ keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
177
+
178
+ # Matplotlibでプロットして画像として保存
179
+ fig, axs = plt.subplots(2, 3, figsize=(10, 6))
180
+ axs[0, 0].imshow(source_map, cmap='hot')
181
+ axs[0, 0].set_title("Source Map")
182
+ axs[0, 1].imshow(target_map, cmap='hot')
183
+ axs[0, 1].set_title("Target Map")
184
+ axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
185
+ axs[0, 2].set_title("Keypoint Map")
186
+ axs[1, 0].imshow(blended_source.permute(1, 2, 0))
187
+ axs[1, 0].set_title("Blended Source")
188
+ axs[1, 1].imshow(blended_target.permute(1, 2, 0))
189
+ axs[1, 1].set_title("Blended Target")
190
+ axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
191
+ axs[1, 2].set_title("Blended Keypoint")
192
+ for ax in axs.flat:
193
+ ax.axis('off')
194
+
195
+ plt.tight_layout()
196
+ plt.close(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return fig
198
+
199
+ def setup(model_dict, input_image=None):
200
+ global model, device, x, test_imgs, points, mean_vector_list
201
+ # str -> dictに変換
202
+ if type(model_dict) == str:
203
+ model_dict = eval(model_dict)
204
+ model_name = model_dict["name"]
205
+ feature_dim = model_dict["feature_dim"]
206
+ model_path = f"checkpoints/{model_name}"
207
+ model, device = load_model(model_path, feature_dim)
208
+ x = load_data(device)
209
+ test_imgs, points = load_keypoints(device)
210
+ feature_map, _ = model(test_imgs)
211
+ mean_vector_list = utils.get_mean_vector(feature_map, points)
212
 
213
+ if input_image is not None:
214
+ fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
215
+ return fig
216
+
217
 
218
+ models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
219
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
220
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
221
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
222
 
223
+ setup(models[0])
224
 
225
 
226
  with gr.Blocks() as demo: