File size: 1,152 Bytes
97daae4 9e0645e 97daae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
import onnx
import data, utils
from train import device, NUM_CLASSES
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import onnxruntime as ort
import numpy as np
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier = nn.Sequential(
nn.Dropout(p = 0.2, inplace = True),
nn.Linear(1280, NUM_CLASSES),
# nn.Softmax()
)
model = utils.load_model(model, "save_model/best_model.pth").to(device)
PATH = "save_model/food_cpu.onnx"
# onnx inference
utils.onnx_inference(model, PATH, "cpu")
onnx_model = onnx.load(PATH)
onnx_check = onnx.checker.check_model(onnx_model)
# print(onnx_check)
x, y = data.test_datasets[0][0], data.test_datasets[0][1]
ort_sess = ort.InferenceSession(PATH)
outputs = ort_sess.run(None, {'input': x.unsqueeze(dim = 0).numpy()})
# Result
classes = data.train_datasets.classes
predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"') |