import gradio as gr import torch from torchvision import transforms model = torch.jit.load("./models/cat_dog_cnn.pt") model.eval() transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)) ]) CLASSES = ["Cat", "Dog", "Panda"] def classify_image(inp): inp = transform(inp).unsqueeze(0) out = model(inp) return CLASSES[out.argmax().item()] iface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil", label="Input Image"), outputs="text", examples=[ "./app_data/cat.jpg", "./app_data/dog.jpg", "./app_data/panda.jpg", ]) iface.launch()