File size: 2,438 Bytes
0b5f4ac 1332ee4 0b5f4ac 1332ee4 0b5f4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright 2023 Advanced Micro Devices, Inc. on behalf of itself and its subsidiaries and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Megvii, Inc. and its affiliates.
import onnxruntime
import argparse
from PIL import Image
import torchvision.transforms as transforms
parser = argparse.ArgumentParser()
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
parser.add_argument('--image_path', type=str, required=True)
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
args = parser.parse_args()
def read_image():
# Read a PIL image
image = Image.open(args.image_path)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
normalize,
])
img_tensor = transform(image).unsqueeze(0)
if args.data_format == "nhwc":
img_tensor = transform(image).unsqueeze(0).permute((0, 2, 3, 1))
return img_tensor.numpy()
def main():
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(
args.onnx_path, providers=providers, provider_options=provider_options)
ort_inputs = {
ort_session.get_inputs()[0].name: read_image()
}
output = ort_session.run(None, ort_inputs)[0]
print("class id =", output[0].argmax())
if __name__ == "__main__":
main()
|