malvika2003's picture
Upload folder using huggingface_hub
db5855f verified
from pathlib import Path
import gc
import torch
import numpy as np
import openvino as ov
def cleanup_torchscript_cache():
"""
Helper for removing cached model representation
"""
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
def convert_encoder(text_encoder: torch.nn.Module, ir_path: Path):
"""
Convert Text Encoder model to IR.
Function accepts pipeline, prepares example inputs for conversion
Parameters:
text_encoder (torch.nn.Module): text encoder PyTorch model
ir_path (Path): File for storing model
Returns:
None
"""
if not ir_path.exists():
input_ids = torch.ones((1, 77), dtype=torch.long)
# switch model to inference mode
text_encoder.eval()
# disable gradients calculation for reducing memory consumption
with torch.no_grad():
# export model
ov_model = ov.convert_model(
text_encoder, # model instance
example_input=input_ids, # example inputs for model tracing
input=([1, 77],), # input shape for conversion
)
ov.save_model(ov_model, ir_path)
del ov_model
cleanup_torchscript_cache()
print("Text Encoder successfully converted to IR")
def convert_unet(
unet: torch.nn.Module,
ir_path: Path,
num_channels: int = 4,
width: int = 64,
height: int = 64,
):
"""
Convert Unet model to IR format.
Function accepts pipeline, prepares example inputs for conversion
Parameters:
unet (torch.nn.Module): UNet PyTorch model
ir_path (Path): File for storing model
num_channels (int, optional, 4): number of input channels
width (int, optional, 64): input width
height (int, optional, 64): input height
Returns:
None
"""
dtype_mapping = {torch.float32: ov.Type.f32, torch.float64: ov.Type.f64}
if not ir_path.exists():
# prepare inputs
encoder_hidden_state = torch.ones((2, 77, 1024))
latents_shape = (2, num_channels, width, height)
latents = torch.randn(latents_shape)
t = torch.from_numpy(np.array([1], dtype=np.float32))
unet.eval()
dummy_inputs = (latents, t, encoder_hidden_state)
input_info = []
for input_tensor in dummy_inputs:
shape = ov.PartialShape(tuple(input_tensor.shape))
element_type = dtype_mapping[input_tensor.dtype]
input_info.append((shape, element_type))
with torch.no_grad():
ov_model = ov.convert_model(unet, example_input=dummy_inputs, input=input_info)
ov.save_model(ov_model, ir_path)
del ov_model
cleanup_torchscript_cache()
print("U-Net successfully converted to IR")
def convert_vae_encoder(vae: torch.nn.Module, ir_path: Path, width: int = 512, height: int = 512):
"""
Convert VAE model to IR format.
VAE model, creates wrapper class for export only necessary for inference part,
prepares example inputs for onversion
Parameters:
vae (torch.nn.Module): VAE PyTorch model
ir_path (Path): File for storing model
width (int, optional, 512): input width
height (int, optional, 512): input height
Returns:
None
"""
class VAEEncoderWrapper(torch.nn.Module):
def __init__(self, vae):
super().__init__()
self.vae = vae
def forward(self, image):
return self.vae.encode(x=image)["latent_dist"].sample()
if not ir_path.exists():
vae_encoder = VAEEncoderWrapper(vae)
vae_encoder.eval()
image = torch.zeros((1, 3, width, height))
with torch.no_grad():
ov_model = ov.convert_model(vae_encoder, example_input=image, input=([1, 3, width, height],))
ov.save_model(ov_model, ir_path)
del ov_model
cleanup_torchscript_cache()
print("VAE encoder successfully converted to IR")
def convert_vae_decoder(vae: torch.nn.Module, ir_path: Path, width: int = 64, height: int = 64):
"""
Convert VAE decoder model to IR format.
Function accepts VAE model, creates wrapper class for export only necessary for inference part,
prepares example inputs for conversion
Parameters:
vae (torch.nn.Module): VAE model
ir_path (Path): File for storing model
width (int, optional, 64): input width
height (int, optional, 64): input height
Returns:
None
"""
class VAEDecoderWrapper(torch.nn.Module):
def __init__(self, vae):
super().__init__()
self.vae = vae
def forward(self, latents):
return self.vae.decode(latents)
if not ir_path.exists():
vae_decoder = VAEDecoderWrapper(vae)
latents = torch.zeros((1, 4, width, height))
vae_decoder.eval()
with torch.no_grad():
ov_model = ov.convert_model(vae_decoder, example_input=latents, input=([1, 4, width, height],))
ov.save_model(ov_model, ir_path)
del ov_model
cleanup_torchscript_cache()
print("VAE decoder successfully converted to IR")