|
import numpy as np |
|
import onnx |
|
import torch |
|
|
|
from StyleTransferModel_128 import StyleTransferModel |
|
|
|
def save_as_onnx_model(torch_model_path, save_emap=True, img_size = 128, originalInswapperClassCompatible = True): |
|
output_path = torch_model_path.replace(".pth", ".onnx") |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
torch_model = StyleTransferModel().to(device) |
|
torch_model.load_state_dict(torch.load(torch_model_path, map_location=device), strict=False) |
|
|
|
|
|
torch_model.eval() |
|
|
|
if originalInswapperClassCompatible: |
|
dynamic_axes = None |
|
else: |
|
image_axe = {0: 'batch_size', 1: 'channels', 2: 'height', 3: 'width'} |
|
dynamic_axes = {'target': image_axe, |
|
'source': {0: 'batch_size'}, |
|
'output' : image_axe} |
|
|
|
torch.onnx.export(torch_model, |
|
{ |
|
'target' :torch.randn(1, 3, img_size, img_size, requires_grad=True).to(device), |
|
'source': torch.randn(1, 512, requires_grad=True).to(device), |
|
}, |
|
output_path, |
|
export_params=True, |
|
opset_version=11, |
|
do_constant_folding=True, |
|
input_names = ['target', "source"], |
|
output_names = ['output'], |
|
dynamic_axes=dynamic_axes) |
|
|
|
model = onnx.load(output_path) |
|
|
|
if save_emap : |
|
emap = np.load("emap.npy") |
|
|
|
emap_tensor = onnx.helper.make_tensor( |
|
name='emap', |
|
data_type=onnx.TensorProto.FLOAT, |
|
dims=[512, 512], |
|
vals=emap |
|
) |
|
|
|
model.graph.initializer.append(emap_tensor) |
|
|
|
onnx.save(model, output_path) |
|
|