luminoussg commited on
Commit
f7fcdf0
·
verified ·
1 Parent(s): 6f0d6b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -141
app.py CHANGED
@@ -1,16 +1,19 @@
1
  import gradio as gr
2
  import cv2
3
  import os
4
- import subprocess
5
  import pandas as pd
6
  import numpy as np
7
  import torch
8
  from ultralytics import YOLO
 
 
 
9
 
10
- with gr.Blocks(theme=gr.themes.Dark()) as demo:
 
 
 
11
 
12
- # Loading a YOLO model
13
- model = YOLO('yolov8x.pt')
14
  dict_classes = model.model.names
15
 
16
  # Auxiliary functions
@@ -21,153 +24,165 @@ def resize_frame(frame, scale_percent):
21
  resized = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA)
22
  return resized
23
 
24
- # Processing function
25
- def process_video(video_path, line_position):
26
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
- print(f"Using device: {device}")
28
- model.to(device)
29
-
30
- # Read video
31
- video = cv2.VideoCapture(video_path)
32
-
33
- # Scaling percentage of original frame
34
- scale_percent = 50
35
- class_IDS = [2, 3, 5, 7]
36
- cy_linha = int(line_position * scale_percent / 100)
37
- cx_sentido = int(2000 * scale_percent / 100)
38
- offset = int(8 * scale_percent / 100)
39
 
40
- # Initializing counters
41
- contador_in = 0
42
- contador_out = 0
43
- veiculos_contador_in = dict.fromkeys(class_IDS, 0)
44
- veiculos_contador_out = dict.fromkeys(class_IDS, 0)
45
 
46
- # Getting video properties
47
- height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
48
- width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
49
- fps = video.get(cv2.CAP_PROP_FPS)
50
 
51
- if scale_percent != 100:
52
- width = int(width * scale_percent / 100)
53
- height = int(height * scale_percent / 100)
54
-
55
- # Setting up video writer
56
- tmp_output_path = "tmp_output.mp4"
57
- output_video = cv2.VideoWriter(tmp_output_path,
58
- cv2.VideoWriter_fourcc(*'mp4v'),
59
- fps, (width, height))
60
-
61
- for i in range(int(video.get(cv2.CAP_PROP_FRAME_COUNT))):
62
- ret, frame = video.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if not ret:
64
  break
65
  frame = resize_frame(frame, scale_percent)
 
 
 
 
 
 
 
66
 
67
- y_hat = model.predict(frame, conf=0.7, classes=class_IDS, device=device, verbose=False)
68
-
69
- boxes = y_hat[0].boxes.xyxy.cpu().numpy()
70
- conf = y_hat[0].boxes.conf.cpu().numpy()
71
- classes = y_hat[0].boxes.cls.cpu().numpy()
72
-
73
- positions_frame = pd.DataFrame(boxes, columns=['xmin', 'ymin', 'xmax', 'ymax'])
74
- positions_frame['conf'] = conf
75
- positions_frame['class'] = classes
76
-
77
- labels = [dict_classes[i] for i in classes]
78
-
79
- cv2.line(frame, (0, cy_linha), (int(4500 * scale_percent / 100), cy_linha), (255, 255, 0), 8)
80
-
81
- for ix, row in positions_frame.iterrows():
82
- xmin, ymin, xmax, ymax, confidence, category = row.astype('int')
83
- center_x, center_y = int((xmax + xmin) / 2), int((ymax + ymin) / 2)
84
-
85
- cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (255, 0, 0), 5)
86
- cv2.circle(frame, (center_x, center_y), 5, (255, 0, 0), -1)
87
- cv2.putText(img=frame, text=labels[ix] + ' - ' + str(np.round(conf[ix], 2)),
88
- org=(xmin, ymin-10), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1, color=(255, 0, 0), thickness=2)
89
-
90
- # Adjust counting logic based on new line position
91
- if (center_y < (cy_linha + offset)) and (center_y > (cy_linha - offset)):
92
- if (center_x >= 0) and (center_x <= cx_sentido):
93
- contador_in += 1
94
- veiculos_contador_in[category] += 1
95
- else:
96
- contador_out += 1
97
- veiculos_contador_out[category] += 1
98
-
99
- contador_in_plt = [f'{dict_classes[k]}: {i}' for k, i in veiculos_contador_in.items()]
100
- contador_out_plt = [f'{dict_classes[k]}: {i}' for k, i in veiculos_contador_out.items()]
101
-
102
- cv2.putText(img=frame, text='N. vehicles In',
103
- org=(30, 30), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
104
- fontScale=1, color=(255, 255, 0), thickness=1)
105
-
106
- cv2.putText(img=frame, text='N. vehicles Out',
107
- org=(int(2800 * scale_percent / 100), 30),
108
- fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1, color=(255, 255, 0), thickness=1)
109
-
110
- xt = 40
111
- for txt in range(len(contador_in_plt)):
112
- xt += 30
113
- cv2.putText(img=frame, text=contador_in_plt[txt],
114
- org=(30, xt), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
115
- fontScale=1, color=(255, 255, 0), thickness=1)
116
-
117
- cv2.putText(img=frame, text=contador_out_plt[txt],
118
- org=(int(2800 * scale_percent / 100), xt), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
119
- fontScale=1, color=(255, 255, 0), thickness=1)
120
-
121
- cv2.putText(img=frame, text=f'In:{contador_in}',
122
- org=(int(1820 * scale_percent / 100), cy_linha + 60),
123
- fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1, color=(255, 255, 0), thickness=2)
124
-
125
- cv2.putText(img=frame, text=f'Out:{contador_out}',
126
- org=(int(1800 * scale_percent / 100), cy_linha - 40),
127
- fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1, color=(255, 255, 0), thickness=2)
128
-
129
- output_video.write(frame)
130
-
131
- output_video.release()
132
-
133
- # Post-processing
134
- output_path = "output.mp4"
135
- if os.path.exists(output_path):
136
- os.remove(output_path)
137
-
138
  subprocess.run(
139
- ["ffmpeg", "-i", tmp_output_path, "-crf", "18", "-preset", "veryfast", "-hide_banner", "-loglevel", "error", "-vcodec", "libx264", output_path])
 
140
  os.remove(tmp_output_path)
141
-
142
  return output_path
143
 
144
- # Gradio interface
145
- with gr.Blocks() as demo:
146
- video_input = gr.File(label="Upload your video")
147
- line_position = gr.Slider(0, 3000, value=1500, label="Line Position (px)")
148
- preview_button = gr.Button("Preview Line")
149
- process_button = gr.Button("Process Video")
150
- video_output = gr.Video(label="Processed Video")
151
- download_button = gr.File(label="Download Processed Video")
152
-
153
- def preview_line(video, line_position):
154
- video = cv2.VideoCapture(video.name)
155
- ret, frame = video.read()
156
- if ret:
157
- scale_percent = 50
158
- cy_linha = int(line_position * scale_percent / 100)
159
- frame = resize_frame(frame, scale_percent)
160
- cv2.line(frame, (0, cy_linha), (int(4500 * scale_percent / 100), cy_linha), (255, 255, 0), 8)
161
- output_path = "preview_line.jpg"
162
- cv2.imwrite(output_path, frame)
163
- return output_path
164
- return None
165
-
166
- def process_video_and_display(video, line_position):
167
- output_path = process_video(video.name, line_position)
168
- return output_path, output_path
169
 
170
- preview_button.click(preview_line, inputs=[video_input, line_position], outputs=gr.Image(label="Preview Line"))
171
- process_button.click(process_video_and_display, inputs=[video_input, line_position], outputs=[video_output, download_button])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- demo.launch()
 
1
  import gradio as gr
2
  import cv2
3
  import os
 
4
  import pandas as pd
5
  import numpy as np
6
  import torch
7
  from ultralytics import YOLO
8
+ from ultralytics.solutions import object_counter
9
+ import subprocess
10
+ import spaces # Import spaces for ZeroGPU integration
11
 
12
+ # Initialize the YOLO model
13
+ MODEL = "yolov8n.pt"
14
+ model = YOLO(MODEL)
15
+ model.fuse()
16
 
 
 
17
  dict_classes = model.model.names
18
 
19
  # Auxiliary functions
 
24
  resized = cv2.resize(frame, dim, interpolation=cv2.INTER_AREA)
25
  return resized
26
 
27
+ @spaces.GPU
28
+ def process_video(video_file, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness, draw_tracks, view_img, view_in_counts, view_out_counts, track_thickness, region_thickness, line_dist_thresh, persist, conf, iou, classes, verbose):
29
+ # Ensure classes is a list of integers
30
+ classes = [int(x) for x in classes.split(',') if x.strip().isdigit()] if classes else None
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ line_points = [(line_start_x, line_start_y), (line_end_x, line_end_y)]
 
 
 
 
33
 
34
+ cap = cv2.VideoCapture(video_file)
35
+ if not cap.isOpened():
36
+ raise ValueError("Failed to open video file")
 
37
 
38
+ tmp_output_path = "processed_output_temp.mp4"
39
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale_percent / 100)
40
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale_percent / 100)
41
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
42
+ video_writer = cv2.VideoWriter(tmp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
43
+
44
+ counter = object_counter.ObjectCounter(
45
+ classes_names=model.names,
46
+ view_img=view_img,
47
+ reg_pts=line_points,
48
+ draw_tracks=draw_tracks,
49
+ line_thickness=int(line_thickness),
50
+ track_thickness=int(track_thickness),
51
+ region_thickness=int(region_thickness),
52
+ line_dist_thresh=line_dist_thresh,
53
+ view_in_counts=view_in_counts,
54
+ view_out_counts=view_out_counts,
55
+ count_reg_color=(255, 0, 255), # Magenta
56
+ track_color=(0, 255, 0), # Green
57
+ count_txt_color=(255, 255, 255), # White
58
+ count_bg_color=(50, 50, 50) # Dark gray
59
+ )
60
+
61
+ prev_frame = None
62
+ prev_keypoints = None
63
+
64
+ while cap.isOpened():
65
+ ret, frame = cap.read()
66
  if not ret:
67
  break
68
  frame = resize_frame(frame, scale_percent)
69
+
70
+ # Adjust line points based on scaling
71
+ scaled_line_points = [(int(x * scale_percent / 100), int(y * scale_percent / 100)) for x, y in line_points]
72
+ for point1, point2 in zip(scaled_line_points[:-1], scaled_line_points[1:]):
73
+ cv2.line(frame, tuple(map(int, point1)), tuple(map(int, point2)), (255, 255, 0), int(line_thickness))
74
+
75
+ tracks = model.track(frame, persist=persist, conf=conf, iou=iou, classes=classes, verbose=verbose)
76
 
77
+ # Update the counter with the current frame and tracks
78
+ frame = counter.start_counting(frame, tracks)
79
+
80
+ # Check if the previous frame is initialized for optical flow calculation
81
+ if prev_frame is not None:
82
+ try:
83
+ prev_frame_resized = resize_frame(prev_frame, scale_percent)
84
+ matched_keypoints, status, _ = cv2.calcOpticalFlowPyrLK(prev_frame_resized, frame, prev_keypoints, None)
85
+ prev_keypoints = matched_keypoints
86
+ except cv2.error as e:
87
+ print(f"Error in optical flow calculation: {e}")
88
+
89
+ prev_frame = frame.copy()
90
+ prev_keypoints = cv2.goodFeaturesToTrack(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), maxCorners=100, qualityLevel=0.3, minDistance=7, blockSize=7)
91
+
92
+ video_writer.write(frame)
93
+
94
+ cap.release()
95
+ video_writer.release()
96
+
97
+ # Reduce the resolution of the video for download
98
+ output_path = "processed_output.mp4"
99
+ if h > 1080:
100
+ resolution = "1920x1080"
101
+ else:
102
+ resolution = "1280x720"
103
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  subprocess.run(
105
+ ["ffmpeg", "-y", "-i", tmp_output_path, "-vf", f"scale={resolution}", "-crf", "18", "-preset", "veryfast", "-hide_banner", "-loglevel", "error", output_path]
106
+ )
107
  os.remove(tmp_output_path)
108
+
109
  return output_path
110
 
111
+ def preview_line(video_file, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness):
112
+ cap = cv2.VideoCapture(video_file)
113
+ ret, frame = cap.read()
114
+ if not ret:
115
+ raise ValueError("Failed to read video frame")
116
+
117
+ frame = resize_frame(frame, scale_percent)
118
+ line_points = [(line_start_x, line_start_y), (line_end_x, line_end_y)]
119
+ scaled_line_points = [(int(x * scale_percent / 100), int(y * scale_percent / 100)) for x, y in line_points]
120
+ for point1, point2 in zip(scaled_line_points[:-1], scaled_line_points[1:]):
121
+ cv2.line(frame, tuple(map(int, point1)), tuple(map(int, point2)), (255, 255, 0), int(line_thickness))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ preview_path = "preview_line.jpg"
124
+ cv2.imwrite(preview_path, frame)
125
+ return preview_path
126
+
127
+ def gradio_app(video, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness, draw_tracks, view_img, view_in_counts, view_out_counts, track_thickness, region_thickness, line_dist_thresh, persist, conf, iou, classes_to_track, verbose):
128
+ output_path = process_video(video.name, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, int(line_thickness), draw_tracks, view_img, view_in_counts, view_out_counts, int(track_thickness), int(region_thickness), line_dist_thresh, persist, conf, iou, classes_to_track, verbose)
129
+ return output_path, output_path
130
+
131
+ def update_preview(video, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness):
132
+ return preview_line(video.name, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, int(line_thickness))
133
+
134
+ def set_4k_coordinates():
135
+ return 0, 1500, 3840, 1500
136
+
137
+ def set_1080p_coordinates():
138
+ return 0, 700, 1920, 700
139
+
140
+ with gr.Blocks(theme="dark") as demo:
141
+ with gr.Row():
142
+ with gr.Column(scale=1):
143
+ video_input = gr.File(label="Upload your video")
144
+ with gr.Row():
145
+ set_4k_button = gr.Button("4K")
146
+ set_1080p_button = gr.Button("1080p")
147
+ line_start_x = gr.Number(label="Line Start X", value=500, precision=0)
148
+ line_start_y = gr.Number(label="Line Start Y", value=1500, precision=0)
149
+ line_end_x = gr.Number(label="Line End X", value=3400, precision=0)
150
+ line_end_y = gr.Number(label="Line End Y", value=1500, precision=0)
151
+ line_thickness = gr.Slider(minimum=1, maximum=10, value=2, label="Line Thickness")
152
+ draw_tracks = gr.Checkbox(label="Draw Tracks", value=True)
153
+ view_img = gr.Checkbox(label="Display Image with Annotations", value=True)
154
+ view_in_counts = gr.Checkbox(label="Display In-Counts", value=True)
155
+ view_out_counts = gr.Checkbox(label="Display Out-Counts", value=True)
156
+ track_thickness = gr.Slider(minimum=1, maximum=10, value=2, label="Track Thickness")
157
+ region_thickness = gr.Slider(minimum=1, maximum=10, value=5, label="Region Thickness")
158
+ line_dist_thresh = gr.Slider(minimum=5, maximum=50, value=15, label="Line Distance Threshold")
159
+ persist = gr.Checkbox(label="Persist Tracks", value=True)
160
+ conf = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Confidence Threshold")
161
+ iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="IOU Threshold")
162
+ classes_to_track = gr.Textbox(label="Classes to Track (comma-separated ids)", value="2,3,5,7")
163
+ verbose = gr.Checkbox(label="Verbose Tracking", value=True)
164
+ scale_percent = gr.Slider(minimum=10, maximum=100, value=100, step=10, label="Scale Percentage")
165
+ process_button = gr.Button("Process Video")
166
+ with gr.Column(scale=2):
167
+ preview_image = gr.Image(label="Preview Line")
168
+ video_output = gr.Video(label="Processed Video")
169
+ download_button = gr.File(label="Download Processed Video")
170
+
171
+ def update_preview_and_display(video, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness):
172
+ preview_path = update_preview(video, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness)
173
+ return preview_path
174
+
175
+ video_input.change(update_preview_and_display, inputs=[video_input, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness], outputs=preview_image)
176
+ for component in [scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness, draw_tracks, view_img, view_in_counts, view_out_counts, track_thickness, region_thickness, line_dist_thresh, persist, conf, iou, classes_to_track, verbose]:
177
+ component.change(update_preview_and_display, inputs=[video_input, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness], outputs=preview_image)
178
+
179
+ set_4k_button.click(lambda: set_4k_coordinates(), outputs=[line_start_x, line_start_y, line_end_x, line_end_y])
180
+ set_1080p_button.click(lambda: set_1080p_coordinates(), outputs=[line_start_x, line_start_y, line_end_x, line_end_y])
181
+
182
+ def clear_previous_video():
183
+ return None, None
184
+
185
+ process_button.click(clear_previous_video, outputs=[video_output, download_button], queue=False)
186
+ process_button.click(gradio_app, inputs=[video_input, scale_percent, line_start_x, line_start_y, line_end_x, line_end_y, line_thickness, draw_tracks, view_img, view_in_counts, view_out_counts, track_thickness, region_thickness, line_dist_thresh, persist, conf, iou, classes_to_track, verbose], outputs=[video_output, download_button])
187
 
188
+ demo.launch()