rbarman commited on
Commit
d9bf0de
·
1 Parent(s): 49f2313
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoFeatureExtractor, ResNetForImageClassification
2
+ import torch
3
+ import gradio as gr
4
+
5
+ # load model
6
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
7
+ model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
8
+
9
+ def predict(image):
10
+
11
+ inputs = feature_extractor(image, return_tensors="pt")
12
+ with torch.no_grad():
13
+ logits = model(**inputs).logits
14
+
15
+ # model predicts one of the 1000 ImageNet classes
16
+ predicted_label = logits.argmax(-1).item()
17
+ print(model.config.id2label[predicted_label])
18
+
19
+ # setup Gradio interface
20
+ title = "Image classifier"
21
+ description = "Image classification with pretrained resnet50 model"
22
+ #examples = ['elephant.jpg']
23
+ interpretation='default'
24
+ enable_queue=True
25
+
26
+ gr.Interface(
27
+ fn=predict,
28
+ inputs=gr.inputs.Image(),
29
+ outputs=gr.outputs.Label(num_top_classes=1),
30
+ title=title,
31
+ description=description,
32
+ #examples=examples,
33
+ interpretation=interpretation,
34
+ enable_queue=enable_queue
35
+ ).launch()