Spaces:
Sleeping
Sleeping
gpu
Browse files
app.py
CHANGED
@@ -13,16 +13,12 @@ import dataset
|
|
13 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
14 |
import utils
|
15 |
|
16 |
-
|
17 |
-
import spaces
|
18 |
-
except ImportError:
|
19 |
-
print("Spaces is not installed.")
|
20 |
|
21 |
image_size = 112
|
22 |
batch_size = 32
|
23 |
|
24 |
-
|
25 |
-
# モデルとデータの読み込み
|
26 |
def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
|
27 |
model = AutoencoderModule(feature_dim=feature_dim)
|
28 |
state_dict = torch.load(model_path)
|
@@ -78,98 +74,51 @@ def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_si
|
|
78 |
return test_imgs, points
|
79 |
|
80 |
# ヒートマップの生成関数
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
return fig
|
127 |
-
|
128 |
-
except:
|
129 |
-
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
130 |
-
if type(uploaded_image) == str:
|
131 |
-
uploaded_image = Image.open(uploaded_image)
|
132 |
-
if type(source_num) == str:
|
133 |
-
source_num = int(source_num)
|
134 |
-
if type(x_coords) == str:
|
135 |
-
x_coords = int(x_coords)
|
136 |
-
if type(y_coords) == str:
|
137 |
-
y_coords = int(y_coords)
|
138 |
-
|
139 |
-
dec5, _ = model(x)
|
140 |
-
feature_map = dec5
|
141 |
-
# アップロード画像の前処理
|
142 |
-
if uploaded_image is not None:
|
143 |
-
uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
|
144 |
-
else:
|
145 |
-
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
|
146 |
-
target_feature_map, _ = model(uploaded_image)
|
147 |
-
img = torch.cat((x, uploaded_image))
|
148 |
-
feature_map = torch.cat((feature_map, target_feature_map))
|
149 |
-
|
150 |
-
source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
|
151 |
-
keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
|
152 |
-
|
153 |
-
# Matplotlibでプロットして画像として保存
|
154 |
-
fig, axs = plt.subplots(2, 3, figsize=(10, 6))
|
155 |
-
axs[0, 0].imshow(source_map, cmap='hot')
|
156 |
-
axs[0, 0].set_title("Source Map")
|
157 |
-
axs[0, 1].imshow(target_map, cmap='hot')
|
158 |
-
axs[0, 1].set_title("Target Map")
|
159 |
-
axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
|
160 |
-
axs[0, 2].set_title("Keypoint Map")
|
161 |
-
axs[1, 0].imshow(blended_source.permute(1, 2, 0))
|
162 |
-
axs[1, 0].set_title("Blended Source")
|
163 |
-
axs[1, 1].imshow(blended_target.permute(1, 2, 0))
|
164 |
-
axs[1, 1].set_title("Blended Target")
|
165 |
-
axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
|
166 |
-
axs[1, 2].set_title("Blended Keypoint")
|
167 |
-
for ax in axs.flat:
|
168 |
-
ax.axis('off')
|
169 |
-
|
170 |
-
plt.tight_layout()
|
171 |
-
plt.close(fig)
|
172 |
-
return fig
|
173 |
|
174 |
def setup(model_dict, input_image=None):
|
175 |
global model, device, x, test_imgs, points, mean_vector_list
|
|
|
13 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
14 |
import utils
|
15 |
|
16 |
+
import spaces
|
|
|
|
|
|
|
17 |
|
18 |
image_size = 112
|
19 |
batch_size = 32
|
20 |
|
21 |
+
@spaces.GPU
|
|
|
22 |
def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
|
23 |
model = AutoencoderModule(feature_dim=feature_dim)
|
24 |
state_dict = torch.load(model_path)
|
|
|
74 |
return test_imgs, points
|
75 |
|
76 |
# ヒートマップの生成関数
|
77 |
+
@spaces.GPU
|
78 |
+
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
79 |
+
if type(uploaded_image) == str:
|
80 |
+
uploaded_image = Image.open(uploaded_image)
|
81 |
+
if type(source_num) == str:
|
82 |
+
source_num = int(source_num)
|
83 |
+
if type(x_coords) == str:
|
84 |
+
x_coords = int(x_coords)
|
85 |
+
if type(y_coords) == str:
|
86 |
+
y_coords = int(y_coords)
|
87 |
+
|
88 |
+
dec5, _ = model(x)
|
89 |
+
feature_map = dec5
|
90 |
+
# アップロード画像の前処理
|
91 |
+
if uploaded_image is not None:
|
92 |
+
uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
|
93 |
+
else:
|
94 |
+
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
|
95 |
+
target_feature_map, _ = model(uploaded_image)
|
96 |
+
img = torch.cat((x, uploaded_image))
|
97 |
+
feature_map = torch.cat((feature_map, target_feature_map))
|
98 |
+
|
99 |
+
source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
|
100 |
+
keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
|
101 |
+
|
102 |
+
# Matplotlibでプロットして画像として保存
|
103 |
+
fig, axs = plt.subplots(2, 3, figsize=(10, 6))
|
104 |
+
axs[0, 0].imshow(source_map, cmap='hot')
|
105 |
+
axs[0, 0].set_title("Source Map")
|
106 |
+
axs[0, 1].imshow(target_map, cmap='hot')
|
107 |
+
axs[0, 1].set_title("Target Map")
|
108 |
+
axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
|
109 |
+
axs[0, 2].set_title("Keypoint Map")
|
110 |
+
axs[1, 0].imshow(blended_source.permute(1, 2, 0))
|
111 |
+
axs[1, 0].set_title("Blended Source")
|
112 |
+
axs[1, 1].imshow(blended_target.permute(1, 2, 0))
|
113 |
+
axs[1, 1].set_title("Blended Target")
|
114 |
+
axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
|
115 |
+
axs[1, 2].set_title("Blended Keypoint")
|
116 |
+
for ax in axs.flat:
|
117 |
+
ax.axis('off')
|
118 |
+
|
119 |
+
plt.tight_layout()
|
120 |
+
plt.close(fig)
|
121 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def setup(model_dict, input_image=None):
|
124 |
global model, device, x, test_imgs, points, mean_vector_list
|