fbnnb commited on
Commit
d722551
·
verified ·
1 Parent(s): e49d7ab

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +15 -15
gradio_app.py CHANGED
@@ -120,23 +120,23 @@ cn_config = OmegaConf.load(cn_config_file)
120
  cn_model_config = cn_config.pop("control_stage_config", OmegaConf.create())
121
 
122
  print("before init")
123
- model_list = []
124
- gpu_num = 1
125
- for gpu_id in range(gpu_num):
126
- model = instantiate_from_config(model_config)
127
- cn_model = instantiate_from_config(cn_model_config)
128
 
129
- # model = model.cuda(gpu_id)
130
- assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
131
- model = load_model_checkpoint(model, ckpt_path)
132
- model.eval()
133
 
134
- cn_model.load_state_dict(load_state_dict(cn_ckpt_path, location='cpu'))
135
- cn_model.eval()
136
 
137
- model.control_model = cn_model
138
 
139
- model_list.append(model)
140
 
141
  save_fps = 8
142
  print("resolution:", resolution)
@@ -161,8 +161,8 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
161
  if steps > 60:
162
  steps = 60
163
 
164
- global model_list
165
- model = model_list[gpu_id]
166
  model = model.cuda()
167
 
168
  batch_size=1
 
120
  cn_model_config = cn_config.pop("control_stage_config", OmegaConf.create())
121
 
122
  print("before init")
123
+ # model_list = []
124
+ # gpu_num = 1
125
+ # for gpu_id in range(gpu_num):
126
+ model = instantiate_from_config(model_config)
127
+ cn_model = instantiate_from_config(cn_model_config)
128
 
129
+ # model = model.cuda(gpu_id)
130
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
131
+ model = load_model_checkpoint(model, ckpt_path)
132
+ model.eval()
133
 
134
+ cn_model.load_state_dict(load_state_dict(cn_ckpt_path, location='cpu'))
135
+ cn_model.eval()
136
 
137
+ model.control_model = cn_model
138
 
139
+ # model_list.append(model)
140
 
141
  save_fps = 8
142
  print("resolution:", resolution)
 
161
  if steps > 60:
162
  steps = 60
163
 
164
+ global model
165
+ # model = model_list[gpu_id]
166
  model = model.cuda()
167
 
168
  batch_size=1