apolinario commited on
Commit
aaee24f
·
1 Parent(s): 320ce78

debug faster by not downloading the model

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -9,7 +9,7 @@ sys.path.append('./latent-diffusion')
9
  from taming.models import vqgan
10
  from ldm.util import instantiate_from_config
11
 
12
- torch.hub.download_url_to_file('https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt','txt2img-f8-large.ckpt')
13
 
14
  #@title Import stuff
15
  import argparse, os, sys, glob
@@ -26,7 +26,7 @@ from ldm.models.diffusion.plms import PLMSSampler
26
 
27
  def load_model_from_config(config, ckpt, verbose=False):
28
  print(f"Loading model from {ckpt}")
29
- pl_sd = torch.load(ckpt, map_location="cuda:0")
30
  sd = pl_sd["state_dict"]
31
  model = instantiate_from_config(config.model)
32
  m, u = model.load_state_dict(sd, strict=False)
@@ -41,8 +41,8 @@ def load_model_from_config(config, ckpt, verbose=False):
41
  model.eval()
42
  return model
43
 
44
- config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
45
- model = load_model_from_config(config, f"txt2img-f8-large.ckpt") # TODO: check path
46
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
  model = model.to(device)
48
 
 
9
  from taming.models import vqgan
10
  from ldm.util import instantiate_from_config
11
 
12
+ #torch.hub.download_url_to_file('https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt','txt2img-f8-large.ckpt')
13
 
14
  #@title Import stuff
15
  import argparse, os, sys, glob
 
26
 
27
  def load_model_from_config(config, ckpt, verbose=False):
28
  print(f"Loading model from {ckpt}")
29
+ pl_sd = torch.load(ckpt, map_location="cuda")
30
  sd = pl_sd["state_dict"]
31
  model = instantiate_from_config(config.model)
32
  m, u = model.load_state_dict(sd, strict=False)
 
41
  model.eval()
42
  return model
43
 
44
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
45
+ model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
46
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
  model = model.to(device)
48