hasnanmr commited on
Commit
bc26e93
·
1 Parent(s): d2d613d

add app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import torch
3
+ # from facenet_pytorch import MTCNN
4
+ # import pickle
5
+ # import cv2
6
+ # from PIL import Image
7
+ # import numpy as np
8
+ # from transformers import ViTImageProcessor, ViTModel
9
+ # import torch.nn as nn
10
+ # from torchvision import transforms
11
+ # from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
12
+ # import av
13
+
14
+ # class ViT(nn.Module):
15
+ # def __init__(self, base_model):
16
+ # super(ViT, self).__init__()
17
+ # self.base_model = base_model
18
+
19
+ # def forward(self, x):
20
+ # x = self.base_model(x).pooler_output
21
+ # return x
22
+
23
+ # @st.cache_resource
24
+ # def load_model():
25
+ # model_name = "google/vit-base-patch16-224"
26
+ # processor = ViTImageProcessor.from_pretrained(model_name)
27
+ # base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
28
+ # model = ViT(base_model)
29
+ # model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu')))
30
+ # model.eval()
31
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ # model.to(device)
33
+ # return model, processor, device
34
+
35
+ import gradio as gr
36
+ import cv2
37
+ import torch
38
+ from facenet_pytorch import MTCNN
39
+
40
+ # Load MTCNN for face detection
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+ mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
43
+
44
+ def align_faces(frame, mtcnn, device):
45
+ boxes, _ = mtcnn.detect(frame)
46
+ aligned_faces = []
47
+ if boxes is not None:
48
+ aligned_faces = mtcnn(frame)
49
+ if aligned_faces is not None:
50
+ aligned_faces = aligned_faces.to(device)
51
+ return aligned_faces, boxes
52
+
53
+ def draw_annotations(frame, detections, names=None):
54
+ if detections is None:
55
+ return frame
56
+ if names is None:
57
+ names = ["Unknown"] * len(detections)
58
+ for i, detection in enumerate(detections):
59
+ x1, y1, x2, y2 = map(int, detection)
60
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
61
+ if names[i]:
62
+ cv2.putText(frame, names[i], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
63
+ return frame
64
+
65
+ def capture_frames():
66
+ cap = cv2.VideoCapture(0)
67
+
68
+ if not cap.isOpened():
69
+ raise RuntimeError("Error: Could not open video stream.")
70
+
71
+ while True:
72
+ ret, frame = cap.read()
73
+
74
+ if not ret:
75
+ raise RuntimeError("Error: Failed to capture image")
76
+
77
+ # Align faces using MTCNN
78
+ aligned_faces, boxes = align_faces(frame, mtcnn, device)
79
+
80
+ # Draw annotations on the frame
81
+ annotated_frame = draw_annotations(frame, boxes)
82
+
83
+ _, buffer = cv2.imencode('.jpg', annotated_frame)
84
+ frame_bytes = buffer.tobytes()
85
+
86
+ yield frame_bytes
87
+
88
+ def video_frame_generator():
89
+ for frame in capture_frames():
90
+ yield frame
91
+
92
+ def gradio_interface():
93
+ with gr.Blocks() as demo:
94
+ with gr.Row():
95
+ webcam_output = gr.Video(source=video_frame_generator, streaming=True, label="Webcam Output")
96
+ stop_button = gr.Button("Stop")
97
+
98
+ def stop_streaming():
99
+ # Placeholder for stopping streaming if necessary
100
+ return "Streaming stopped."
101
+
102
+ stop_button.click(fn=stop_streaming, inputs=None, outputs=None)
103
+
104
+ demo.launch(share=True, debug=True)
105
+
106
+ if __name__ == "__main__":
107
+ gradio_interface()