File size: 3,275 Bytes
6a3ad5b b98d24d 6a3ad5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from cgitb import text
import os
import clip
import torch.onnx
import torch
from torch import nn
from multiprocessing import Pool
class TextTransformer(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.clip_model = clip_model
def forward(self, x: torch.Tensor):
return self.clip_model.encode_text(x)
def export(model, input, path):
print(f"Exporting to {path}")
torch.onnx.export(
model, # model being run
input, # model input (or a tuple for multiple inputs)
path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=16, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={
'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}
}
)
def convert(model_name, dashed_name):
visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx"
textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx"
visual_exists = os.path.exists(visual_path)
textual_exists = os.path.exists(textual_path)
if visual_exists and textual_exists:
print(f"{visual_path} exists, skipping")
print(f"{textual_path} exists, skipping")
return
print(f"Model: {model_name}")
print(f"Loading CLIP")
model, _ = clip.load(model_name, device=device)
model = model.to(device=device)
if not visual_exists:
input_res = model.visual.input_resolution
export(
model.visual,
torch.rand(1, 3, input_res, input_res),
visual_path,
)
else:
print(f"{visual_path} exists, skipping")
if not textual_exists:
text_transformer = TextTransformer(model)
export(
text_transformer,
clip.tokenize(["hello onnx"]).to(device),
textual_path,
)
else:
print(f"{textual_path} exists, skipping")
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
output_dir = "converted"
if __name__ == "__main__":
print(f"Torch device: {device}")
available_models = clip.available_models()
print(f"Available models: {available_models}")
models = [
("RN50", "resnet-50"),
("RN101", "resnet-101"),
("RN50x4", "resnet-50x4"),
("RN50x16", "resnet-50x16"),
("RN50x64", "resnet-50x64"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("RN50", "resnet-50"),
("ViT-B/16", "vit-base-patch16"),
("ViT-B/32", "vit-base-patch32"),
("ViT-L/14", "vit-large-patch14"),
("ViT-L/14@336px", "vit-large-patch14-336"),
]
print(f"Converting models: {models}")
for model in models:
convert(*model)
# For converting multiple models at once
# with Pool(1) as p:
# p.starmap(convert, models)
print("done")
|