Spravil commited on
Commit
d976184
·
verified ·
1 Parent(s): 87cb855

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import timm
3
+ import hyenapixel.models
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ with open("imagenet.txt") as file:
9
+ class_names = [line.rstrip() for line in file]
10
+
11
+ def predict(model_name, image):
12
+ model = timm.create_model(model_name, pretrained=True)
13
+ model.eval()
14
+ image_size = 224
15
+ if "_384" in model_name:
16
+ image_size = 384
17
+ transform = timm.data.create_transform(image_size)
18
+ input_tensor = transform(image).unsqueeze(0)
19
+ with torch.no_grad():
20
+ output = model(input_tensor)
21
+ output_np = output[0].numpy()
22
+ class_ind = np.argmax(output_np)
23
+ return class_names[class_ind]
24
+
25
+ interface = gr.Interface(
26
+ fn=predict,
27
+ inputs=[
28
+ gr.Dropdown(label="Select Model", value="hb_former_b36", choices=["hpx_former_s18", "hpx_former_s18_384", "hb_former_s18", "c_hpx_former_s18", "hpx_a_former_s18", "hb_a_former_s18", "hpx_former_b36", "hb_former_b36"]),
29
+ gr.Image(type="pil", label="Upload Image")
30
+ ],
31
+ outputs=gr.Textbox(label="Predicted Class"),
32
+ title="Image Classification",
33
+ description="Choose a model and upload an image to predict the class."
34
+ )
35
+
36
+ interface.launch()