import torch from huggingface_hub import hf_hub_download class ImageProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens image_proj_model = ImageProjModel() model_filename = hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/ip-adapter_sdxl.bin") state_dict = torch.load(model_filename, map_location="cpu", weights_only=True) image_proj_model.load_state_dict(state_dict["image_proj"]) clip_image_embeds = torch.rand((1, 1280)) onnx_output_path = 'model.onnx' torch.onnx.export( image_proj_model, clip_image_embeds, onnx_output_path, export_params=True, opset_version=18, do_constant_folding=True, input_names=['clip_image_embeds'], output_names=['image_prompt_embeds'], dynamic_axes={ 'clip_image_embeds': {0: 'batch_size', 1:'embed_size'}, 'image_prompt_embeds': {0: 'batch_size'}, })