antfraia commited on
Commit
d889240
·
1 Parent(s): 06aa8a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -1,14 +1,29 @@
1
  import gradio as gr
2
- from transformers import ViTForImageClassification, ViTProcessor
 
 
 
3
 
4
- # Load the model and processor
5
  model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
6
- processor = ViTProcessor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
 
7
 
8
  def predict_image(img):
9
- inputs = processor(img, return_tensors="pt")
10
- outputs = model(**inputs)
11
- predictions = outputs.logits.argmax(-1)
 
 
 
 
 
12
  return model.config.id2label[predictions.item()]
13
 
14
  # Create the interface
 
1
  import gradio as gr
2
+ from transformers import ViTForImageClassification
3
+ import torch
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
 
7
+ # Load the model
8
  model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
9
+ model.eval()
10
+
11
+ # Define the image preprocessing pipeline
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
16
+ ])
17
 
18
  def predict_image(img):
19
+ # Apply the transformations
20
+ tensor_img = transform(img).unsqueeze(0)
21
+
22
+ # Make prediction
23
+ with torch.no_grad():
24
+ outputs = model(tensor_img)
25
+ predictions = outputs.logits.argmax(-1)
26
+
27
  return model.config.id2label[predictions.item()]
28
 
29
  # Create the interface