Spaces:
Sleeping
Sleeping
gpu
Browse files
app.py
CHANGED
@@ -17,8 +17,8 @@ import spaces
|
|
17 |
|
18 |
image_size = 112
|
19 |
batch_size = 32
|
|
|
20 |
|
21 |
-
@spaces.GPU
|
22 |
def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
|
23 |
model = AutoencoderModule(feature_dim=feature_dim)
|
24 |
state_dict = torch.load(model_path)
|
@@ -35,10 +35,9 @@ def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt
|
|
35 |
model.load_state_dict(new_state_dict)
|
36 |
model.eval()
|
37 |
|
38 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
model.to(device)
|
40 |
print("Model loaded successfully.")
|
41 |
-
return model
|
42 |
|
43 |
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
|
44 |
filenames = load_filenames(img_dir)
|
@@ -121,14 +120,14 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
121 |
return fig
|
122 |
|
123 |
def setup(model_dict, input_image=None):
|
124 |
-
global model,
|
125 |
# str -> dictに変換
|
126 |
if type(model_dict) == str:
|
127 |
model_dict = eval(model_dict)
|
128 |
model_name = model_dict["name"]
|
129 |
feature_dim = model_dict["feature_dim"]
|
130 |
model_path = f"checkpoints/{model_name}"
|
131 |
-
model
|
132 |
x = load_data(device)
|
133 |
test_imgs, points = load_keypoints(device)
|
134 |
feature_map, _ = model(test_imgs)
|
|
|
17 |
|
18 |
image_size = 112
|
19 |
batch_size = 32
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
|
|
22 |
def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
|
23 |
model = AutoencoderModule(feature_dim=feature_dim)
|
24 |
state_dict = torch.load(model_path)
|
|
|
35 |
model.load_state_dict(new_state_dict)
|
36 |
model.eval()
|
37 |
|
|
|
38 |
model.to(device)
|
39 |
print("Model loaded successfully.")
|
40 |
+
return model
|
41 |
|
42 |
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
|
43 |
filenames = load_filenames(img_dir)
|
|
|
120 |
return fig
|
121 |
|
122 |
def setup(model_dict, input_image=None):
|
123 |
+
global model, x, test_imgs, points, mean_vector_list
|
124 |
# str -> dictに変換
|
125 |
if type(model_dict) == str:
|
126 |
model_dict = eval(model_dict)
|
127 |
model_name = model_dict["name"]
|
128 |
feature_dim = model_dict["feature_dim"]
|
129 |
model_path = f"checkpoints/{model_name}"
|
130 |
+
model = load_model(model_path, feature_dim)
|
131 |
x = load_data(device)
|
132 |
test_imgs, points = load_keypoints(device)
|
133 |
feature_map, _ = model(test_imgs)
|