import torch | |
from transformers import AutoImageProcessor, Swinv2ForImageClassification | |
import onnx | |
import onnxruntime as ort | |
# Load the model and processor | |
image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy") | |
model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy") | |
# Set the model to evaluation mode | |
model.eval() | |
# Create dummy input for tracing | |
dummy_input = torch.randn(1, 3, 256, 256) # Batch size of 1, 3 color channels, 256x256 image | |
# Export the model to ONNX | |
onnx_model_path = "model.onnx" | |
# torch.onnx.export( | |
# model, | |
# dummy_input, | |
# onnx_model_path, | |
# input_names=["pixel_values"], | |
# output_names=["logits"], | |
# opset_version=11, | |
# dynamic_axes={ | |
# "pixel_values": {0: "batch_size"}, | |
# "logits": {0: "batch_size"} | |
# } | |
# ) | |
# Verify the ONNX model | |
# onnx_model = onnx.load(onnx_model_path) | |
# onnx.checker.check_model(onnx_model) | |
print("The ONNX model is valid.") |