JiantaoLin commited on
Commit
ebe241c
Β·
1 Parent(s): 235efa3
Files changed (2) hide show
  1. app.py +1 -1
  2. pipeline/kiss3d_wrapper.py +37 -40
app.py CHANGED
@@ -421,7 +421,7 @@ with gr.Blocks(css="""
421
  # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
422
 
423
  btn_gen_mesh = gr.Button("Generate Mesh")
424
- output_video1 = gr.Video(label="Generated Video", interactive=False, loop=True, autoplay=True)
425
  # btn_download1 = gr.Button("Download Mesh")
426
 
427
 
 
421
  # reconstruction_stage2_steps = gr.Number(value=50, label="reconstruction_stage2_steps")
422
 
423
  btn_gen_mesh = gr.Button("Generate Mesh")
424
+ output_video1 = gr.Video(label="Render Video", interactive=False, loop=True, autoplay=True)
425
  # btn_download1 = gr.Button("Download Mesh")
426
 
427
 
pipeline/kiss3d_wrapper.py CHANGED
@@ -74,15 +74,11 @@ def init_wrapper_from_config(config_path):
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
- # flux_pipe.enable_vae_slicing()
78
- # flux_pipe.enable_vae_tiling()
79
- # flux_pipe.vae = taef1
80
- flux_pipe.vae.enable_slicing() # ε€šζ‰Ήζ¬‘η”Ÿε›ΎδΌ˜εŒ–
81
  flux_pipe.vae.enable_tiling()
82
 
83
- # flux_pipe.enable_sequential_cpu_offload()
84
  # load flux model and controlnet
85
- if flux_controlnet_pth is not None and False:
86
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
87
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
88
 
@@ -91,57 +87,55 @@ def init_wrapper_from_config(config_path):
91
  # load lora weights
92
  flux_pipe.load_lora_weights(flux_lora_pth)
93
  # flux_pipe.to(device=flux_device)
94
- # flux_pipe.enable_model_cpu_offload(device=flux_device)
95
- # flux_pipe = None
96
 
97
  # load redux model
98
  flux_redux_pipe = None
99
- if flux_redux_pth is not None and False:
100
  flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16, token=access_token)
101
  flux_redux_pipe.text_encoder = flux_pipe.text_encoder
102
  flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
103
  flux_redux_pipe.tokenizer = flux_pipe.tokenizer
104
  flux_redux_pipe.tokenizer_2 = flux_pipe.tokenizer_2
105
 
106
- flux_redux_pipe.to(device=flux_device)
107
 
108
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
109
 
110
  # TODO: load pulid model
111
 
112
  # init multiview model
113
- # logger.info('==> Loading multiview diffusion model ...')
114
- # multiview_device = config_['multiview'].get('device', 'cpu')
115
- # multiview_pipeline = DiffusionPipeline.from_pretrained(
116
- # config_['multiview']['base_model'],
117
- # custom_pipeline=config_['multiview']['custom_pipeline'],
118
- # torch_dtype=torch.float16,
119
- # )
120
- # multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
121
- # multiview_pipeline.scheduler.config, timestep_spacing='trailing'
122
- # )
123
 
124
- # # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
125
- # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
126
- # if unet_ckpt_path is not None:
127
- # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
128
- # # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
129
- # multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
130
 
131
  # multiview_pipeline.to(multiview_device)
132
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
133
- multiview_pipeline = None
134
 
135
 
136
  # load caption model
137
- # logger.info('==> Loading caption model ...')
138
- # caption_device = config_['caption'].get('device', 'cpu')
139
- # caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
140
- # torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
141
- # caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
142
  # logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
143
- caption_processor = None
144
- caption_model = None
145
 
146
  # load reconstruction model
147
  logger.info('==> Loading reconstruction model ...')
@@ -156,8 +150,7 @@ def init_wrapper_from_config(config_path):
156
  recon_model.to(recon_device)
157
  recon_model.eval()
158
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
159
- # recon_model = None
160
- # recon_model_config = None
161
  # load llm
162
  llm_configs = config_.get('llm', None)
163
  if llm_configs is not None:
@@ -242,7 +235,7 @@ class kiss3d_wrapper(object):
242
  """
243
  torch_dtype = torch.bfloat16
244
  caption_device = self.config['caption'].get('device', 'cpu')
245
-
246
  if isinstance(image, str): # If image is a file path
247
  image = preprocess_input_image(Image.open(image))
248
  elif not isinstance(image, Image.Image):
@@ -264,7 +257,7 @@ class kiss3d_wrapper(object):
264
  logger.info(f"Auto caption result: \"{caption_text}\"")
265
 
266
  caption_text = self.get_detailed_prompt(caption_text)
267
-
268
  return caption_text
269
  # @spaces.GPU
270
  def get_detailed_prompt(self, prompt, seed=None):
@@ -290,7 +283,7 @@ class kiss3d_wrapper(object):
290
  def generate_multiview(self, image, seed=None, num_inference_steps=None):
291
  seed = seed or self.config['multiview'].get('seed', 0)
292
  mv_device = self.config['multiview'].get('device', 'cpu')
293
-
294
  generator = torch.Generator(device=mv_device).manual_seed(seed)
295
  with self.context():
296
  mv_image = self.multiview_pipeline(image,
@@ -298,6 +291,7 @@ class kiss3d_wrapper(object):
298
  width=512*2,
299
  height=512*2,
300
  generator=generator).images[0]
 
301
  return mv_image
302
 
303
  def reconstruct_from_multiview(self, mv_image, lrm_render_radius=4.15):
@@ -375,6 +369,7 @@ class kiss3d_wrapper(object):
375
  } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
376
 
377
  flux_device = self.config['flux'].get('device', 'cpu')
 
378
  seed = seed or self.config['flux'].get('seed', 0)
379
  num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
380
 
@@ -401,6 +396,7 @@ class kiss3d_wrapper(object):
401
 
402
  # do redux
403
  if redux_hparam is not None:
 
404
  assert self.flux_redux_pipeline is not None
405
  assert 'image' in redux_hparam.keys()
406
  redux_hparam_ = {
@@ -413,6 +409,7 @@ class kiss3d_wrapper(object):
413
  redux_output = self.flux_redux_pipeline(**redux_hparam_)
414
 
415
  hparam_dict.update(redux_output)
 
416
 
417
  # append controlnet hparams
418
  if len(control_image) > 0:
@@ -442,7 +439,7 @@ class kiss3d_wrapper(object):
442
  torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
443
  logger.info(f"Save generated 3D bundle image to {save_path}")
444
  return gen_3d_bundle_image_, save_path
445
-
446
  return gen_3d_bundle_image_
447
 
448
  def preprocess_controlnet_cond_image(self, image, control_mode, save_intermediate_results=True, **kwargs):
 
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
+ flux_pipe.vae.enable_slicing()
 
 
 
78
  flux_pipe.vae.enable_tiling()
79
 
 
80
  # load flux model and controlnet
81
+ if flux_controlnet_pth is not None:
82
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
83
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
84
 
 
87
  # load lora weights
88
  flux_pipe.load_lora_weights(flux_lora_pth)
89
  # flux_pipe.to(device=flux_device)
 
 
90
 
91
  # load redux model
92
  flux_redux_pipe = None
93
+ if flux_redux_pth is not None:
94
  flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16, token=access_token)
95
  flux_redux_pipe.text_encoder = flux_pipe.text_encoder
96
  flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
97
  flux_redux_pipe.tokenizer = flux_pipe.tokenizer
98
  flux_redux_pipe.tokenizer_2 = flux_pipe.tokenizer_2
99
 
100
+ # flux_redux_pipe.to(device=flux_device)
101
 
102
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
103
 
104
  # TODO: load pulid model
105
 
106
  # init multiview model
107
+ logger.info('==> Loading multiview diffusion model ...')
108
+ multiview_device = config_['multiview'].get('device', 'cpu')
109
+ multiview_pipeline = DiffusionPipeline.from_pretrained(
110
+ config_['multiview']['base_model'],
111
+ custom_pipeline=config_['multiview']['custom_pipeline'],
112
+ torch_dtype=torch.float16,
113
+ )
114
+ multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
115
+ multiview_pipeline.scheduler.config, timestep_spacing='trailing'
116
+ )
117
 
118
+ # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
119
+ unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
120
+ if unet_ckpt_path is not None:
121
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
122
+ # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
123
+ multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
124
 
125
  # multiview_pipeline.to(multiview_device)
126
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
127
+ # multiview_pipeline = None
128
 
129
 
130
  # load caption model
131
+ logger.info('==> Loading caption model ...')
132
+ caption_device = config_['caption'].get('device', 'cpu')
133
+ caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
134
+ torch_dtype=torch.bfloat16, trust_remote_code=True)
135
+ caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
136
  # logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
137
+ # caption_processor = None
138
+ # caption_model = None
139
 
140
  # load reconstruction model
141
  logger.info('==> Loading reconstruction model ...')
 
150
  recon_model.to(recon_device)
151
  recon_model.eval()
152
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
153
+
 
154
  # load llm
155
  llm_configs = config_.get('llm', None)
156
  if llm_configs is not None:
 
235
  """
236
  torch_dtype = torch.bfloat16
237
  caption_device = self.config['caption'].get('device', 'cpu')
238
+ self.caption_model.to(caption_device)
239
  if isinstance(image, str): # If image is a file path
240
  image = preprocess_input_image(Image.open(image))
241
  elif not isinstance(image, Image.Image):
 
257
  logger.info(f"Auto caption result: \"{caption_text}\"")
258
 
259
  caption_text = self.get_detailed_prompt(caption_text)
260
+ self.caption_model.to('cpu')
261
  return caption_text
262
  # @spaces.GPU
263
  def get_detailed_prompt(self, prompt, seed=None):
 
283
  def generate_multiview(self, image, seed=None, num_inference_steps=None):
284
  seed = seed or self.config['multiview'].get('seed', 0)
285
  mv_device = self.config['multiview'].get('device', 'cpu')
286
+ self.multiview_pipeline.to(mv_device)
287
  generator = torch.Generator(device=mv_device).manual_seed(seed)
288
  with self.context():
289
  mv_image = self.multiview_pipeline(image,
 
291
  width=512*2,
292
  height=512*2,
293
  generator=generator).images[0]
294
+ self.multiview_pipeline.to('cpu')
295
  return mv_image
296
 
297
  def reconstruct_from_multiview(self, mv_image, lrm_render_radius=4.15):
 
369
  } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
370
 
371
  flux_device = self.config['flux'].get('device', 'cpu')
372
+ self.flux_pipeline.to(flux_device)
373
  seed = seed or self.config['flux'].get('seed', 0)
374
  num_inference_steps = num_inference_steps or self.config['flux'].get('num_inference_steps', 20)
375
 
 
396
 
397
  # do redux
398
  if redux_hparam is not None:
399
+ self.flux_redux_pipeline.to(flux_device)
400
  assert self.flux_redux_pipeline is not None
401
  assert 'image' in redux_hparam.keys()
402
  redux_hparam_ = {
 
409
  redux_output = self.flux_redux_pipeline(**redux_hparam_)
410
 
411
  hparam_dict.update(redux_output)
412
+ self.flux_redux_pipeline.to('cpu')
413
 
414
  # append controlnet hparams
415
  if len(control_image) > 0:
 
439
  torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
440
  logger.info(f"Save generated 3D bundle image to {save_path}")
441
  return gen_3d_bundle_image_, save_path
442
+ self.flux_pipeline.to('cpu')
443
  return gen_3d_bundle_image_
444
 
445
  def preprocess_controlnet_cond_image(self, image, control_mode, save_intermediate_results=True, **kwargs):