# import streamlit as st # import torch # from facenet_pytorch import MTCNN # import pickle # import cv2 # from PIL import Image # import numpy as np # from transformers import ViTImageProcessor, ViTModel # import torch.nn as nn # from torchvision import transforms # from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode # import av # class ViT(nn.Module): # def __init__(self, base_model): # super(ViT, self).__init__() # self.base_model = base_model # def forward(self, x): # x = self.base_model(x).pooler_output # return x # @st.cache_resource # def load_model(): # model_name = "google/vit-base-patch16-224" # processor = ViTImageProcessor.from_pretrained(model_name) # base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224") # model = ViT(base_model) # model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu'))) # model.eval() # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # model.to(device) # return model, processor, device import gradio as gr import cv2 import torch from facenet_pytorch import MTCNN # Load MTCNN for face detection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device) def align_faces(frame, mtcnn, device): boxes, _ = mtcnn.detect(frame) aligned_faces = [] if boxes is not None: aligned_faces = mtcnn(frame) if aligned_faces is not None: aligned_faces = aligned_faces.to(device) return aligned_faces, boxes def draw_annotations(frame, detections, names=None): if detections is None: return frame if names is None: names = ["Unknown"] * len(detections) for i, detection in enumerate(detections): x1, y1, x2, y2 = map(int, detection) cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) if names[i]: cv2.putText(frame, names[i], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2) return frame def capture_frames(): cap = cv2.VideoCapture(0) if not cap.isOpened(): raise RuntimeError("Error: Could not open video stream.") while True: ret, frame = cap.read() if not ret: raise RuntimeError("Error: Failed to capture image") # Align faces using MTCNN aligned_faces, boxes = align_faces(frame, mtcnn, device) # Draw annotations on the frame annotated_frame = draw_annotations(frame, boxes) _, buffer = cv2.imencode('.jpg', annotated_frame) frame_bytes = buffer.tobytes() yield frame_bytes def video_frame_generator(): for frame in capture_frames(): yield frame def gradio_interface(): with gr.Blocks() as demo: with gr.Row(): webcam_output = gr.Video(source=video_frame_generator, streaming=True, label="Webcam Output") stop_button = gr.Button("Stop") def stop_streaming(): # Placeholder for stopping streaming if necessary return "Streaming stopped." stop_button.click(fn=stop_streaming, inputs=None, outputs=None) demo.launch(share=True, debug=True) if __name__ == "__main__": gradio_interface()