antfraia commited on
Commit
06fd5a4
·
1 Parent(s): 36f2675

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -1,15 +1,21 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- # Define the function to predict using the model
4
  def predict_image(img):
5
- model = gr.load("models/google/vit-base-patch16-224")
6
- return model.predict(img)
 
 
7
 
8
  # Create the interface
9
  iface = gr.Interface(
10
  fn=predict_image,
11
- inputs=gr.Image(),
12
- outputs="label",
13
  live=True,
14
  capture_session=True,
15
  title="Image recognition",
 
1
  import gradio as gr
2
+ from transformers import ViTForImageClassification, ViTProcessor
3
+
4
+ # Load the model and processor
5
+ model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
6
+ processor = ViTProcessor.from_pretrained("google/vit-base-patch16-224")
7
 
 
8
  def predict_image(img):
9
+ inputs = processor(img, return_tensors="pt")
10
+ outputs = model(**inputs)
11
+ predictions = outputs.logits.argmax(-1)
12
+ return model.config.id2label[predictions.item()]
13
 
14
  # Create the interface
15
  iface = gr.Interface(
16
  fn=predict_image,
17
+ inputs=gr.Image(shape=(224, 224)),
18
+ outputs="text",
19
  live=True,
20
  capture_session=True,
21
  title="Image recognition",