zoheb commited on
Commit
68ac5b3
Β·
1 Parent(s): 605fe32

initial commit

Browse files
Files changed (3) hide show
  1. README.md +6 -4
  2. app.py +170 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Yolos Demo
3
- emoji: 🐨
4
  colorFrom: purple
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.10.0
 
8
  app_file: app.py
9
- pinned: false
 
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: YOLOS Demo (Balloons)
3
+ emoji: 🎈
4
  colorFrom: purple
5
+ colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
+ python_version: 3.9.13
9
  app_file: app.py
10
+ models: zoheb/yolos-small-balloon
11
+ pinned: true
12
  license: mit
13
  ---
14
 
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import cv2
3
+ from PIL import Image
4
+ import streamlit as st
5
+ from transformers import AutoModelForObjectDetection, AutoFeatureExtractor
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from stqdm import stqdm
9
+ from pathlib import Path
10
+
11
+
12
+ # Load the model
13
+ best_model_path = "zoheb/yolos-small-balloon"
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ feature_extractor = AutoFeatureExtractor.from_pretrained(best_model_path, size=512, max_size=864)
16
+ model_pt = AutoModelForObjectDetection.from_pretrained(best_model_path).to(device)
17
+
18
+ # colors for visualization
19
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
20
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
21
+
22
+ # Convert Video to Frames
23
+ def video_to_frames(video, dir):
24
+ cap = cv2.VideoCapture(str(video))
25
+ success, image = cap.read()
26
+
27
+ frame_count = 0
28
+ while success:
29
+ frameId = int(round(cap.get(1))) # current frame number
30
+ if frameId % 5 == 0:
31
+ cv2.imwrite(f"{str(dir)}/frame_{frame_count}.jpg", image)
32
+ frame_count += 1
33
+ success, image = cap.read()
34
+
35
+ cap.release()
36
+ #print (f"No. of frames {frame_count}")
37
+
38
+
39
+ # for output bounding box post-processing
40
+ def box_cxcywh_to_xyxy(x):
41
+ x_c, y_c, w, h = x.unbind(1)
42
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
43
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
44
+ return torch.stack(b, dim=1)
45
+
46
+ # rescale bboxes
47
+ def rescale_bboxes(out_bbox, size):
48
+ img_w, img_h = size
49
+ b = box_cxcywh_to_xyxy(out_bbox)
50
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
51
+ return b
52
+
53
+ # Save predicted frame
54
+ def save_results(pil_img, prob, boxes, mod_img_path):
55
+ plt.figure(figsize=(18,10))
56
+ plt.imshow(pil_img)
57
+ id2label = {0: 'balloon'}
58
+ ax = plt.gca()
59
+ colors = COLORS * 100
60
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
61
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
62
+ fill=False, color=c, linewidth=3))
63
+ cl = p.argmax()
64
+ text = f'{id2label[cl.item()]}: {p[cl]:0.2f}'
65
+ ax.text(xmin, ymin, text, fontsize=15,
66
+ bbox=dict(facecolor='yellow', alpha=0.5))
67
+ plt.axis('off')
68
+ plt.tight_layout(pad=0)
69
+ plt.savefig(mod_img_path, transparent=True)
70
+ plt.close()
71
+
72
+ # Save predictions
73
+ def save_predictions(image, outputs, mod_img_path, threshold=0.9):
74
+ # keep only predictions with confidence >= threshold
75
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
76
+ keep = probas.max(-1).values > threshold
77
+
78
+ # convert predicted boxes from [0; 1] to image scales
79
+ bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
80
+
81
+ # save results
82
+ save_results(image, probas[keep], bboxes_scaled, mod_img_path)
83
+
84
+ # Predict on frames
85
+ def predict_on_frames(dir, mod_dir):
86
+ files = [f for f in dir.glob('*.jpg') if f.is_file()]
87
+ #for sorting the file names properly
88
+ files.sort(key = lambda x: int(x.stem[6:]))
89
+
90
+ for i in stqdm(range(len(files)), desc="Generating... this is a slow task"):
91
+ filename = Path(dir, files[i])
92
+ #print(filename)
93
+ #reading each files
94
+ img = Image.open(str(filename))
95
+ # extract features
96
+ img_ftr = feature_extractor(images=img, return_tensors="pt")
97
+ pixel_values = img_ftr["pixel_values"].to(device)
98
+ # forward pass to get class logits and bounding boxes
99
+ outputs = model_pt(pixel_values=pixel_values)
100
+ mod_img_path = Path(mod_dir, files[i].name)
101
+ save_predictions(img, outputs, mod_img_path)
102
+
103
+ # Convert frames to video
104
+ def frames_to_video(dir, path, fps=5):
105
+ frame_array = []
106
+ files = [f for f in dir.glob('*.jpg') if f.is_file()]
107
+ #for sorting the file names properly
108
+ files.sort(key = lambda x: int(x.stem[6:]))
109
+ for file in files:
110
+ filename = Path(dir, file)
111
+ #reading each files
112
+ img = cv2.imread(str(filename))
113
+ height, width, _ = img.shape
114
+ size = (width, height)
115
+ #print(filename)
116
+ #inserting the frames into an image array
117
+ frame_array.append(img)
118
+ out = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
119
+ for item in frame_array:
120
+ # writing to a image array
121
+ out.write(item)
122
+ out.release()
123
+
124
+
125
+ # Main
126
+ if __name__=='__main__':
127
+ st.title('Detect Balloons using YOLOS')
128
+
129
+ # All dir and Files
130
+ BASE_DIR = Path(__file__).parent.absolute()
131
+
132
+ FRAMES_DIR = Path(BASE_DIR, "extracted_images")
133
+ MOD_DIR = Path(BASE_DIR, "modified_images")
134
+
135
+ if FRAMES_DIR.exists() and FRAMES_DIR.is_dir():
136
+ shutil.rmtree(FRAMES_DIR)
137
+ FRAMES_DIR.mkdir(parents=True, exist_ok=True)
138
+
139
+ if MOD_DIR.exists() and MOD_DIR.is_dir():
140
+ shutil.rmtree(MOD_DIR)
141
+ MOD_DIR.mkdir(parents=True, exist_ok=True)
142
+
143
+ generated_video = Path(BASE_DIR, "final_video.mp4")
144
+
145
+ # Upload the video
146
+ uploaded_file = st.file_uploader("Upload a small video containing Balloons", type=["mp4"])
147
+ if uploaded_file is not None:
148
+ st.video(uploaded_file)
149
+ vid = uploaded_file.name
150
+ st.info(f'Uploaded {vid}')
151
+ with open(vid, mode='wb') as f:
152
+ f.write(uploaded_file.read())
153
+ uploaded_video = Path(BASE_DIR, vid)
154
+
155
+ # Detect balloon in the frames and generate video
156
+ try:
157
+ video_to_frames(uploaded_video, FRAMES_DIR)
158
+ predict_on_frames(FRAMES_DIR, MOD_DIR)
159
+ frames_to_video(MOD_DIR, generated_video)
160
+ st.success("Successfully Generated!!")
161
+
162
+ # Video file Generated
163
+ video_file = open(str(generated_video), 'rb')
164
+ video_bytes = video_file.read()
165
+ st.video(video_bytes)
166
+ st.download_button('Download the Video', video_bytes, file_name=generated_video.name)
167
+ except Exception as e:
168
+ st.error(f"Could not convert the file due to {e}")
169
+ else:
170
+ st.info('File Not Uploaded Yet!!!')
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.23.3
2
+ pillow==9.2.0
3
+ opencv-python==4.6.0.66
4
+ matplotlib==3.6.1
5
+ torch==1.12.1
6
+ transformers==4.22.2
7
+ stqdm==0.0.4