yeq6x commited on
Commit
50f5011
·
1 Parent(s): 8ecd333
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -17,8 +17,8 @@ import spaces
17
 
18
  image_size = 112
19
  batch_size = 32
 
20
 
21
- @spaces.GPU
22
  def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
23
  model = AutoencoderModule(feature_dim=feature_dim)
24
  state_dict = torch.load(model_path)
@@ -35,10 +35,9 @@ def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt
35
  model.load_state_dict(new_state_dict)
36
  model.eval()
37
 
38
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  model.to(device)
40
  print("Model loaded successfully.")
41
- return model, device
42
 
43
  def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
44
  filenames = load_filenames(img_dir)
@@ -121,14 +120,14 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
121
  return fig
122
 
123
  def setup(model_dict, input_image=None):
124
- global model, device, x, test_imgs, points, mean_vector_list
125
  # str -> dictに変換
126
  if type(model_dict) == str:
127
  model_dict = eval(model_dict)
128
  model_name = model_dict["name"]
129
  feature_dim = model_dict["feature_dim"]
130
  model_path = f"checkpoints/{model_name}"
131
- model, device = load_model(model_path, feature_dim)
132
  x = load_data(device)
133
  test_imgs, points = load_keypoints(device)
134
  feature_map, _ = model(test_imgs)
 
17
 
18
  image_size = 112
19
  batch_size = 32
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
22
  def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
23
  model = AutoencoderModule(feature_dim=feature_dim)
24
  state_dict = torch.load(model_path)
 
35
  model.load_state_dict(new_state_dict)
36
  model.eval()
37
 
 
38
  model.to(device)
39
  print("Model loaded successfully.")
40
+ return model
41
 
42
  def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
43
  filenames = load_filenames(img_dir)
 
120
  return fig
121
 
122
  def setup(model_dict, input_image=None):
123
+ global model, x, test_imgs, points, mean_vector_list
124
  # str -> dictに変換
125
  if type(model_dict) == str:
126
  model_dict = eval(model_dict)
127
  model_name = model_dict["name"]
128
  feature_dim = model_dict["feature_dim"]
129
  model_path = f"checkpoints/{model_name}"
130
+ model = load_model(model_path, feature_dim)
131
  x = load_data(device)
132
  test_imgs, points = load_keypoints(device)
133
  feature_map, _ = model(test_imgs)