HyenaPixel / app.py
Spravil's picture
Create app.py
d976184 verified
raw
history blame
1.18 kB
import gradio as gr
import timm
import hyenapixel.models
import torch
import numpy as np
from PIL import Image
with open("imagenet.txt") as file:
class_names = [line.rstrip() for line in file]
def predict(model_name, image):
model = timm.create_model(model_name, pretrained=True)
model.eval()
image_size = 224
if "_384" in model_name:
image_size = 384
transform = timm.data.create_transform(image_size)
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
output_np = output[0].numpy()
class_ind = np.argmax(output_np)
return class_names[class_ind]
interface = gr.Interface(
fn=predict,
inputs=[
gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]),
gr.Image(type="pil", label="Upload Image")
],
outputs=gr.Textbox(label="Predicted Class"),
title="Image Classification",
description="Choose a model and upload an image to predict the class."
)
interface.launch()