ankanpy's picture
Update app.py
e1c31ae verified
raw
history blame
9.53 kB
import cv2
import numpy as np
import time
import os
import matplotlib.pyplot as plt
import gradio as gr
# try:
# from pygame import mixer
# mixer_init = True
# except ModuleNotFoundError:
# mixer = None
# mixer_init = False
# ------------------------------------------------------------------------------
# 1. Initializations.
# ------------------------------------------------------------------------------
# Initialize counter for the number of blinks detected.
BLINK = 0
# Model file paths.
MODEL_PATH = "./model/res10_300x300_ssd_iter_140000.caffemodel"
CONFIG_PATH = "./model/deploy.prototxt"
LBF_MODEL = "./model/lbfmodel.yaml"
# Create a face detector network instance.
net = cv2.dnn.readNetFromCaffe(CONFIG_PATH, MODEL_PATH)
# Create the landmark detector instance.
landmarkDetector = cv2.face.createFacemarkLBF()
landmarkDetector.loadModel(LBF_MODEL)
# ------------------------------------------------------------------------------
# 2. Function definitions.
# ------------------------------------------------------------------------------
def detect_faces(image, detection_threshold=0.70):
blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123])
net.setInput(blob)
detections = net.forward()
faces = []
img_h = image.shape[0]
img_w = image.shape[1]
for detection in detections[0][0]:
if detection[2] >= detection_threshold:
left = detection[3] * img_w
top = detection[4] * img_h
right = detection[5] * img_w
bottom = detection[6] * img_h
face_w = right - left
face_h = bottom - top
face_roi = (left, top, face_w, face_h)
faces.append(face_roi)
return np.array(faces).astype(int)
def get_primary_face(faces, frame_h, frame_w):
primary_face_index = None
face_height_max = 0
for idx in range(len(faces)):
face = faces[idx]
x1 = face[0]
y1 = face[1]
x2 = x1 + face[2]
y2 = y1 + face[3]
if x1 > frame_w or y1 > frame_h or x2 > frame_w or y2 > frame_h:
continue
if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
continue
# Prioritize the face with the maximum height.
if face[3] > face_height_max:
primary_face_index = idx
face_height_max = face[3]
if primary_face_index is not None:
primary_face = faces[primary_face_index]
else:
primary_face = None
return primary_face
def visualize_eyes(landmarks, frame):
for i in range(36, 48):
cv2.circle(frame, tuple(landmarks[i].astype("int")), 2, (0, 255, 0), -1)
def get_eye_aspect_ratio(landmarks):
vert_dist_1right = calculate_distance(landmarks[37], landmarks[41])
vert_dist_2right = calculate_distance(landmarks[38], landmarks[40])
vert_dist_1left = calculate_distance(landmarks[43], landmarks[47])
vert_dist_2left = calculate_distance(landmarks[44], landmarks[46])
horz_dist_right = calculate_distance(landmarks[36], landmarks[39])
horz_dist_left = calculate_distance(landmarks[42], landmarks[45])
EAR_left = (vert_dist_1left + vert_dist_2left) / (2.0 * horz_dist_left)
EAR_right = (vert_dist_1right + vert_dist_2right) / (2.0 * horz_dist_right)
ear = (EAR_left + EAR_right) / 2
return ear
def calculate_distance(A, B):
distance = ((A[0] - B[0]) ** 2 + (A[1] - B[1]) ** 2) ** 0.5
return distance
# def play(file):
# if mixer_init:
# mixer.init()
# sound = mixer.Sound(file)
# sound.play()
# ------------------------------------------------------------------------------
# 3. Processing function (to be used in Gradio).
# ------------------------------------------------------------------------------
def process_video(input_video):
# Generate unique filenames for the outputs
out_video_filename = "processed_video.mp4"
out_plot_filename = "ear_plot.png"
cap = cv2.VideoCapture(input_video)
ret, frame = cap.read()
if not ret:
print("Cannot read the input video.")
return None, None
frame_h = frame.shape[0]
frame_w = frame.shape[1]
# Initialize writer for processed video
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 30
out_writer = cv2.VideoWriter(out_video_filename, fourcc, fps, (frame_w, frame_h))
# Calibration
frame_count = 0
frame_calib = 30 # Number of frames to use for threshold calibration.
sum_ear = 0
BLINK = 0
state_prev = state_curr = "open"
ear_values = []
while True:
ret, frame = cap.read()
if not ret:
break
# Detect Face.
faces = detect_faces(frame, detection_threshold=0.90)
if len(faces) > 0:
# Use primary face
primary_face = get_primary_face(faces, frame_h, frame_w)
if primary_face is not None:
cv2.rectangle(
frame,
(primary_face[0], primary_face[1]),
(primary_face[0] + primary_face[2], primary_face[1] + primary_face[3]),
(0, 255, 0),
3,
)
# Detect Landmarks
retval, landmarksList = landmarkDetector.fit(frame, np.expand_dims(primary_face, 0))
if retval:
landmarks = landmarksList[0][0]
# Display detections.
visualize_eyes(landmarks, frame)
# Get EAR
ear = get_eye_aspect_ratio(landmarks)
ear_values.append(ear)
if frame_count < frame_calib:
frame_count += 1
sum_ear += ear
elif frame_count == frame_calib:
frame_count += 1
avg_ear = sum_ear / frame_count
HIGHER_TH = 0.90 * avg_ear
LOWER_TH = 0.80 * HIGHER_TH
print("SET EAR HIGH: ", HIGHER_TH)
print("SET EAR LOW: ", LOWER_TH)
else:
if ear < LOWER_TH:
state_curr = "closed"
elif ear > HIGHER_TH:
state_curr = "open"
if state_prev == "closed" and state_curr == "open":
BLINK += 1
# if mixer_init:
# play("./click.wav")
state_prev = state_curr
cv2.putText(
frame,
f"Blink Counter: {BLINK}",
(10, 80),
cv2.FONT_HERSHEY_SIMPLEX,
1.5,
(0, 0, 255),
4,
cv2.LINE_AA,
)
else:
# No valid face detected
pass
else:
# No faces
pass
frame_out_final = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
out_writer.write(frame)
yield frame_out_final, None, None
cap.release()
out_writer.release()
# Plot EAR values if collected
if ear_values:
plt.figure(figsize=(10, 5.625))
plt.plot(ear_values, label="EAR")
plt.title("Eye Aspect Ratio (EAR) over time")
plt.xlabel("Frame Index")
plt.ylabel("EAR")
plt.legend()
plt.grid(True)
plt.savefig(out_plot_filename)
plt.close()
else:
out_plot_filename = None
yield None, out_video_filename, out_plot_filename
# ------------------------------------------------------------------------------
# 4. Gradio UI
# ------------------------------------------------------------------------------
def process_gradio(video_file):
if video_file is None:
return None, None, None
video_path = video_file
output_frames = None
processed_video = None
plot_img = None
# Process video using generator
for frame_out, processed_video_path, plot_path in process_video(video_path):
if frame_out is not None:
output_frames = frame_out # Update frames dynamically
yield output_frames, None, None # Gradio updates frames step-by-step
else:
processed_video = processed_video_path
plot_img = plot_path
# Final yield with processed video and EAR plot
yield None, processed_video, plot_img
with gr.Blocks() as demo:
gr.Markdown("# Blink Detection with OpenCV")
gr.Markdown("Upload a video to detect blinks and view the EAR plot after processing.")
with gr.Row():
video_input = gr.Video(label="Input Video")
output_frames = gr.Image(label="Output Frames")
process_btn = gr.Button("Process")
with gr.Row():
processed_video = gr.Video(label="Processed Video")
ear_plot = gr.Image(label="EAR Plot")
process_btn.click(process_gradio, inputs=video_input, outputs=[output_frames, processed_video, ear_plot])
examples = [
["./input-video.mp4"],
]
with gr.Row():
gr.Examples(
examples=examples,
inputs=[video_input],
label="Load Example Video",
)
if __name__ == "__main__":
demo.launch()