import gradio as gr import timm import hyenapixel.models import torch import numpy as np from PIL import Image with open("imagenet.txt") as file: class_names = [line.rstrip() for line in file] def predict(model_name, image): model = timm.create_model(model_name, pretrained=True) model.eval() image_size = 224 if "_384" in model_name: image_size = 384 transform = timm.data.create_transform(image_size) input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(input_tensor) output_np = output[0].numpy() class_ind = np.argmax(output_np) return class_names[class_ind] interface = gr.Interface( fn=predict, inputs=[ gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]), gr.Image(type="pil", label="Upload Image") ], outputs=gr.Textbox(label="Predicted Class"), title="Image Classification", description="Choose a model and upload an image to predict the class." ) interface.launch()