Spaces:
Sleeping
Sleeping
import gradio as gr | |
import timm | |
import hyenapixel.models | |
import torch | |
import numpy as np | |
from PIL import Image | |
with open("imagenet.txt") as file: | |
class_names = [line.rstrip() for line in file] | |
def predict(model_name, image): | |
model = timm.create_model(model_name, pretrained=True) | |
model.eval() | |
image_size = 224 | |
if "_384" in model_name: | |
image_size = 384 | |
transform = timm.data.create_transform(image_size) | |
input_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
output_np = output[0].numpy() | |
class_ind = np.argmax(output_np) | |
return class_names[class_ind] | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]), | |
gr.Image(type="pil", label="Upload Image") | |
], | |
outputs=gr.Textbox(label="Predicted Class"), | |
title="Image Classification", | |
description="Choose a model and upload an image to predict the class." | |
) | |
interface.launch() |