caohy666 commited on
Commit
c07e913
·
1 Parent(s): 2b2259b

<feat> single thread

Browse files
Files changed (1) hide show
  1. app.py +196 -190
app.py CHANGED
@@ -11,6 +11,7 @@ import spaces
11
  import gc
12
  import tempfile
13
  import imageio
 
14
  import gradio as gr
15
  import numpy as np
16
 
@@ -46,6 +47,8 @@ there's no need to manually input edge maps, depth maps, or other condition imag
46
  The corresponding condition images will be automatically extracted.
47
  """
48
 
 
 
49
 
50
  def init_basemodel():
51
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor, pipe, current_task
@@ -100,202 +103,204 @@ def init_basemodel():
100
 
101
 
102
  @spaces.GPU
 
103
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
104
  # set up the model
105
- global pipe, current_task, transformer
106
- if current_task != task:
107
- if current_task is None:
108
- # insert LoRA
109
- lora_config = LoraConfig(
110
- r=16,
111
- lora_alpha=16,
112
- init_lora_weights="gaussian",
113
- target_modules=[
114
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
115
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
116
- 'ff.net.0.proj', 'ff.net.2',
117
- 'ff_context.net.0.proj', 'ff_context.net.2',
118
- 'norm1_context.linear', 'norm1.linear',
119
- 'norm.linear', 'proj_mlp', 'proj_out',
120
- ]
121
- )
122
- transformer.add_adapter(lora_config)
123
- else:
124
- def restore_forward(module):
125
- def restored_forward(self, x, *args, **kwargs):
126
- return module.original_forward(x, *args, **kwargs)
127
- return restored_forward.__get__(module, type(module))
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  for n, m in transformer.named_modules():
130
  if isinstance(m, peft.tuners.lora.layer.Linear):
131
- m.forward = restore_forward(m)
132
-
133
- current_task = task
134
-
135
- # hack LoRA forward
136
- def create_hacked_forward(module):
137
- if not hasattr(module, 'original_forward'):
138
- module.original_forward = module.forward
139
- lora_forward = module.forward
140
- non_lora_forward = module.base_layer.forward
141
- img_sequence_length = int((512 / 8 / 2) ** 2)
142
- encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
143
- num_imgs = 4
144
- num_generated_imgs = 3
145
- num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
146
-
147
- def hacked_lora_forward(self, x, *args, **kwargs):
148
- if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
149
- return torch.cat((
150
- lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
151
- non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
152
- ), dim=1)
153
- elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
154
- return lora_forward(x, *args, **kwargs)
155
- elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
156
- return torch.cat((
157
- lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
158
- non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
159
- lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
160
- ), dim=1)
161
- elif x.shape[1] == 3072:
162
- return non_lora_forward(x, *args, **kwargs)
163
- else:
164
- raise ValueError(
165
- f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
-
168
- return hacked_lora_forward.__get__(module, type(module))
169
-
170
- for n, m in transformer.named_modules():
171
- if isinstance(m, peft.tuners.lora.layer.Linear):
172
- m.forward = create_hacked_forward(m)
173
-
174
- # load LoRA weights
175
- model_root = hf_hub_download(
176
- repo_id="Kunbyte/DRA-Ctrl",
177
- filename=f"{task}.safetensors",
178
- resume_download=True)
179
-
180
- try:
181
- with safe_open(model_root, framework="pt") as f:
182
- lora_weights = {}
183
- for k in f.keys():
184
- param = f.get_tensor(k)
185
- if k.endswith(".weight"):
186
- k = k.replace('.weight', '.default.weight')
187
- lora_weights[k] = param
188
- transformer.load_state_dict(lora_weights, strict=False)
189
- except Exception as e:
190
- raise ValueError(f'{e}')
191
-
192
- transformer.requires_grad_(False)
193
-
194
- # start generation
195
- c_txt = None if condition_image_prompt == "" else condition_image_prompt
196
- c_img = condition_image.resize((512, 512))
197
- t_txt = target_prompt
198
-
199
- if task not in ['subject_driven', 'style_transfer']:
200
- if task == "canny":
201
- def get_canny_edge(img):
202
- img_np = np.array(img)
203
- img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
204
- edges = cv2.Canny(img_gray, 100, 200)
205
- edges_tmp = Image.fromarray(edges).convert("RGB")
206
- edges[edges == 0] = 128
207
- return Image.fromarray(edges).convert("RGB")
208
- c_img = get_canny_edge(c_img)
209
- elif task == "coloring":
210
- c_img = (
211
- c_img.resize((512, 512))
212
- .convert("L")
213
- .convert("RGB")
214
- )
215
- elif task == "deblurring":
216
- blur_radius = 10
217
- c_img = (
218
- c_img.convert("RGB")
219
- .filter(ImageFilter.GaussianBlur(blur_radius))
220
- .resize((512, 512))
221
- .convert("RGB")
222
- )
223
- elif task == "depth":
224
- def get_depth_map(img):
225
- from transformers import pipeline
226
-
227
- depth_pipe = pipeline(
228
- task="depth-estimation",
229
- model="LiheYoung/depth-anything-small-hf",
230
- device="cpu",
231
  )
232
- return depth_pipe(img)["depth"].convert("RGB").resize((512, 512))
233
- c_img = get_depth_map(c_img)
234
- k = (255 - 128) / 255
235
- b = 128
236
- c_img = c_img.point(lambda x: k * x + b)
237
- elif task == "depth_pred":
238
- c_img = c_img
239
- elif task == "fill":
240
- c_img = c_img.resize((512, 512)).convert("RGB")
241
- x1, x2 = fill_x1, fill_x2
242
- y1, y2 = fill_y1, fill_y2
243
- mask = Image.new("L", (512, 512), 0)
244
- draw = ImageDraw.Draw(mask)
245
- draw.rectangle((x1, y1, x2, y2), fill=255)
246
- if inpainting:
247
- mask = Image.eval(mask, lambda a: 255 - a)
248
- c_img = Image.composite(
249
- c_img,
250
- Image.new("RGB", (512, 512), (255, 255, 255)),
251
- mask
252
- )
253
- c_img = Image.composite(
254
- c_img,
255
- Image.new("RGB", (512, 512), (128, 128, 128)),
256
- mask
257
- )
258
- elif task == "sr":
259
- c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB")
260
- c_img = c_img.resize((512, 512))
261
-
262
- gen_img = pipe(
263
- image=c_img,
264
- prompt=[t_txt.strip()],
265
- prompt_condition=[c_txt.strip()] if c_txt is not None else None,
266
- prompt_2=[t_txt],
267
- height=512,
268
- width=512,
269
- num_frames=5,
270
- num_inference_steps=num_steps,
271
- guidance_scale=6.0,
272
- num_videos_per_prompt=1,
273
- generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed),
274
- output_type='pt',
275
- image_embed_interleave=4,
276
- frame_gap=48,
277
- mixup=True,
278
- mixup_num_imgs=2,
279
- enhance_tp=task in ['subject_driven'],
280
- ).frames
281
-
282
- output_images = []
283
- for i in range(10):
284
- out = gen_img[:, i:i+1, :, :, :]
285
- out = out.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy()
286
- out = np.transpose(out, (1, 2, 0))
287
- out = (out * 255).astype(np.uint8)
288
- out = Image.fromarray(out)
289
- output_images.append(out)
290
-
291
- # video = [np.array(img.convert('RGB')) for img in output_images[1:] + [output_images[0]]]
292
- # video = np.stack(video, axis=0)
293
-
294
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
295
- video_path = f.name
296
- imageio.mimsave(video_path, output_images[1:]+[output_images[0]], fps=5)
297
-
298
- return output_images[0], video_path
299
 
300
  def get_samples():
301
  sample_list = [
@@ -452,6 +457,7 @@ def get_samples():
452
  def create_app():
453
  with gr.Blocks() as app:
454
  gr.Markdown(header, elem_id="header")
 
455
  with gr.Row(equal_height=False):
456
  with gr.Column(variant="panel", elem_classes="inputPanel"):
457
  condition_image = gr.Image(
 
11
  import gc
12
  import tempfile
13
  import imageio
14
+ import threading
15
  import gradio as gr
16
  import numpy as np
17
 
 
47
  The corresponding condition images will be automatically extracted.
48
  """
49
 
50
+ pipe_lock = threading.Lock()
51
+
52
 
53
  def init_basemodel():
54
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor, pipe, current_task
 
103
 
104
 
105
  @spaces.GPU
106
+ @gr.queue(concurrency_count=1)
107
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
108
  # set up the model
109
+ with pipe_lock:
110
+ global pipe, current_task, transformer
111
+ if current_task != task:
112
+ if current_task is None:
113
+ # insert LoRA
114
+ lora_config = LoraConfig(
115
+ r=16,
116
+ lora_alpha=16,
117
+ init_lora_weights="gaussian",
118
+ target_modules=[
119
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
120
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
121
+ 'ff.net.0.proj', 'ff.net.2',
122
+ 'ff_context.net.0.proj', 'ff_context.net.2',
123
+ 'norm1_context.linear', 'norm1.linear',
124
+ 'norm.linear', 'proj_mlp', 'proj_out',
125
+ ]
126
+ )
127
+ transformer.add_adapter(lora_config)
128
+ else:
129
+ def restore_forward(module):
130
+ def restored_forward(self, x, *args, **kwargs):
131
+ return module.original_forward(x, *args, **kwargs)
132
+ return restored_forward.__get__(module, type(module))
133
+
134
+ for n, m in transformer.named_modules():
135
+ if isinstance(m, peft.tuners.lora.layer.Linear):
136
+ m.forward = restore_forward(m)
137
+
138
+ current_task = task
139
+
140
+ # hack LoRA forward
141
+ def create_hacked_forward(module):
142
+ if not hasattr(module, 'original_forward'):
143
+ module.original_forward = module.forward
144
+ lora_forward = module.forward
145
+ non_lora_forward = module.base_layer.forward
146
+ img_sequence_length = int((512 / 8 / 2) ** 2)
147
+ encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
148
+ num_imgs = 4
149
+ num_generated_imgs = 3
150
+ num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
151
+
152
+ def hacked_lora_forward(self, x, *args, **kwargs):
153
+ if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
154
+ return torch.cat((
155
+ lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
156
+ non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
157
+ ), dim=1)
158
+ elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
159
+ return lora_forward(x, *args, **kwargs)
160
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
161
+ return torch.cat((
162
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
163
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
164
+ lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
165
+ ), dim=1)
166
+ elif x.shape[1] == 3072:
167
+ return non_lora_forward(x, *args, **kwargs)
168
+ else:
169
+ raise ValueError(
170
+ f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
171
+ )
172
+
173
+ return hacked_lora_forward.__get__(module, type(module))
174
+
175
  for n, m in transformer.named_modules():
176
  if isinstance(m, peft.tuners.lora.layer.Linear):
177
+ m.forward = create_hacked_forward(m)
178
+
179
+ # load LoRA weights
180
+ model_root = hf_hub_download(
181
+ repo_id="Kunbyte/DRA-Ctrl",
182
+ filename=f"{task}.safetensors",
183
+ resume_download=True)
184
+
185
+ try:
186
+ with safe_open(model_root, framework="pt") as f:
187
+ lora_weights = {}
188
+ for k in f.keys():
189
+ param = f.get_tensor(k)
190
+ if k.endswith(".weight"):
191
+ k = k.replace('.weight', '.default.weight')
192
+ lora_weights[k] = param
193
+ transformer.load_state_dict(lora_weights, strict=False)
194
+ except Exception as e:
195
+ raise ValueError(f'{e}')
196
+
197
+ transformer.requires_grad_(False)
198
+
199
+ # start generation
200
+ c_txt = None if condition_image_prompt == "" else condition_image_prompt
201
+ c_img = condition_image.resize((512, 512))
202
+ t_txt = target_prompt
203
+
204
+ if task not in ['subject_driven', 'style_transfer']:
205
+ if task == "canny":
206
+ def get_canny_edge(img):
207
+ img_np = np.array(img)
208
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
209
+ edges = cv2.Canny(img_gray, 100, 200)
210
+ edges_tmp = Image.fromarray(edges).convert("RGB")
211
+ edges[edges == 0] = 128
212
+ return Image.fromarray(edges).convert("RGB")
213
+ c_img = get_canny_edge(c_img)
214
+ elif task == "coloring":
215
+ c_img = (
216
+ c_img.resize((512, 512))
217
+ .convert("L")
218
+ .convert("RGB")
219
+ )
220
+ elif task == "deblurring":
221
+ blur_radius = 10
222
+ c_img = (
223
+ c_img.convert("RGB")
224
+ .filter(ImageFilter.GaussianBlur(blur_radius))
225
+ .resize((512, 512))
226
+ .convert("RGB")
227
+ )
228
+ elif task == "depth":
229
+ def get_depth_map(img):
230
+ from transformers import pipeline
231
+
232
+ depth_pipe = pipeline(
233
+ task="depth-estimation",
234
+ model="LiheYoung/depth-anything-small-hf",
235
+ device="cpu",
236
  )
237
+ return depth_pipe(img)["depth"].convert("RGB").resize((512, 512))
238
+ c_img = get_depth_map(c_img)
239
+ k = (255 - 128) / 255
240
+ b = 128
241
+ c_img = c_img.point(lambda x: k * x + b)
242
+ elif task == "depth_pred":
243
+ c_img = c_img
244
+ elif task == "fill":
245
+ c_img = c_img.resize((512, 512)).convert("RGB")
246
+ x1, x2 = fill_x1, fill_x2
247
+ y1, y2 = fill_y1, fill_y2
248
+ mask = Image.new("L", (512, 512), 0)
249
+ draw = ImageDraw.Draw(mask)
250
+ draw.rectangle((x1, y1, x2, y2), fill=255)
251
+ if inpainting:
252
+ mask = Image.eval(mask, lambda a: 255 - a)
253
+ c_img = Image.composite(
254
+ c_img,
255
+ Image.new("RGB", (512, 512), (255, 255, 255)),
256
+ mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
+ c_img = Image.composite(
259
+ c_img,
260
+ Image.new("RGB", (512, 512), (128, 128, 128)),
261
+ mask
262
+ )
263
+ elif task == "sr":
264
+ c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB")
265
+ c_img = c_img.resize((512, 512))
266
+
267
+ gen_img = pipe(
268
+ image=c_img,
269
+ prompt=[t_txt.strip()],
270
+ prompt_condition=[c_txt.strip()] if c_txt is not None else None,
271
+ prompt_2=[t_txt],
272
+ height=512,
273
+ width=512,
274
+ num_frames=5,
275
+ num_inference_steps=num_steps,
276
+ guidance_scale=6.0,
277
+ num_videos_per_prompt=1,
278
+ generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed),
279
+ output_type='pt',
280
+ image_embed_interleave=4,
281
+ frame_gap=48,
282
+ mixup=True,
283
+ mixup_num_imgs=2,
284
+ enhance_tp=task in ['subject_driven'],
285
+ ).frames
286
+
287
+ output_images = []
288
+ for i in range(10):
289
+ out = gen_img[:, i:i+1, :, :, :]
290
+ out = out.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy()
291
+ out = np.transpose(out, (1, 2, 0))
292
+ out = (out * 255).astype(np.uint8)
293
+ out = Image.fromarray(out)
294
+ output_images.append(out)
295
+
296
+ # video = [np.array(img.convert('RGB')) for img in output_images[1:] + [output_images[0]]]
297
+ # video = np.stack(video, axis=0)
298
+
299
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
300
+ video_path = f.name
301
+ imageio.mimsave(video_path, output_images[1:]+[output_images[0]], fps=5)
302
+
303
+ return output_images[0], video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  def get_samples():
306
  sample_list = [
 
457
  def create_app():
458
  with gr.Blocks() as app:
459
  gr.Markdown(header, elem_id="header")
460
+ gr.Markdown("🚦 To ensure stable model output, we are running the process in a single-threaded serial mode. If your request is queued, please wait patiently for the generation to complete.", elem_id="queue_notice")
461
  with gr.Row(equal_height=False):
462
  with gr.Column(variant="panel", elem_classes="inputPanel"):
463
  condition_image = gr.Image(