HyenaPixel / app.py
Spravil's picture
Update app.py
c80591e verified
raw
history blame
1.21 kB
import gradio as gr
import timm
import hyenapixel.models
import torch
import numpy as np
from PIL import Image
with open("imagenet.txt") as file:
class_names = [line.rstrip() for line in file]
def predict(model_name, image):
model = timm.create_model(model_name, pretrained=True)
model.eval()
image_size = 224
if "_384" in model_name:
image_size = 384
transform = timm.data.create_transform(image_size)
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
output_np = torch.softmax(output)[0].numpy()
return {clsname: prob for clsname, prob in zip(class_names, output_np)}
interface = gr.Interface(
fn=predict,
inputs=[
gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]),
gr.Image(type="pil", label="Upload Image")
],
outputs=gr.Label(label="Prediction", num_top_classes=10),
title="Image Classification",
description="Choose a model and upload an image to predict the class."
)
interface.launch()