CXDJY commited on
Commit
a3e1670
·
1 Parent(s): 4623df1

Add application file

Browse files
Files changed (1) hide show
  1. app.py +85 -1
app.py CHANGED
@@ -1,7 +1,91 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import ViTForImageClassification, ViTModel, ViTImageProcessor
5
+ from PIL import Image
6
+ import io
7
+ from sklearn.preprocessing import LabelEncoder
8
+ import json
9
 
10
  def greet(name):
11
  return "Hello " + name + "!!"
12
 
13
+
14
+ async def test2(file, top_k: int = 5):
15
+ extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
16
+
17
+ if not extension:
18
+ return "Image format must be jpg, jpeg, or png!"
19
+ # Read image contents
20
+ contents = await file.read()
21
+
22
+ # Preprocess image
23
+ image_tensor = preprocess_image(contents)
24
+
25
+ # Make predictions
26
+ predictions = predict(image_tensor, top_k)
27
+
28
+ item = {"predictions": predictions, "filename": file.filename}
29
+ return json.dumps(item)
30
+
31
+ encoder = LabelEncoder()
32
+ encoder.classes_ = np.load('encoder.npy', allow_pickle=True)
33
+
34
+ pretrained_model = ViTModel.from_pretrained('pillIdentifierAI/pillIdentifier')
35
+ feature_extractor = ViTImageProcessor(
36
+ image_size=224,
37
+ do_resize=True,
38
+ do_normalize=True,
39
+ do_rescale=False,
40
+ image_mean=[0.5, 0.5, 0.5],
41
+ image_std=[0.5, 0.5, 0.5],
42
+ )
43
+
44
+
45
+ config = pretrained_model.config
46
+ config.num_labels = 2112 # Change this to the appropriate number of classes
47
+ model = ViTForImageClassification(config)
48
+ model.vit = pretrained_model
49
+
50
+ model.eval()
51
+
52
+ def preprocess_image(contents):
53
+ # Convert image bytes to PIL Image
54
+ image = Image.open(io.BytesIO(contents))
55
+
56
+ if image.mode != 'RGB':
57
+ image = image.convert('RGB')
58
+
59
+ # Use the feature extractor directly
60
+ inputs = feature_extractor(images=[image])
61
+ image_tensor = inputs['pixel_values'][0]
62
+
63
+ # Convert to tensor
64
+ image_tensor = torch.tensor(image_tensor, dtype=torch.float32)
65
+
66
+ return image_tensor
67
+
68
+ def predict(image_tensor, top_k=5):
69
+ # Ensure the model is in evaluation mode
70
+ model.eval()
71
+
72
+ # Make prediction
73
+ with torch.no_grad():
74
+ outputs = model(pixel_values=image_tensor.unsqueeze(0)) # Add batch dimension
75
+ logits = outputs.logits.numpy()
76
+
77
+ # Get top k predictions and their probabilities
78
+ predictions = np.argsort(logits, axis=1)[:, ::-1][:, :top_k]
79
+ probabilities = np.sort(logits, axis=1)[:, ::-1][:, :top_k]
80
+
81
+ # Decode predictions using the label encoder and create the result dictionary
82
+ result = {}
83
+ for i in range(top_k):
84
+ class_name = encoder.inverse_transform([predictions[0][i]])[0]
85
+ probability = probabilities[0][i]
86
+ result[i + 1] = {'label': str(class_name), 'probability': float(probability)}
87
+
88
+ return result
89
+
90
+ iface = gr.Interface(fn=greet, inputs="image", outputs="text")
91
  iface.launch()