add torchscript and safetensors versions of retccl model
This commit adds TorchScript and Safetensors versions of the RetCCL model. Torchscript allows the model to be used without code defining its implementation and without a Python runtime. This can help others incorporate RetCCL into other applications. In fact I was planning to upload RetCCL to HuggingFace but I found this repository first. I am hoping to use the Torchscript version of this model in several applications and to pull it from this repository.
Safetensors is a file format developed by HuggingFace to deal with some drawbacks of the PyTorch pickle-based format. I have uploaded a safetensors version here.
Here is the code I used to create the two files here. First I cloned the RetCCL GitHub repo, and then made minor changes to ResNet.py to satisfy TorchScript requirements. Namely, I set self.instDis = nn.Identity()
and self.groupDis = nn.Identity()
if those attributes were not set.
import numpy as np
from safetensors.torch import save_file
import torch
from torch import nn
import ResNet
model = ResNet.resnet50(num_classes=128,mlp=False, two_branch=False, normlinear=True)
pretext_model = torch.load("/home/jakub/Downloads/retccl.pth", map_location="cpu")
model.fc = nn.Identity()
model.load_state_dict(pretext_model, strict=True)
model.eval()
# Save torchscript model.
model_jit = torch.jit.script(model, example_inputs=[(torch.ones(1, 3, 224, 224),)])
torch.jit.save(model_jit, "retccl_torchscript.pth")
# Save safetensors weights
save_file(pretext_model, "retccl.safetensors")
# Ensure model outputs are the same in JIT and original model.
x = torch.ones(1, 3, 224, 224)
with torch.no_grad():
orig = model(x)
new = model_jit(x)
assert np.array_equal(orig, new)
Thanks for the contribution! I will review and merge later today or tomorrow.