xywwww commited on
Commit
1d4b712
·
verified ·
1 Parent(s): 4609937

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -3,10 +3,14 @@ import torch
3
  from annotator.util import resize_image, HWC3
4
  from cldm.model import create_model, load_state_dict
5
  from cldm.ddim_hacked import DDIMSampler
 
6
 
7
  # Initialize the model and other components
 
8
  model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
9
- model.load_state_dict(load_state_dict('xywwww/scene_diffusion/checkpoints/epoch=25-step=112553.ckpt', location='cuda'), strict=False)
 
 
10
  model = model.cuda()
11
  ddim_sampler = DDIMSampler(model)
12
 
 
3
  from annotator.util import resize_image, HWC3
4
  from cldm.model import create_model, load_state_dict
5
  from cldm.ddim_hacked import DDIMSampler
6
+ from huggingface_hub import hf_hub_download
7
 
8
  # Initialize the model and other components
9
+ # config = "./models/cldm_v21_512_latctrl_coltrans.yaml'"
10
  model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
11
+ ckpt = hf_hub_download(repo_id="xywwww/scene_diffusion", filename="checkpoints/epoch=25-step=112553.ckpt")
12
+ # model.load_state_dict(load_state_dict('xywwww/scene_diffusion/checkpoints/epoch=25-step=112553.ckpt', location='cuda'), strict=False)
13
+ model = load_model_checkpoint(model, ckpt)
14
  model = model.cuda()
15
  ddim_sampler = DDIMSampler(model)
16