|
from cgitb import text |
|
import os |
|
|
|
import clip |
|
import torch.onnx |
|
import torch |
|
from torch import nn |
|
from multiprocessing import Pool |
|
|
|
class TextTransformer(nn.Module): |
|
def __init__(self, clip_model): |
|
super().__init__() |
|
self.clip_model = clip_model |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.clip_model.encode_text(x) |
|
|
|
def export(model, input, path): |
|
print(f"Exporting to {path}") |
|
torch.onnx.export( |
|
model, |
|
input, |
|
path, |
|
export_params=True, |
|
opset_version=16, |
|
do_constant_folding=True, |
|
input_names = ['input'], |
|
output_names = ['output'], |
|
dynamic_axes={ |
|
'input' : {0 : 'batch_size'}, |
|
'output' : {0 : 'batch_size'} |
|
} |
|
) |
|
|
|
def convert(model_name, dashed_name): |
|
visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx" |
|
textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx" |
|
visual_exists = os.path.exists(visual_path) |
|
textual_exists = os.path.exists(textual_path) |
|
if visual_exists and textual_exists: |
|
print(f"{visual_path} exists, skipping") |
|
print(f"{textual_path} exists, skipping") |
|
return |
|
|
|
print(f"Model: {model_name}") |
|
print(f"Loading CLIP") |
|
model, _ = clip.load(model_name, device=device) |
|
model = model.to(device=device) |
|
|
|
|
|
if not visual_exists: |
|
input_res = model.visual.input_resolution |
|
export( |
|
model.visual, |
|
torch.rand(1, 3, input_res, input_res), |
|
visual_path, |
|
) |
|
else: |
|
print(f"{visual_path} exists, skipping") |
|
|
|
if not textual_exists: |
|
text_transformer = TextTransformer(model) |
|
export( |
|
text_transformer, |
|
clip.tokenize(["hello onnx"]).to(device), |
|
textual_path, |
|
) |
|
else: |
|
print(f"{textual_path} exists, skipping") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = "cpu" |
|
output_dir = "converted" |
|
if __name__ == "__main__": |
|
print(f"Torch device: {device}") |
|
|
|
available_models = clip.available_models() |
|
print(f"Available models: {available_models}") |
|
|
|
models = [ |
|
("RN50", "resnet-50"), |
|
("RN101", "resnet-101"), |
|
("RN50x4", "resnet-50x4"), |
|
("RN50x16", "resnet-50x16"), |
|
("RN50x64", "resnet-50x64"), |
|
("RN50", "resnet-50"), |
|
("RN50", "resnet-50"), |
|
("RN50", "resnet-50"), |
|
("ViT-B/16", "vit-base-patch16"), |
|
("ViT-B/32", "vit-base-patch32"), |
|
("ViT-L/14", "vit-large-patch14"), |
|
("ViT-L/14@336px", "vit-large-patch14-336"), |
|
] |
|
|
|
print(f"Converting models: {models}") |
|
|
|
for model in models: |
|
convert(*model) |
|
|
|
|
|
|
|
|
|
|
|
print("done") |
|
|