Omarrran commited on
Commit
ce0c645
·
verified ·
1 Parent(s): ffcc387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -47
app.py CHANGED
@@ -1,64 +1,231 @@
1
  """
2
- Script to download YOLOv8 model and save it to the models directory
3
- This ensures the model is included in the repository for Hugging Face Spaces
 
 
4
  """
5
 
 
 
 
 
6
  import os
 
 
7
  from ultralytics import YOLO
8
 
9
- def download_model(model_name="yolov8n.pt", save_dir="models"):
 
 
 
 
 
 
10
  """
11
- Download a YOLOv8 model and save it to the specified directory
12
-
13
- Args:
14
- model_name: Name of the YOLOv8 model to download
15
- save_dir: Directory to save the model to
16
  """
17
- # Create models directory if it doesn't exist
18
- if not os.path.exists(save_dir):
19
- os.makedirs(save_dir)
20
- print(f"Created directory: {save_dir}")
21
-
22
- # Full path to save the model
23
- model_path = os.path.join(save_dir, model_name)
24
-
25
- # Check if model already exists
26
- if os.path.exists(model_path):
27
- print(f"Model already exists at: {model_path}")
28
- # Load model to verify it's valid
 
 
 
 
 
 
 
 
 
 
29
  try:
30
- model = YOLO(model_path)
 
 
31
  print("Model loaded successfully")
32
- return model_path
33
  except Exception as e:
34
- print(f"Error loading existing model: {e}")
35
- print("Downloading new model...")
36
-
37
- # Download the model
38
- try:
39
- print(f"Downloading model: {model_name}")
40
- model = YOLO(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Save the model to the specified path
43
- model_file = model.export(format="torchscript")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Rename to the original model name
46
- if os.path.exists(model_file) and model_file != model_path:
47
- # Copy the model file to the specified path
48
- import shutil
49
- shutil.copy(model_file, model_path)
50
- print(f"Model saved to: {model_path}")
51
 
52
- return model_path
53
- except Exception as e:
54
- print(f"Error downloading model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return None
 
 
 
 
 
 
 
 
56
 
57
- if __name__ == "__main__":
58
- # Download YOLOv8n model
59
- model_path = download_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- if model_path:
62
- print(f"Model successfully downloaded to: {model_path}")
63
- else:
64
- print("Failed to download model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Phone Detection App for Hugging Face Spaces
3
+
4
+ This app uses YOLOv8 to detect phones in real-time through a webcam feed.
5
+ When a phone is detected, a warning message is displayed.
6
  """
7
 
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import time
12
  import os
13
+ import gradio as gr
14
+ from PIL import Image, ImageDraw, ImageFont
15
  from ultralytics import YOLO
16
 
17
+ # Configurations
18
+ MODEL_PATH = "models/yolov8n.pt" # Path to the model within the repository
19
+ TARGET_CLASS = "cell phone"
20
+ TARGET_CLASS_ID = 67 # In YOLOv8's COCO dataset
21
+ MIN_CONFIDENCE = 0.4 # Minimum confidence threshold for detections
22
+
23
+ class PhoneDetector:
24
  """
25
+ A class to handle phone detection using YOLOv8 model
 
 
 
 
26
  """
27
+ def __init__(self, model_path=MODEL_PATH, confidence=MIN_CONFIDENCE):
28
+ """
29
+ Initialize the phone detector
30
+
31
+ Args:
32
+ model_path: Path to the YOLOv8 model weights
33
+ confidence: Minimum confidence threshold for detections
34
+ """
35
+ self.target_class = TARGET_CLASS
36
+ self.target_class_id = TARGET_CLASS_ID
37
+ self.min_confidence = confidence
38
+
39
+ # Select device (GPU if available, otherwise CPU)
40
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {self.device}")
42
+
43
+ # Check if model exists, otherwise use default YOLOv8n
44
+ if not os.path.exists(model_path):
45
+ print(f"Model not found at {model_path}, using default YOLOv8n")
46
+ model_path = "yolov8n.pt" # Will be downloaded automatically by YOLO
47
+
48
+ # Load model
49
  try:
50
+ print(f"Loading YOLOv8 model from {model_path}...")
51
+ self.model = YOLO(model_path)
52
+ self.model.to(self.device)
53
  print("Model loaded successfully")
 
54
  except Exception as e:
55
+ print(f"Error loading model: {e}")
56
+ print("Loading default YOLOv8n model...")
57
+ self.model = YOLO("yolov8n.pt")
58
+ self.model.to(self.device)
59
+
60
+ def detect(self, frame):
61
+ """
62
+ Detect phones in a frame and add visualization
63
+
64
+ Args:
65
+ frame: Input image frame (numpy array)
66
+
67
+ Returns:
68
+ Processed frame with detection visualization
69
+ """
70
+ if frame is None:
71
+ return None
72
+
73
+ # Convert to RGB if grayscale
74
+ if len(frame.shape) == 2:
75
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
76
+ elif frame.shape[2] == 4: # If RGBA
77
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
78
+
79
+ # Get frame dimensions
80
+ (h, w) = frame.shape[:2]
81
+
82
+ # Convert to PIL Image for easier text rendering
83
+ pil_image = Image.fromarray(frame)
84
+ draw = ImageDraw.Draw(pil_image)
85
+
86
+ # Try to load a nicer font, fall back to default if not available
87
+ try:
88
+ font = ImageFont.truetype("DejaVuSans.ttf", 25)
89
+ small_font = ImageFont.truetype("DejaVuSans.ttf", 15)
90
+ except IOError:
91
+ font = ImageFont.load_default()
92
+ small_font = ImageFont.load_default()
93
+
94
+ # Perform detection with YOLOv8
95
+ with torch.no_grad(): # Disable gradient calculation for inference
96
+ results = self.model.predict(frame, conf=self.min_confidence, verbose=False)
97
+
98
+ # Flag to track if a phone is detected in this frame
99
+ phone_detected = False
100
+
101
+ # Process detection results
102
+ if len(results) > 0:
103
+ for result in results:
104
+ boxes = result.boxes
105
+ for box in boxes:
106
+ # Get class ID
107
+ cls_id = int(box.cls[0].item())
108
+ class_name = result.names[cls_id]
109
+
110
+ # Check if the detected object is a cell phone
111
+ if class_name == self.target_class or cls_id == self.target_class_id:
112
+ phone_detected = True
113
+
114
+ # Get confidence score
115
+ conf = float(box.conf[0].item())
116
+
117
+ # Get bounding box coordinates
118
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
119
+
120
+ # Draw bounding box on PIL image
121
+ draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=3)
122
+
123
+ # Display confidence and class
124
+ label = f"{class_name}: {conf:.2f}"
125
+ y_label = y1 - 15 if y1 - 15 > 15 else y1 + 15
126
+ draw.text((x1, y_label), label, fill="red", font=small_font)
127
 
128
+ # Display warning message if phone is detected
129
+ if phone_detected:
130
+ warning_text = "WARNING: Phone Detected!"
131
+
132
+ # Measure text size for centering (implementation differs based on PIL version)
133
+ try:
134
+ # For newer PIL versions
135
+ text_width = draw.textlength(warning_text, font=font)
136
+ except AttributeError:
137
+ # For older PIL versions
138
+ text_width = font.getmask(warning_text).getbbox()[2]
139
+
140
+ text_x = (w - text_width) // 2
141
+ text_y = h // 2
142
+
143
+ # Draw semi-transparent red rectangle for warning
144
+ overlay = Image.new('RGBA', pil_image.size, (0, 0, 0, 0))
145
+ overlay_draw = ImageDraw.Draw(overlay)
146
+ overlay_draw.rectangle([(0, text_y - 40), (w, text_y + 10)], fill=(255, 0, 0, 128))
147
+ pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay).convert('RGB')
148
+ draw = ImageDraw.Draw(pil_image)
149
+
150
+ # Draw warning text
151
+ draw.text((text_x, text_y - 30), warning_text, fill="white", font=font)
152
 
153
+ # Add processing info at the bottom
154
+ device_text = f"Running on: {self.device}"
155
+ draw.text((10, h - 30), device_text, fill="green", font=small_font)
 
 
 
156
 
157
+ # Convert back to numpy array
158
+ result_frame = np.array(pil_image)
159
+
160
+ return result_frame
161
+
162
+ # Initialize the detector
163
+ detector = PhoneDetector()
164
+
165
+ # Function to process webcam frames
166
+ def process_webcam(image):
167
+ """
168
+ Process webcam input for Gradio interface
169
+
170
+ Args:
171
+ image: Input image from Gradio
172
+
173
+ Returns:
174
+ Processed image with phone detection visualization
175
+ """
176
+ if image is None:
177
  return None
178
+
179
+ # Process the frame
180
+ result_frame = detector.detect(image)
181
+
182
+ if result_frame is None:
183
+ return image
184
+
185
+ return result_frame
186
 
187
+ # Create Gradio interface
188
+ title = "Phone Detection with YOLOv8"
189
+ description = """
190
+ ## Real-time Phone Detection
191
+
192
+ This app uses YOLOv8 to detect phones in real-time through your webcam.
193
+ When a phone is detected, a warning message is displayed.
194
+
195
+ ### How it works:
196
+ 1. The webcam captures your video feed
197
+ 2. Each frame is analyzed by YOLOv8 to detect phones
198
+ 3. If a phone is detected, a warning message appears
199
+
200
+ ### Notes:
201
+ - You may need to give permission for camera access
202
+ - The app works best with good lighting conditions
203
+ - The model detects cell phones only
204
+ """
205
+
206
+ # Create Gradio blocks interface
207
+ with gr.Blocks(title=title) as demo:
208
+ gr.Markdown(description)
209
 
210
+ with gr.Row():
211
+ with gr.Column():
212
+ # Webcam input with streaming
213
+ webcam_input = gr.Image(label="Webcam", sources=["webcam"], streaming=True)
214
+
215
+ with gr.Column():
216
+ output_display = gr.Image(label="Detection Result")
217
+
218
+ # Stream processing
219
+ webcam_input.stream(process_webcam, inputs=webcam_input, outputs=output_display)
220
+
221
+ gr.Markdown("""
222
+ ### Technical Details
223
+ - Model: YOLOv8n (optimized for speed)
224
+ - Target class: "cell phone"
225
+ - Confidence threshold: 0.4
226
+
227
+ This application was developed using Ultralytics YOLOv8, Gradio, and OpenCV.
228
+ """)
229
+
230
+ # Launch the interface
231
+ demo.launch()