Emaad commited on
Commit
658d610
·
1 Parent(s): d75590d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -36,12 +36,16 @@ class model:
36
  def __init__(self):
37
  self.model = None
38
  self.model_name = None
 
39
 
40
  def gradio_demo(self, model_name, sequence_input, image):
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  if self.model_name != model_name:
 
 
43
  self.model_name = model_name
44
  model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
 
45
  model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
46
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
47
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
 
36
  def __init__(self):
37
  self.model = None
38
  self.model_name = None
39
+ self.model_path = None
40
 
41
  def gradio_demo(self, model_name, sequence_input, image):
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  if self.model_name != model_name:
44
+ if self.model_path is not None:
45
+ os.remove(self.model_path)
46
  self.model_name = model_name
47
  model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
48
+ self.model_path = model_ckpt_path
49
  model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
50
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
51
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")