Jsmithwek commited on
Commit
643d4fb
·
1 Parent(s): ac1bcf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -79
app.py CHANGED
@@ -1,81 +1,66 @@
1
- import io
2
-
3
  import matplotlib.pyplot as plt
4
- import requests
5
- import streamlit as st
6
- import torch
7
  from PIL import Image
8
- from transformers import DetrFeatureExtractor, DetrForObjectDetection
9
-
10
- # colors for visualization
11
- COLORS = [
12
- [0.000, 0.447, 0.741],
13
- [0.850, 0.325, 0.098],
14
- [0.929, 0.694, 0.125],
15
- [0.494, 0.184, 0.556],
16
- [0.466, 0.674, 0.188],
17
- [0.301, 0.745, 0.933]
18
- ]
19
-
20
-
21
- @st.cache(allow_output_mutation=True)
22
- def get_hf_components(model_name_or_path):
23
- feature_extractor = DetrFeatureExtractor.from_pretrained(model_name_or_path)
24
- model = DetrForObjectDetection.from_pretrained(model_name_or_path)
25
- model.eval()
26
- return feature_extractor, model
27
-
28
-
29
- @st.cache
30
- def get_img_from_url(url):
31
- return Image.open(requests.get(url, stream=True).raw)
32
-
33
-
34
- def fig2img(fig):
35
- buf = io.BytesIO()
36
- fig.savefig(buf)
37
- buf.seek(0)
38
- img = Image.open(buf)
39
- return img
40
-
41
-
42
- def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
43
- keep = output_dict["scores"] > threshold
44
- boxes = output_dict["boxes"][keep].tolist()
45
- scores = output_dict["scores"][keep].tolist()
46
- labels = output_dict["labels"][keep].tolist()
47
- if id2label is not None:
48
- labels = [id2label[x] for x in labels]
49
-
50
- plt.figure(figsize=(16, 10))
51
- plt.imshow(pil_img)
52
- ax = plt.gca()
53
- colors = COLORS * 100
54
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
55
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
56
- ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
57
- plt.axis("off")
58
- return fig2img(plt.gcf())
59
-
60
-
61
- def make_prediction(img, feature_extractor, model):
62
- inputs = feature_extractor(img, return_tensors="pt")
63
- outputs = model(**inputs)
64
- img_size = torch.tensor([tuple(reversed(img.size))])
65
- processed_outputs = feature_extractor.post_process(outputs, img_size)
66
- return processed_outputs[0]
67
-
68
-
69
- def main():
70
- option = st.selectbox("Which model should we use?", ("facebook/detr-resnet-50", "facebook/detr-resnet-101"))
71
- feature_extractor, model = get_hf_components(option)
72
- url = st.text_input("URL to some image", "http://images.cocodataset.org/val2017/000000039769.jpg")
73
- img = get_img_from_url(url)
74
- processed_outputs = make_prediction(img, feature_extractor, model)
75
- threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.7)
76
- viz_img = visualize_prediction(img, processed_outputs, threshold, model.config.id2label)
77
- st.image(viz_img)
78
-
79
-
80
- if __name__ == "__main__":
81
- main()
 
1
+ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
 
2
  import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ from random import choice
 
5
  from PIL import Image
6
+ import os
7
+ from matplotlib import rcParams, font_manager
8
+ import streamlit as st
9
+ import urllib.request
10
+ import requests
11
+
12
+ extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-101")
13
+
14
+ model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-101")
15
+
16
+ from transformers import pipeline
17
+
18
+ pipe = pipeline('object-detection', model=model, feature_extractor=extractor)
19
+
20
+ img_url = st.text_input('Image URL', 'https://images.unsplash.com/photo-1556911220-bff31c812dba?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2468&q=80')
21
+
22
+ st.caption('Downloading Image...')
23
+
24
+ img_data = requests.get(img_url).content
25
+ with open('detect.jpg', 'wb') as handler:
26
+ handler.write(img_data)
27
+
28
+ st.caption('Running Detection...')
29
+
30
+ output = pipe(img_url)
31
+
32
+ st.caption('Adding Predictions to Image...')
33
+
34
+ fpath = "Poppins-SemiBold.ttf"
35
+ prop = font_manager.FontProperties(fname=fpath)
36
+
37
+ img = Image.open('detect.jpg')
38
+ plt.figure(dpi=2400)
39
+
40
+ # Create figure and axes
41
+ fig, ax = plt.subplots()
42
+
43
+ # Display the image
44
+ ax.imshow(img)
45
+
46
+ colors = ["#ef4444", "#f97316", "#eab308", "#84cc16", "#06b6d4", "#6366f1"]
47
+
48
+ # Create a Rectangle patch
49
+ for prediction in output:
50
+ selected_color = choice(colors)
51
+ x, y, w, h = prediction['box']['xmin'], prediction['box']['ymin'], prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
52
+ rect = patches.FancyBboxPatch((x, y), w, h, linewidth=1.25, edgecolor=selected_color, facecolor='none', boxstyle="round,pad=-0.0040,rounding_size=10",)
53
+ ax.add_patch(rect)
54
+ plt.text(x, y-25, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontsize=5, color=selected_color, fontproperties=prop)
55
+
56
+ plt.axis('off')
57
+
58
+ plt.savefig('detect-bbox.jpg', dpi=1200, bbox_inches='tight')
59
+
60
+ image = Image.open('detect-bbox.jpg')
61
+
62
+ st.image(image, caption='DETR Image')
63
+
64
+ plt.show()
65
+
66
+ st.caption('Done!')