JiantaoLin commited on
Commit
844009d
·
1 Parent(s): af53f48
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +15 -15
pipeline/kiss3d_wrapper.py CHANGED
@@ -75,8 +75,8 @@ def init_wrapper_from_config(config_path):
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
  # flux_pipe.enable_vae_slicing()
78
- # flux_pipe.enable_xformers_memory_efficient_attention()
79
- flux_pipe.enable_sequential_cpu_offload()
80
  # load flux model and controlnet
81
  if flux_controlnet_pth is not None and False:
82
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
@@ -139,20 +139,20 @@ def init_wrapper_from_config(config_path):
139
  caption_model = None
140
 
141
  # load reconstruction model
142
- logger.info('==> Loading reconstruction model ...')
143
- recon_device = config_['reconstruction'].get('device', 'cpu')
144
- recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
145
- recon_model = instantiate_from_config(recon_model_config.model_config)
146
- # load recon model checkpoint
147
- model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
148
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
149
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
150
- recon_model.load_state_dict(state_dict, strict=True)
151
- recon_model.to(recon_device)
152
- recon_model.eval()
153
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
154
- # recon_model = None
155
- # recon_model_config = None
156
  # load llm
157
  llm_configs = config_.get('llm', None)
158
  if llm_configs is not None and False:
 
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
  # flux_pipe.enable_vae_slicing()
78
+ flux_pipe.enable_xformers_memory_efficient_attention()
79
+ # flux_pipe.enable_sequential_cpu_offload()
80
  # load flux model and controlnet
81
  if flux_controlnet_pth is not None and False:
82
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
 
139
  caption_model = None
140
 
141
  # load reconstruction model
142
+ # logger.info('==> Loading reconstruction model ...')
143
+ # recon_device = config_['reconstruction'].get('device', 'cpu')
144
+ # recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
145
+ # recon_model = instantiate_from_config(recon_model_config.model_config)
146
+ # # load recon model checkpoint
147
+ # model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
148
+ # state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
149
+ # state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
150
+ # recon_model.load_state_dict(state_dict, strict=True)
151
+ # recon_model.to(recon_device)
152
+ # recon_model.eval()
153
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
154
+ recon_model = None
155
+ recon_model_config = None
156
  # load llm
157
  llm_configs = config_.get('llm', None)
158
  if llm_configs is not None and False: