from pathlib import Path import gc import torch import numpy as np import openvino as ov def cleanup_torchscript_cache(): """ Helper for removing cached model representation """ torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() torch.jit._state._clear_class_state() def convert_encoder(text_encoder: torch.nn.Module, ir_path: Path): """ Convert Text Encoder model to IR. Function accepts pipeline, prepares example inputs for conversion Parameters: text_encoder (torch.nn.Module): text encoder PyTorch model ir_path (Path): File for storing model Returns: None """ if not ir_path.exists(): input_ids = torch.ones((1, 77), dtype=torch.long) # switch model to inference mode text_encoder.eval() # disable gradients calculation for reducing memory consumption with torch.no_grad(): # export model ov_model = ov.convert_model( text_encoder, # model instance example_input=input_ids, # example inputs for model tracing input=([1, 77],), # input shape for conversion ) ov.save_model(ov_model, ir_path) del ov_model cleanup_torchscript_cache() print("Text Encoder successfully converted to IR") def convert_unet( unet: torch.nn.Module, ir_path: Path, num_channels: int = 4, width: int = 64, height: int = 64, ): """ Convert Unet model to IR format. Function accepts pipeline, prepares example inputs for conversion Parameters: unet (torch.nn.Module): UNet PyTorch model ir_path (Path): File for storing model num_channels (int, optional, 4): number of input channels width (int, optional, 64): input width height (int, optional, 64): input height Returns: None """ dtype_mapping = {torch.float32: ov.Type.f32, torch.float64: ov.Type.f64} if not ir_path.exists(): # prepare inputs encoder_hidden_state = torch.ones((2, 77, 1024)) latents_shape = (2, num_channels, width, height) latents = torch.randn(latents_shape) t = torch.from_numpy(np.array([1], dtype=np.float32)) unet.eval() dummy_inputs = (latents, t, encoder_hidden_state) input_info = [] for input_tensor in dummy_inputs: shape = ov.PartialShape(tuple(input_tensor.shape)) element_type = dtype_mapping[input_tensor.dtype] input_info.append((shape, element_type)) with torch.no_grad(): ov_model = ov.convert_model(unet, example_input=dummy_inputs, input=input_info) ov.save_model(ov_model, ir_path) del ov_model cleanup_torchscript_cache() print("U-Net successfully converted to IR") def convert_vae_encoder(vae: torch.nn.Module, ir_path: Path, width: int = 512, height: int = 512): """ Convert VAE model to IR format. VAE model, creates wrapper class for export only necessary for inference part, prepares example inputs for onversion Parameters: vae (torch.nn.Module): VAE PyTorch model ir_path (Path): File for storing model width (int, optional, 512): input width height (int, optional, 512): input height Returns: None """ class VAEEncoderWrapper(torch.nn.Module): def __init__(self, vae): super().__init__() self.vae = vae def forward(self, image): return self.vae.encode(x=image)["latent_dist"].sample() if not ir_path.exists(): vae_encoder = VAEEncoderWrapper(vae) vae_encoder.eval() image = torch.zeros((1, 3, width, height)) with torch.no_grad(): ov_model = ov.convert_model(vae_encoder, example_input=image, input=([1, 3, width, height],)) ov.save_model(ov_model, ir_path) del ov_model cleanup_torchscript_cache() print("VAE encoder successfully converted to IR") def convert_vae_decoder(vae: torch.nn.Module, ir_path: Path, width: int = 64, height: int = 64): """ Convert VAE decoder model to IR format. Function accepts VAE model, creates wrapper class for export only necessary for inference part, prepares example inputs for conversion Parameters: vae (torch.nn.Module): VAE model ir_path (Path): File for storing model width (int, optional, 64): input width height (int, optional, 64): input height Returns: None """ class VAEDecoderWrapper(torch.nn.Module): def __init__(self, vae): super().__init__() self.vae = vae def forward(self, latents): return self.vae.decode(latents) if not ir_path.exists(): vae_decoder = VAEDecoderWrapper(vae) latents = torch.zeros((1, 4, width, height)) vae_decoder.eval() with torch.no_grad(): ov_model = ov.convert_model(vae_decoder, example_input=latents, input=([1, 4, width, height],)) ov.save_model(ov_model, ir_path) del ov_model cleanup_torchscript_cache() print("VAE decoder successfully converted to IR")