LPX55 commited on
Commit
7cc2ec6
·
verified ·
1 Parent(s): 7820a52

Create utils/onnx.py

Browse files
Files changed (1) hide show
  1. utils/onnx.py +34 -0
utils/onnx.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, Swinv2ForImageClassification
3
+ import onnx
4
+ import onnxruntime as ort
5
+
6
+ # Load the model and processor
7
+ image_processor = AutoImageProcessor.from_pretrained("haywoodsloan/ai-image-detector-deploy")
8
+ model = Swinv2ForImageClassification.from_pretrained("haywoodsloan/ai-image-detector-deploy")
9
+
10
+ # Set the model to evaluation mode
11
+ model.eval()
12
+
13
+ # Create dummy input for tracing
14
+ dummy_input = torch.randn(1, 3, 256, 256) # Batch size of 1, 3 color channels, 256x256 image
15
+
16
+ # Export the model to ONNX
17
+ onnx_model_path = "model.onnx"
18
+ # torch.onnx.export(
19
+ # model,
20
+ # dummy_input,
21
+ # onnx_model_path,
22
+ # input_names=["pixel_values"],
23
+ # output_names=["logits"],
24
+ # opset_version=11,
25
+ # dynamic_axes={
26
+ # "pixel_values": {0: "batch_size"},
27
+ # "logits": {0: "batch_size"}
28
+ # }
29
+ # )
30
+
31
+ # Verify the ONNX model
32
+ # onnx_model = onnx.load(onnx_model_path)
33
+ # onnx.checker.check_model(onnx_model)
34
+ print("The ONNX model is valid.")