Update app.py
Browse files
app.py
CHANGED
@@ -14,16 +14,21 @@ class model:
|
|
14 |
def __init__(self):
|
15 |
self.model = None
|
16 |
self.model_name = None
|
|
|
17 |
|
18 |
def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
|
21 |
if self.model_name != model_name:
|
22 |
self.model_name = model_name
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
# Load model config and set ckpt_path if not provided in config
|
|
|
14 |
def __init__(self):
|
15 |
self.model = None
|
16 |
self.model_name = None
|
17 |
+
self.model_dict = {}
|
18 |
|
19 |
def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
if self.model_name != model_name:
|
23 |
self.model_name = model_name
|
24 |
+
if self.model_name not in self.model_dict.keys():
|
25 |
+
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
|
26 |
+
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
|
27 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
|
28 |
+
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
|
29 |
+
self.model_dict.update({self.model_name:[model_ckpt_path, model_config_path]})
|
30 |
+
else:
|
31 |
+
model_ckpt_path, model_config_path = self.model_dict[self.model_name]
|
32 |
|
33 |
|
34 |
# Load model config and set ckpt_path if not provided in config
|