Emaad commited on
Commit
871c359
1 Parent(s): b9ae49c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -14,16 +14,21 @@ class model:
14
  def __init__(self):
15
  self.model = None
16
  self.model_name = None
 
17
 
18
  def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  if self.model_name != model_name:
22
  self.model_name = model_name
23
- model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
24
- model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
25
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
26
- hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
 
 
 
 
27
 
28
 
29
  # Load model config and set ckpt_path if not provided in config
 
14
  def __init__(self):
15
  self.model = None
16
  self.model_name = None
17
+ self.model_dict = {}
18
 
19
  def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  if self.model_name != model_name:
23
  self.model_name = model_name
24
+ if self.model_name not in self.model_dict.keys():
25
+ model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
26
+ model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
27
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
28
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
29
+ self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
30
+ else:
31
+ model_ckpt_path, model_config_path = self.model_dict[self.model_name]
32
 
33
 
34
  # Load model config and set ckpt_path if not provided in config