# import os
# from pathlib import Path
# from PIL import Image
# import onnx
# import onnx_graphsurgeon as gs
# import torch
# from onnx import shape_inference
# from packaging import version
# from polygraphy.backend.onnx.loader import fold_constants
# from torch.onnx import export

# from transformers import CLIPVisionModelWithProjection
# from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

# is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
# is_torch_2_0_1 = version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.1")

# class Optimizer:
#     def __init__(self, onnx_graph, verbose=False):
#         self.graph = gs.import_onnx(onnx_graph)
#         self.verbose = verbose

#     def info(self, prefix):
#         if self.verbose:
#             print(
#                 f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs"
#             )

#     def cleanup(self, return_onnx=False):
#         self.graph.cleanup().toposort()
#         if return_onnx:
#             return gs.export_onnx(self.graph)

#     def select_outputs(self, keep, names=None):
#         self.graph.outputs = [self.graph.outputs[o] for o in keep]
#         if names:
#             for i, name in enumerate(names):
#                 self.graph.outputs[i].name = name

#     def fold_constants(self, return_onnx=False):
#         onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
#         self.graph = gs.import_onnx(onnx_graph)
#         if return_onnx:
#             return onnx_graph

#     def infer_shapes(self, return_onnx=False):
#         onnx_graph = gs.export_onnx(self.graph)
#         if onnx_graph.ByteSize() > 4147483648:
#             raise TypeError("ERROR: model size exceeds supported 2GB limit")
#         else:
#             onnx_graph = shape_inference.infer_shapes(onnx_graph)

#         self.graph = gs.import_onnx(onnx_graph)
#         if return_onnx:
#             return onnx_graph


# def optimize(onnx_graph, name, verbose):
#     opt = Optimizer(onnx_graph, verbose=verbose)
#     opt.info(name + ": original")
#     opt.cleanup()
#     opt.info(name + ": cleanup")
#     opt.fold_constants()
#     opt.info(name + ": fold constants")
#     # opt.infer_shapes()
#     # opt.info(name + ': shape inference')
#     onnx_opt_graph = opt.cleanup(return_onnx=True)
#     opt.info(name + ": finished")
#     return onnx_opt_graph


# class CLIPVisionProj(torch.nn.Module):
#     def __init__(self, clip_model) -> None:
#         super().__init__()
#         self.clip_model = clip_model

#     def forward(self, image_embedding):
#         result = self.clip_model(image_embedding,return_dict = False)
#         return result[0]
    
# def onnx_export(
#     model,
#     model_args: tuple,
#     output_path: Path,
#     ordered_input_names,
#     output_names,
#     dynamic_axes,
#     opset: int,
#     use_external_data_format=False,
#     verbose=False,  # Thêm tham số verbose
# ):
#     output_path.parent.mkdir(parents=True, exist_ok=True)
#     with torch.inference_mode(), torch.autocast("cuda"):
#         if is_torch_less_than_1_11:
#             export(
#                 model,
#                 model_args,
#                 f=output_path.as_posix(),
#                 input_names=ordered_input_names,
#                 output_names=output_names,
#                 dynamic_axes=dynamic_axes,
#                 do_constant_folding=True,
#                 use_external_data_format=use_external_data_format,
#                 enable_onnx_checker=True,
#                 opset_version=opset,
#                 verbose=verbose,  # Thêm verbose ở đây
#             )
#         else:
#             export(
#                 model,
#                 model_args,
#                 f=output_path.as_posix(),
#                 input_names=ordered_input_names,
#                 output_names=output_names,
#                 dynamic_axes=dynamic_axes,
#                 do_constant_folding=True,
#                 opset_version=opset,
#                 verbose=verbose,  # Thêm verbose ở đây
#             )

# def convert_models(
#     image_path:str, 
#     output_path:str,
#     opset:int=16,
# ):
#         dtype =  torch.float32
#         device = 'cpu'
#         image = Image.open(image_path)
#         image_encoder_processor = CLIPImageProcessor()
#         image_embedding = image_encoder_processor(image, return_tensors="pt").pixel_values
#         clip_model= CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder = 'sdxl_models/image_encoder')
#         image_encoder = CLIPVisionProj(clip_model).to(device=device)
#         output_path = Path(output_path)

#         clip_path = output_path / "clip_vision_proj" / "model.onnx"
#         clip_optimize = output_path / 'clip_vision_proj' / 'optimize' / 'model.onnx'
#         #create folder for optimize clip
#         os.makedirs(output_path / 'optimize', exist_ok= True)
#         onnx_export(image_encoder,
#                     model_args= (image_embedding).to(dtype = torch.float32, device = device),
#                     output_path =clip_path,
#                     ordered_input_names= ['image_embedding'],
#                     output_names=["image_encoder"],
#                     dynamic_axes={'image_embedding': {0: 'Batch_size', 1: 'channel', 2: 'height', 3:'width'},
#                                  'image_encoder': {0:'Batch_size', 1: 'sequence_length'} },
#                     opset=opset,
#                     verbose=True,
#                     use_external_data_format=True, 
#                 )
#         clip_opt_graph =  onnx.load(clip_path)
#         onnx.save_model(
#             clip_opt_graph,
#             clip_optimize,  
#             save_as_external_data=True, 
#             all_tensors_to_one_file=True,  
#             convert_attribute=False,            
#             location="weights.pb",
#         )

# if __name__ == "__main__":
#     convert_models(image_path= '/home/SDXL_CNextAnimeCanny_IPAdapter_ONNX/code_inference/image_condition/control_canny_edge/condition_0.jpg',
#                       output_path='/home/new_onnx/image_encoder',
#                        opset=18,
#                      )


import onnx

def optimize_onnx_model(clip_path, output_model_path, weight_file="weights.pb"):
    # Load the ONNX model từ đường dẫn clip_path
    clip_opt_graph = onnx.load(clip_path)
    
    # Save optimized model, với weights được lưu riêng vào weight_file
    onnx.save_model(
        clip_opt_graph,
        output_model_path,  
        save_as_external_data=True,         # Lưu dữ liệu tensor lớn ra ngoài
        all_tensors_to_one_file=True,       # Gom tất cả tensor vào một file duy nhất
        location=weight_file,               # Đường dẫn đến file lưu weights
    )

# Sử dụng hàm
clip_path = "/home/new_onnx/image_encoder/clip_vision_proj/model.onnx"     # Đường dẫn model đã convert
output_model_path = "/home/new_onnx/image_encoder/optimize/model.onnx"            # Đường dẫn file model bạn muốn lưu
weight_file = "weights.pb"                  # Tên file chứa weights
optimize_onnx_model(clip_path, output_model_path, weight_file)