JiantaoLin
commited on
Commit
·
844009d
1
Parent(s):
af53f48
- pipeline/kiss3d_wrapper.py +15 -15
pipeline/kiss3d_wrapper.py
CHANGED
@@ -75,8 +75,8 @@ def init_wrapper_from_config(config_path):
|
|
75 |
else:
|
76 |
flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
77 |
# flux_pipe.enable_vae_slicing()
|
78 |
-
|
79 |
-
flux_pipe.enable_sequential_cpu_offload()
|
80 |
# load flux model and controlnet
|
81 |
if flux_controlnet_pth is not None and False:
|
82 |
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
|
@@ -139,20 +139,20 @@ def init_wrapper_from_config(config_path):
|
|
139 |
caption_model = None
|
140 |
|
141 |
# load reconstruction model
|
142 |
-
logger.info('==> Loading reconstruction model ...')
|
143 |
-
recon_device = config_['reconstruction'].get('device', 'cpu')
|
144 |
-
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
145 |
-
recon_model = instantiate_from_config(recon_model_config.model_config)
|
146 |
-
# load recon model checkpoint
|
147 |
-
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
148 |
-
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
149 |
-
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
150 |
-
recon_model.load_state_dict(state_dict, strict=True)
|
151 |
-
recon_model.to(recon_device)
|
152 |
-
recon_model.eval()
|
153 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
154 |
-
|
155 |
-
|
156 |
# load llm
|
157 |
llm_configs = config_.get('llm', None)
|
158 |
if llm_configs is not None and False:
|
|
|
75 |
else:
|
76 |
flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
|
77 |
# flux_pipe.enable_vae_slicing()
|
78 |
+
flux_pipe.enable_xformers_memory_efficient_attention()
|
79 |
+
# flux_pipe.enable_sequential_cpu_offload()
|
80 |
# load flux model and controlnet
|
81 |
if flux_controlnet_pth is not None and False:
|
82 |
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
|
|
|
139 |
caption_model = None
|
140 |
|
141 |
# load reconstruction model
|
142 |
+
# logger.info('==> Loading reconstruction model ...')
|
143 |
+
# recon_device = config_['reconstruction'].get('device', 'cpu')
|
144 |
+
# recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
145 |
+
# recon_model = instantiate_from_config(recon_model_config.model_config)
|
146 |
+
# # load recon model checkpoint
|
147 |
+
# model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
148 |
+
# state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
149 |
+
# state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
150 |
+
# recon_model.load_state_dict(state_dict, strict=True)
|
151 |
+
# recon_model.to(recon_device)
|
152 |
+
# recon_model.eval()
|
153 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
154 |
+
recon_model = None
|
155 |
+
recon_model_config = None
|
156 |
# load llm
|
157 |
llm_configs = config_.get('llm', None)
|
158 |
if llm_configs is not None and False:
|