hasnanmr's picture
add app.py
bc26e93
raw
history blame
3.38 kB
# import streamlit as st
# import torch
# from facenet_pytorch import MTCNN
# import pickle
# import cv2
# from PIL import Image
# import numpy as np
# from transformers import ViTImageProcessor, ViTModel
# import torch.nn as nn
# from torchvision import transforms
# from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
# import av
# class ViT(nn.Module):
# def __init__(self, base_model):
# super(ViT, self).__init__()
# self.base_model = base_model
# def forward(self, x):
# x = self.base_model(x).pooler_output
# return x
# @st.cache_resource
# def load_model():
# model_name = "google/vit-base-patch16-224"
# processor = ViTImageProcessor.from_pretrained(model_name)
# base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
# model = ViT(base_model)
# model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu')))
# model.eval()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)
# return model, processor, device
import gradio as gr
import cv2
import torch
from facenet_pytorch import MTCNN
# Load MTCNN for face detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
def align_faces(frame, mtcnn, device):
boxes, _ = mtcnn.detect(frame)
aligned_faces = []
if boxes is not None:
aligned_faces = mtcnn(frame)
if aligned_faces is not None:
aligned_faces = aligned_faces.to(device)
return aligned_faces, boxes
def draw_annotations(frame, detections, names=None):
if detections is None:
return frame
if names is None:
names = ["Unknown"] * len(detections)
for i, detection in enumerate(detections):
x1, y1, x2, y2 = map(int, detection)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
if names[i]:
cv2.putText(frame, names[i], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
return frame
def capture_frames():
cap = cv2.VideoCapture(0)
if not cap.isOpened():
raise RuntimeError("Error: Could not open video stream.")
while True:
ret, frame = cap.read()
if not ret:
raise RuntimeError("Error: Failed to capture image")
# Align faces using MTCNN
aligned_faces, boxes = align_faces(frame, mtcnn, device)
# Draw annotations on the frame
annotated_frame = draw_annotations(frame, boxes)
_, buffer = cv2.imencode('.jpg', annotated_frame)
frame_bytes = buffer.tobytes()
yield frame_bytes
def video_frame_generator():
for frame in capture_frames():
yield frame
def gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
webcam_output = gr.Video(source=video_frame_generator, streaming=True, label="Webcam Output")
stop_button = gr.Button("Stop")
def stop_streaming():
# Placeholder for stopping streaming if necessary
return "Streaming stopped."
stop_button.click(fn=stop_streaming, inputs=None, outputs=None)
demo.launch(share=True, debug=True)
if __name__ == "__main__":
gradio_interface()