yeq6x commited on
Commit
8ecd333
·
1 Parent(s): 156c303
Files changed (1) hide show
  1. app.py +47 -98
app.py CHANGED
@@ -13,16 +13,12 @@ import dataset
13
  from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
14
  import utils
15
 
16
- try:
17
- import spaces
18
- except ImportError:
19
- print("Spaces is not installed.")
20
 
21
  image_size = 112
22
  batch_size = 32
23
 
24
-
25
- # モデルとデータの読み込み
26
  def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
27
  model = AutoencoderModule(feature_dim=feature_dim)
28
  state_dict = torch.load(model_path)
@@ -78,98 +74,51 @@ def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_si
78
  return test_imgs, points
79
 
80
  # ヒートマップの生成関数
81
- try:
82
- @spaces.GPU
83
- def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
84
- if type(uploaded_image) == str:
85
- uploaded_image = Image.open(uploaded_image)
86
- if type(source_num) == str:
87
- source_num = int(source_num)
88
- if type(x_coords) == str:
89
- x_coords = int(x_coords)
90
- if type(y_coords) == str:
91
- y_coords = int(y_coords)
92
-
93
- dec5, _ = model(x)
94
- feature_map = dec5
95
- # アップロード画像の前処理
96
- if uploaded_image is not None:
97
- uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
98
- else:
99
- uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
100
- target_feature_map, _ = model(uploaded_image)
101
- img = torch.cat((x, uploaded_image))
102
- feature_map = torch.cat((feature_map, target_feature_map))
103
-
104
- source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
105
- keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
106
-
107
- # Matplotlibでプロットして画像として保存
108
- fig, axs = plt.subplots(2, 3, figsize=(10, 6))
109
- axs[0, 0].imshow(source_map, cmap='hot')
110
- axs[0, 0].set_title("Source Map")
111
- axs[0, 1].imshow(target_map, cmap='hot')
112
- axs[0, 1].set_title("Target Map")
113
- axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
114
- axs[0, 2].set_title("Keypoint Map")
115
- axs[1, 0].imshow(blended_source.permute(1, 2, 0))
116
- axs[1, 0].set_title("Blended Source")
117
- axs[1, 1].imshow(blended_target.permute(1, 2, 0))
118
- axs[1, 1].set_title("Blended Target")
119
- axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
120
- axs[1, 2].set_title("Blended Keypoint")
121
- for ax in axs.flat:
122
- ax.axis('off')
123
-
124
- plt.tight_layout()
125
- plt.close(fig)
126
- return fig
127
-
128
- except:
129
- def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
130
- if type(uploaded_image) == str:
131
- uploaded_image = Image.open(uploaded_image)
132
- if type(source_num) == str:
133
- source_num = int(source_num)
134
- if type(x_coords) == str:
135
- x_coords = int(x_coords)
136
- if type(y_coords) == str:
137
- y_coords = int(y_coords)
138
-
139
- dec5, _ = model(x)
140
- feature_map = dec5
141
- # アップロード画像の前処理
142
- if uploaded_image is not None:
143
- uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
144
- else:
145
- uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
146
- target_feature_map, _ = model(uploaded_image)
147
- img = torch.cat((x, uploaded_image))
148
- feature_map = torch.cat((feature_map, target_feature_map))
149
-
150
- source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
151
- keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
152
-
153
- # Matplotlibでプロットして画像として保存
154
- fig, axs = plt.subplots(2, 3, figsize=(10, 6))
155
- axs[0, 0].imshow(source_map, cmap='hot')
156
- axs[0, 0].set_title("Source Map")
157
- axs[0, 1].imshow(target_map, cmap='hot')
158
- axs[0, 1].set_title("Target Map")
159
- axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
160
- axs[0, 2].set_title("Keypoint Map")
161
- axs[1, 0].imshow(blended_source.permute(1, 2, 0))
162
- axs[1, 0].set_title("Blended Source")
163
- axs[1, 1].imshow(blended_target.permute(1, 2, 0))
164
- axs[1, 1].set_title("Blended Target")
165
- axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
166
- axs[1, 2].set_title("Blended Keypoint")
167
- for ax in axs.flat:
168
- ax.axis('off')
169
-
170
- plt.tight_layout()
171
- plt.close(fig)
172
- return fig
173
 
174
  def setup(model_dict, input_image=None):
175
  global model, device, x, test_imgs, points, mean_vector_list
 
13
  from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
14
  import utils
15
 
16
+ import spaces
 
 
 
17
 
18
  image_size = 112
19
  batch_size = 32
20
 
21
+ @spaces.GPU
 
22
  def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
23
  model = AutoencoderModule(feature_dim=feature_dim)
24
  state_dict = torch.load(model_path)
 
74
  return test_imgs, points
75
 
76
  # ヒートマップの生成関数
77
+ @spaces.GPU
78
+ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
79
+ if type(uploaded_image) == str:
80
+ uploaded_image = Image.open(uploaded_image)
81
+ if type(source_num) == str:
82
+ source_num = int(source_num)
83
+ if type(x_coords) == str:
84
+ x_coords = int(x_coords)
85
+ if type(y_coords) == str:
86
+ y_coords = int(y_coords)
87
+
88
+ dec5, _ = model(x)
89
+ feature_map = dec5
90
+ # アップロード画像の前処理
91
+ if uploaded_image is not None:
92
+ uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
93
+ else:
94
+ uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
95
+ target_feature_map, _ = model(uploaded_image)
96
+ img = torch.cat((x, uploaded_image))
97
+ feature_map = torch.cat((feature_map, target_feature_map))
98
+
99
+ source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
100
+ keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
101
+
102
+ # Matplotlibでプロットして画像として保存
103
+ fig, axs = plt.subplots(2, 3, figsize=(10, 6))
104
+ axs[0, 0].imshow(source_map, cmap='hot')
105
+ axs[0, 0].set_title("Source Map")
106
+ axs[0, 1].imshow(target_map, cmap='hot')
107
+ axs[0, 1].set_title("Target Map")
108
+ axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
109
+ axs[0, 2].set_title("Keypoint Map")
110
+ axs[1, 0].imshow(blended_source.permute(1, 2, 0))
111
+ axs[1, 0].set_title("Blended Source")
112
+ axs[1, 1].imshow(blended_target.permute(1, 2, 0))
113
+ axs[1, 1].set_title("Blended Target")
114
+ axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
115
+ axs[1, 2].set_title("Blended Keypoint")
116
+ for ax in axs.flat:
117
+ ax.axis('off')
118
+
119
+ plt.tight_layout()
120
+ plt.close(fig)
121
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def setup(model_dict, input_image=None):
124
  global model, device, x, test_imgs, points, mean_vector_list