Jangai commited on
Commit
77e4539
·
verified ·
1 Parent(s): c588988

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -2,14 +2,17 @@ import gradio as gr
2
  import torch
3
  from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
 
5
 
6
- # Load the pre-trained model and feature extractor
7
  model_name = "google/vit-base-patch16-224"
8
  image_processor = ViTImageProcessor.from_pretrained(model_name)
9
  model = ViTForImageClassification.from_pretrained(model_name)
10
 
11
  # Define the prediction function
12
  def predict(image):
 
 
13
  inputs = image_processor(images=image, return_tensors="pt")
14
  outputs = model(**inputs)
15
  logits = outputs.logits
 
2
  import torch
3
  from transformers import ViTImageProcessor, ViTForImageClassification
4
  from PIL import Image
5
+ import numpy as np
6
 
7
+ # Load the pre-trained model and image processor
8
  model_name = "google/vit-base-patch16-224"
9
  image_processor = ViTImageProcessor.from_pretrained(model_name)
10
  model = ViTForImageClassification.from_pretrained(model_name)
11
 
12
  # Define the prediction function
13
  def predict(image):
14
+ # Convert the dictionary to a PIL image
15
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
16
  inputs = image_processor(images=image, return_tensors="pt")
17
  outputs = model(**inputs)
18
  logits = outputs.logits