Spaces:
Sleeping
Sleeping
# import streamlit as st | |
# import torch | |
# from facenet_pytorch import MTCNN | |
# import pickle | |
# import cv2 | |
# from PIL import Image | |
# import numpy as np | |
# from transformers import ViTImageProcessor, ViTModel | |
# import torch.nn as nn | |
# from torchvision import transforms | |
# from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode | |
# import av | |
# class ViT(nn.Module): | |
# def __init__(self, base_model): | |
# super(ViT, self).__init__() | |
# self.base_model = base_model | |
# def forward(self, x): | |
# x = self.base_model(x).pooler_output | |
# return x | |
# @st.cache_resource | |
# def load_model(): | |
# model_name = "google/vit-base-patch16-224" | |
# processor = ViTImageProcessor.from_pretrained(model_name) | |
# base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224") | |
# model = ViT(base_model) | |
# model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu'))) | |
# model.eval() | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# model.to(device) | |
# return model, processor, device | |
import gradio as gr | |
import cv2 | |
import torch | |
from facenet_pytorch import MTCNN | |
# Load MTCNN for face detection | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device) | |
def align_faces(frame, mtcnn, device): | |
boxes, _ = mtcnn.detect(frame) | |
aligned_faces = [] | |
if boxes is not None: | |
aligned_faces = mtcnn(frame) | |
if aligned_faces is not None: | |
aligned_faces = aligned_faces.to(device) | |
return aligned_faces, boxes | |
def draw_annotations(frame, detections, names=None): | |
if detections is None: | |
return frame | |
if names is None: | |
names = ["Unknown"] * len(detections) | |
for i, detection in enumerate(detections): | |
x1, y1, x2, y2 = map(int, detection) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
if names[i]: | |
cv2.putText(frame, names[i], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2) | |
return frame | |
def capture_frames(): | |
cap = cv2.VideoCapture(0) | |
if not cap.isOpened(): | |
raise RuntimeError("Error: Could not open video stream.") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
raise RuntimeError("Error: Failed to capture image") | |
# Align faces using MTCNN | |
aligned_faces, boxes = align_faces(frame, mtcnn, device) | |
# Draw annotations on the frame | |
annotated_frame = draw_annotations(frame, boxes) | |
_, buffer = cv2.imencode('.jpg', annotated_frame) | |
frame_bytes = buffer.tobytes() | |
yield frame_bytes | |
def video_frame_generator(): | |
for frame in capture_frames(): | |
yield frame | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
webcam_output = gr.Video(source=video_frame_generator, streaming=True, label="Webcam Output") | |
stop_button = gr.Button("Stop") | |
def stop_streaming(): | |
# Placeholder for stopping streaming if necessary | |
return "Streaming stopped." | |
stop_button.click(fn=stop_streaming, inputs=None, outputs=None) | |
demo.launch(share=True, debug=True) | |
if __name__ == "__main__": | |
gradio_interface() |