caohy666 commited on
Commit
88d26a2
·
1 Parent(s): 8851fb5

<fix> fix some bugs in app.py.

Browse files
Files changed (1) hide show
  1. app.py +124 -114
app.py CHANGED
@@ -45,9 +45,11 @@ there's no need to manually input edge maps, depth maps, or other condition imag
45
  The corresponding condition images will be automatically extracted.
46
  """
47
 
 
 
48
 
49
- def init_pipeline():
50
- global pipe
51
 
52
  # init models
53
  transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
@@ -78,101 +80,106 @@ def init_pipeline():
78
  vae.enable_tiling()
79
  vae.enable_slicing()
80
 
81
- # insert LoRA
82
- lora_config = LoraConfig(
83
- r=16,
84
- lora_alpha=16,
85
- init_lora_weights="gaussian",
86
- target_modules=[
87
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
88
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
89
- 'ff.net.0.proj', 'ff.net.2',
90
- 'ff_context.net.0.proj', 'ff_context.net.2',
91
- 'norm1_context.linear', 'norm1.linear',
92
- 'norm.linear', 'proj_mlp', 'proj_out',
93
- ]
94
- )
95
- transformer.add_adapter(lora_config)
96
-
97
- # hack LoRA forward
98
- def create_hacked_forward(module):
99
- lora_forward = module.forward
100
- non_lora_forward = module.base_layer.forward
101
- img_sequence_length = int((args.img_size / 8 / 2) ** 2)
102
- encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
103
- num_imgs = 4
104
- num_generated_imgs = 3
105
- num_encoder_sequences = 2 if args.task in ['subject_driven', 'style_transfer'] else 1
106
-
107
- def hacked_lora_forward(self, x, *args, **kwargs):
108
- if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
109
- return torch.cat((
110
- lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
111
- non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
112
- ), dim=1)
113
- elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
114
- return lora_forward(x, *args, **kwargs)
115
- elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
116
- return torch.cat((
117
- lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
118
- non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
119
- lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
120
- ), dim=1)
121
- elif x.shape[1] == 3072:
122
- return non_lora_forward(x, *args, **kwargs)
123
- else:
124
- raise ValueError(
125
- f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
126
- )
127
-
128
- return hacked_lora_forward.__get__(module, type(module))
129
-
130
- for n, m in transformer.named_modules():
131
- if isinstance(m, peft.tuners.lora.layer.Linear):
132
- m.forward = create_hacked_forward(m)
133
-
134
- # load LoRA weights
135
- model_root = hf_hub_download(
136
- repo_id="Kunbyte/DRA-Ctrl",
137
- filename=f"{task}.safetensors",
138
- resume_download=True)
139
-
140
- try:
141
- with safe_open(model_root, framework="pt") as f:
142
- lora_weights = {}
143
- for k in f.keys():
144
- param = f.get_tensor(k)
145
- if k.endswith(".weight"):
146
- k = k.replace('.weight', '.default.weight')
147
- lora_weights[k] = param
148
- transformer.load_state_dict(lora_weights, strict=False)
149
- except Exception as e:
150
- raise ValueError(f'{e}')
151
-
152
- transformer.requires_grad_(False)
153
-
154
- pipe = HunyuanVideoImageToVideoPipeline(
155
- text_encoder=text_encoder,
156
- tokenizer=tokenizer,
157
- transformer=transformer,
158
- vae=vae,
159
- scheduler=copy.deepcopy(scheduler),
160
- text_encoder_2=text_encoder_2,
161
- tokenizer_2=tokenizer_2,
162
- image_processor=image_processor,
163
- )
164
-
165
 
166
  @spaces.GPU
167
- def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # start generation
170
  c_txt = None if condition_image_prompt == "" else condition_image_prompt
171
  c_img = condition_image.resize((512, 512))
172
  t_txt = target_prompt
173
 
174
- if args.task not in ['subject_driven', 'style_transfer']:
175
- if args.task == "canny":
176
  def get_canny_edge(img):
177
  img_np = np.array(img)
178
  img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
@@ -182,21 +189,21 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
182
  edges[edges == 0] = 128
183
  return Image.fromarray(edges).convert("RGB")
184
  c_img = get_canny_edge(c_img)
185
- elif args.task == "coloring":
186
  c_img = (
187
- c_img.resize((args.img_size, args.img_size))
188
  .convert("L")
189
  .convert("RGB")
190
  )
191
- elif args.task == "deblurring":
192
  blur_radius = 10
193
  c_img = (
194
  c_img.convert("RGB")
195
  .filter(ImageFilter.GaussianBlur(blur_radius))
196
- .resize((args.img_size, args.img_size))
197
  .convert("RGB")
198
  )
199
- elif args.task == "depth":
200
  def get_depth_map(img):
201
  from transformers import pipeline
202
 
@@ -205,43 +212,40 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
205
  model="LiheYoung/depth-anything-small-hf",
206
  device="cpu",
207
  )
208
- return depth_pipe(img)["depth"].convert("RGB").resize((args.img_size, args.img_size))
209
  c_img = get_depth_map(c_img)
210
  c_img.save(os.path.join(save_dir, f"depth.png"))
211
  k = (255 - 128) / 255
212
  b = 128
213
  c_img = c_img.point(lambda x: k * x + b)
214
- elif args.task == "depth_pred":
215
  c_img = c_img
216
- elif args.task == "fill":
217
- c_img = c_img.resize((args.img_size, args.img_size)).convert("RGB")
218
- x1, x2 = args.fill_x1, args.fill_x2
219
- y1, y2 = args.fill_y1, args.fill_y2
220
- mask = Image.new("L", (args.img_size, args.img_size), 0)
221
  draw = ImageDraw.Draw(mask)
222
  draw.rectangle((x1, y1, x2, y2), fill=255)
223
- if args.inpainting:
224
  mask = Image.eval(mask, lambda a: 255 - a)
225
  c_img = Image.composite(
226
  c_img,
227
- Image.new("RGB", (args.img_size, args.img_size), (255, 255, 255)),
228
  mask
229
  )
230
  c_img.save(os.path.join(save_dir, f"mask.png"))
231
  c_img = Image.composite(
232
  c_img,
233
- Image.new("RGB", (args.img_size, args.img_size), (128, 128, 128)),
234
  mask
235
  )
236
- elif args.task == "sr":
237
- c_img = c_img.resize((int(args.img_size / 4), int(args.img_size / 4))).convert("RGB")
238
  c_img.save(os.path.join(save_dir, f"low_resolution.png"))
239
- c_img = c_img.resize((args.img_size, args.img_size))
240
  c_img.save(os.path.join(save_dir, f"low_to_high.png"))
241
 
242
- if pipe is None:
243
- init_pipeline()
244
-
245
  gen_img = pipe(
246
  image=c_img,
247
  prompt=[t_txt.strip()],
@@ -253,7 +257,7 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
253
  num_inference_steps=50,
254
  guidance_scale=6.0,
255
  num_videos_per_prompt=1,
256
- generator=torch.Generator(device=pipe.transformer.device).manual_seed(0),
257
  output_type='pt',
258
  image_embed_interleave=4,
259
  frame_gap=48,
@@ -295,8 +299,14 @@ def create_app():
295
  elem_id="task_selection"
296
  )
297
  gr.Markdown(notice, elem_id="notice")
298
- target_prompt = gr.Textbox(lines=2, label="Target Prompt", elem_id="text")
299
- condition_image_prompt = gr.Textbox(lines=2, label="Condition Image Prompt", elem_id="text")
 
 
 
 
 
 
300
  submit_btn = gr.Button("Run", elem_id="submit_btn")
301
 
302
  with gr.Column(variant="panel", elem_classes="outputPanel"):
@@ -304,7 +314,7 @@ def create_app():
304
 
305
  submit_btn.click(
306
  fn=process_image_and_text,
307
- inputs=[condition_image, target_prompt, condition_image_prompt, task],
308
  outputs=output_image,
309
  )
310
 
 
45
  The corresponding condition images will be automatically extracted.
46
  """
47
 
48
+ pipe = None
49
+ current_task = None
50
 
51
+ def init_basemodel():
52
+ global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor
53
 
54
  # init models
55
  transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V',
 
80
  vae.enable_tiling()
81
  vae.enable_slicing()
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  @spaces.GPU
85
+ def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
86
+ # set up models
87
+ required_models = [transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor]
88
+ if any(model is None for model in required_models):
89
+ init_basemodel()
90
+
91
+ if pipe is None or current_task != task:
92
+ # insert LoRA
93
+ lora_config = LoraConfig(
94
+ r=16,
95
+ lora_alpha=16,
96
+ init_lora_weights="gaussian",
97
+ target_modules=[
98
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
99
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
100
+ 'ff.net.0.proj', 'ff.net.2',
101
+ 'ff_context.net.0.proj', 'ff_context.net.2',
102
+ 'norm1_context.linear', 'norm1.linear',
103
+ 'norm.linear', 'proj_mlp', 'proj_out',
104
+ ]
105
+ )
106
+ transformer.add_adapter(lora_config)
107
+
108
+ # hack LoRA forward
109
+ def create_hacked_forward(module):
110
+ lora_forward = module.forward
111
+ non_lora_forward = module.base_layer.forward
112
+ img_sequence_length = int((512 / 8 / 2) ** 2)
113
+ encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
114
+ num_imgs = 4
115
+ num_generated_imgs = 3
116
+ num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
117
+
118
+ def hacked_lora_forward(self, x, *args, **kwargs):
119
+ if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
120
+ return torch.cat((
121
+ lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
122
+ non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
123
+ ), dim=1)
124
+ elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
125
+ return lora_forward(x, *args, **kwargs)
126
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
127
+ return torch.cat((
128
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
129
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
130
+ lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
131
+ ), dim=1)
132
+ elif x.shape[1] == 3072:
133
+ return non_lora_forward(x, *args, **kwargs)
134
+ else:
135
+ raise ValueError(
136
+ f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
137
+ )
138
+
139
+ return hacked_lora_forward.__get__(module, type(module))
140
+
141
+ for n, m in transformer.named_modules():
142
+ if isinstance(m, peft.tuners.lora.layer.Linear):
143
+ m.forward = create_hacked_forward(m)
144
+
145
+ # load LoRA weights
146
+ model_root = hf_hub_download(
147
+ repo_id="Kunbyte/DRA-Ctrl",
148
+ filename=f"{task}.safetensors",
149
+ resume_download=True)
150
+
151
+ try:
152
+ with safe_open(model_root, framework="pt") as f:
153
+ lora_weights = {}
154
+ for k in f.keys():
155
+ param = f.get_tensor(k)
156
+ if k.endswith(".weight"):
157
+ k = k.replace('.weight', '.default.weight')
158
+ lora_weights[k] = param
159
+ transformer.load_state_dict(lora_weights, strict=False)
160
+ except Exception as e:
161
+ raise ValueError(f'{e}')
162
+
163
+ transformer.requires_grad_(False)
164
+
165
+ pipe = HunyuanVideoImageToVideoPipeline(
166
+ text_encoder=text_encoder,
167
+ tokenizer=tokenizer,
168
+ transformer=transformer,
169
+ vae=vae,
170
+ scheduler=copy.deepcopy(scheduler),
171
+ text_encoder_2=text_encoder_2,
172
+ tokenizer_2=tokenizer_2,
173
+ image_processor=image_processor,
174
+ )
175
 
176
  # start generation
177
  c_txt = None if condition_image_prompt == "" else condition_image_prompt
178
  c_img = condition_image.resize((512, 512))
179
  t_txt = target_prompt
180
 
181
+ if task not in ['subject_driven', 'style_transfer']:
182
+ if task == "canny":
183
  def get_canny_edge(img):
184
  img_np = np.array(img)
185
  img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
 
189
  edges[edges == 0] = 128
190
  return Image.fromarray(edges).convert("RGB")
191
  c_img = get_canny_edge(c_img)
192
+ elif task == "coloring":
193
  c_img = (
194
+ c_img.resize((512, 512))
195
  .convert("L")
196
  .convert("RGB")
197
  )
198
+ elif task == "deblurring":
199
  blur_radius = 10
200
  c_img = (
201
  c_img.convert("RGB")
202
  .filter(ImageFilter.GaussianBlur(blur_radius))
203
+ .resize((512, 512))
204
  .convert("RGB")
205
  )
206
+ elif task == "depth":
207
  def get_depth_map(img):
208
  from transformers import pipeline
209
 
 
212
  model="LiheYoung/depth-anything-small-hf",
213
  device="cpu",
214
  )
215
+ return depth_pipe(img)["depth"].convert("RGB").resize((512, 512))
216
  c_img = get_depth_map(c_img)
217
  c_img.save(os.path.join(save_dir, f"depth.png"))
218
  k = (255 - 128) / 255
219
  b = 128
220
  c_img = c_img.point(lambda x: k * x + b)
221
+ elif task == "depth_pred":
222
  c_img = c_img
223
+ elif task == "fill":
224
+ c_img = c_img.resize((512, 512)).convert("RGB")
225
+ x1, x2 = fill_x1, fill_x2
226
+ y1, y2 = fill_y1, fill_y2
227
+ mask = Image.new("L", (512, 512), 0)
228
  draw = ImageDraw.Draw(mask)
229
  draw.rectangle((x1, y1, x2, y2), fill=255)
230
+ if inpainting:
231
  mask = Image.eval(mask, lambda a: 255 - a)
232
  c_img = Image.composite(
233
  c_img,
234
+ Image.new("RGB", (512, 512), (255, 255, 255)),
235
  mask
236
  )
237
  c_img.save(os.path.join(save_dir, f"mask.png"))
238
  c_img = Image.composite(
239
  c_img,
240
+ Image.new("RGB", (512, 512), (128, 128, 128)),
241
  mask
242
  )
243
+ elif task == "sr":
244
+ c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB")
245
  c_img.save(os.path.join(save_dir, f"low_resolution.png"))
246
+ c_img = c_img.resize((512, 512))
247
  c_img.save(os.path.join(save_dir, f"low_to_high.png"))
248
 
 
 
 
249
  gen_img = pipe(
250
  image=c_img,
251
  prompt=[t_txt.strip()],
 
257
  num_inference_steps=50,
258
  guidance_scale=6.0,
259
  num_videos_per_prompt=1,
260
+ generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed),
261
  output_type='pt',
262
  image_embed_interleave=4,
263
  frame_gap=48,
 
299
  elem_id="task_selection"
300
  )
301
  gr.Markdown(notice, elem_id="notice")
302
+ target_prompt = gr.Textbox(lines=2, label="Target Prompt", elem_id="tp")
303
+ condition_image_prompt = gr.Textbox(lines=2, label="Condition Image Prompt", elem_id="cp")
304
+ random_seed = gr.Number(label="Random Seed", , precision=0, value=0, elem_id="seed")
305
+ inpainting = gr.Checkbox(label="Inpainting", value=False, elem_id="inpainting")
306
+ fill_x1 = gr.Number(label="In/Out-painting Box Left Boundary", precision=0, value=128, elem_id="fill_x1")
307
+ fill_x2 = gr.Number(label="In/Out-painting Box Right Boundary", precision=0, value=384, elem_id="fill_x2")
308
+ fill_y1 = gr.Number(label="In/Out-painting Box Top Boundary", precision=0, value=128, elem_id="fill_y1")
309
+ fill_y2 = gr.Number(label="In/Out-painting Box Bottom Boundary", precision=0, value=384, elem_id="fill_y2")
310
  submit_btn = gr.Button("Run", elem_id="submit_btn")
311
 
312
  with gr.Column(variant="panel", elem_classes="outputPanel"):
 
314
 
315
  submit_btn.click(
316
  fn=process_image_and_text,
317
+ inputs=[condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2],
318
  outputs=output_image,
319
  )
320