import gradio as gr import pandas as pd import torch from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import hiera df=pd.read_csv('Imagenet.txt',usecols=[0],header=None) model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k") input_size = 224 transform_list = [ transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(input_size) ] transform_norm = transforms.Compose(transform_list + [ transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) def recognize(img): img1=img.resize((224,224)) img_norm = transform_norm(img1) output = model(img_norm[None,]) out=output.argmax(dim=-1).item() out1=(df.iloc[out,0]) return out1 demo = gr.Interface(fn=recognize, inputs='pil',outputs='text',examples= [['Banana.jpg']]) demo.launch()