JosephTK commited on
Commit
3602b81
·
1 Parent(s): 1c3fab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -30
app.py CHANGED
@@ -1,47 +1,25 @@
1
  import gradio as gr
2
- from transformers import AutoImageProcessor, ResNetForImageClassification, YolosFeatureExtractor, YolosForObjectDetection
3
  import torch
4
 
 
 
5
 
6
-
7
-
8
- def detect(image1, image2):
9
- ### Image 1, the object ###
10
- processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
11
- model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
12
-
13
- inputs = processor(image1, return_tensors="pt")
14
-
15
- with torch.no_grad():
16
- logits = model(**inputs).logits
17
-
18
- # model predicts one of the 1000 ImageNet classes
19
- predicted_label = logits.argmax(-1).item()
20
- object_label = model.config.id2label[predicted_label]
21
-
22
- ### Image 2, object detections ###
23
- from PIL import Image
24
- import requests
25
-
26
- feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
27
- model = YolosForObjectDetection.from_pretrained('hustvl/yolos-small')
28
-
29
- inputs = feature_extractor(images=image2, return_tensors="pt")
30
  outputs = model(**inputs)
31
 
32
  # model predicts bounding boxes and corresponding COCO classes
33
  logits = outputs.logits
34
  bboxes = outputs.pred_boxes
35
 
36
- return object_label, bboxes
37
-
38
-
39
 
40
 
41
  demo = gr.Interface(
42
  fn=detect,
43
- inputs=[gr.inputs.Image(label="Object to detect"), gr.inputs.Image(label="Image to detect object in")],
44
- outputs=["text", "text"],
45
  title="Object Counts in Image"
46
  )
47
 
 
1
  import gradio as gr
2
+ from transformers import YolosFeatureExtractor, YolosForObjectDetection
3
  import torch
4
 
5
+ feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
6
+ model = YolosForObjectDetection.from_pretrained('hustvl/yolos-small')
7
 
8
+ def detect(image):
9
+ inputs = feature_extractor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  outputs = model(**inputs)
11
 
12
  # model predicts bounding boxes and corresponding COCO classes
13
  logits = outputs.logits
14
  bboxes = outputs.pred_boxes
15
 
16
+ return outputs
 
 
17
 
18
 
19
  demo = gr.Interface(
20
  fn=detect,
21
+ inputs=[gr.inputs.Image(label="Input image")],
22
+ outputs=["text"],
23
  title="Object Counts in Image"
24
  )
25