nooneshouldtouch commited on
Commit
40d4c17
·
verified ·
1 Parent(s): dca75ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -566
app.py CHANGED
@@ -1,666 +1,155 @@
1
- # backend.py
2
-
3
  import os
4
  import cv2
5
  import numpy as np
6
  import tensorflow as tf
7
- import smtplib
8
-
9
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
10
- from fastapi.responses import JSONResponse, StreamingResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
-
13
  from typing import Dict, Any
14
  from datetime import datetime, timezone
15
  from io import BytesIO
16
-
17
- # SQLAlchemy imports
18
  from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
19
  from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
20
-
21
- # ReportLab (PDF generation)
22
  from reportlab.lib.pagesizes import A4
23
- from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as RLImage, Table, TableStyle
24
  from reportlab.lib.styles import getSampleStyleSheet
25
  from reportlab.lib import colors
26
-
27
- # Matplotlib (Chart generation)
28
- import matplotlib
29
- matplotlib.use('Agg')
30
  import matplotlib.pyplot as plt
31
-
32
- # YOLO-related imports
33
  from src.yolo3.model import yolo_body
34
  from src.yolo3.detect import detection
35
  from src.utils.image import letterbox_image
36
  from src.utils.fixes import fix_tf_gpu
37
  from tensorflow.keras.layers import Input
38
 
39
-
40
- ##############################################################################
41
- # Database Setup (SQLite)
42
- ##############################################################################
43
-
44
  DB_URL = "sqlite:///./safety_monitor.db"
45
-
46
- engine = create_engine(
47
- DB_URL, connect_args={"check_same_thread": False} # for single-threaded SQLite
48
- )
49
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
50
  Base = declarative_base()
51
 
52
  class Upload(Base):
53
- """
54
- Stores information about each upload (image or video), plus the user's email.
55
- """
56
  __tablename__ = "uploads"
57
-
58
  id = Column(Integer, primary_key=True, index=True)
59
  filename = Column(String)
60
  filepath = Column(String)
61
  timestamp = Column(DateTime)
62
  approach = Column(Integer)
63
- user_email = Column(String) # The user’s email address
64
  total_workers = Column(Integer, default=0)
65
  total_helmets = Column(Integer, default=0)
66
  total_vests = Column(Integer, default=0)
67
- # We'll store worker_images as a comma-separated string for simplicity
68
  worker_images = Column(Text, default="")
69
-
70
- # Relationship to SafetyDetection
71
  detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
72
 
73
-
74
  class SafetyDetection(Base):
75
- """
76
- Stores individual safety gear detections (e.g., bounding boxes for helmets/vests).
77
- """
78
  __tablename__ = "safety_detections"
79
-
80
  id = Column(Integer, primary_key=True, index=True)
81
- label = Column(String) # e.g. 'H', 'V'
82
- box = Column(String) # bounding box as string, e.g. "x1,y1,x2,y2"
83
  timestamp = Column(DateTime)
84
-
85
  upload_id = Column(Integer, ForeignKey("uploads.id"))
86
  upload = relationship("Upload", back_populates="detections")
87
 
88
-
89
  Base.metadata.create_all(bind=engine)
90
 
 
 
91
 
92
- ##############################################################################
93
- # FastAPI App & Configuration
94
- ##############################################################################
95
-
96
- app = FastAPI(
97
- title="Industrial Safety Monitor (FastAPI + SQLite)",
98
- description="A YOLO-based safety gear detection app. Three endpoints: upload, results, dashboard.",
99
- version="1.0.0",
100
- )
101
-
102
- # Allow cross-origin requests (optional)
103
- app.add_middleware(
104
- CORSMiddleware,
105
- allow_origins=["*"],
106
- allow_credentials=True,
107
- allow_methods=["*"],
108
- allow_headers=["*"],
109
- )
110
-
111
- # Directories
112
  UPLOAD_FOLDER = "static/uploads"
113
  PROCESSED_FOLDER = "static/processed"
114
  WORKER_FOLDER = "static/workers"
115
  CHARTS_FOLDER = "static/charts"
116
- ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'mp4'}
117
-
118
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
119
  os.makedirs(PROCESSED_FOLDER, exist_ok=True)
120
  os.makedirs(WORKER_FOLDER, exist_ok=True)
121
  os.makedirs(CHARTS_FOLDER, exist_ok=True)
122
 
123
- ##############################################################################
124
- # YOLO Model Setup
125
- ##############################################################################
126
-
127
  input_shape = (416, 416)
128
- class_names = []
129
- anchor_boxes = None
130
- num_classes = 0
131
- num_anchors = 0
132
  model = None
133
 
134
- def prepare_model(approach: int):
135
- """
136
- Prepares the YOLO model for the selected approach (1, 2, or 3).
137
- """
138
- global input_shape, class_names, anchor_boxes
139
- global num_classes, num_anchors
140
-
141
- if approach not in [1, 2, 3]:
142
- raise NotImplementedError("Approach must be 1, 2, or 3")
143
-
144
- # Classes: H=Helmet, V=Vest, W=Worker
145
- class_names[:] = ['H', 'V', 'W']
146
-
147
- # Anchor boxes by approach
148
- if approach == 1:
149
- anchor_boxes = np.array(
150
- [
151
- np.array([[76, 59], [84, 136], [188, 225]]) / 32,
152
- np.array([[25, 15], [46, 29], [27, 56]]) / 16,
153
- np.array([[5, 3], [10, 8], [12, 26]]) / 8
154
- ],
155
- dtype='float64'
156
- )
157
- elif approach == 2:
158
- anchor_boxes = np.array(
159
- [
160
- np.array([[73, 158], [128, 209], [224, 246]]) / 32,
161
- np.array([[32, 50], [40, 104], [76, 73]]) / 16,
162
- np.array([[6, 11], [11, 23], [19, 36]]) / 8
163
- ],
164
- dtype='float64'
165
- )
166
- else: # approach == 3
167
- anchor_boxes = np.array(
168
- [
169
- np.array([[76, 59], [84, 136], [188, 225]]) / 32,
170
- np.array([[25, 15], [46, 29], [27, 56]]) / 16,
171
- np.array([[5, 3], [10, 8], [12, 26]]) / 8
172
- ],
173
- dtype='float64'
174
- )
175
-
176
- num_classes = len(class_names)
177
- num_anchors = anchor_boxes.shape[0] * anchor_boxes.shape[1]
178
-
179
  input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
180
  num_out_filters = (num_anchors // 3) * (5 + num_classes)
181
- _model = yolo_body(input_tensor, num_out_filters)
182
-
183
- weight_path = f"model-data/weights/pictor-ppe-v302-a{approach}-yolo-v3-weights.h5"
184
  if not os.path.exists(weight_path):
185
  raise FileNotFoundError(f"Weight file not found: {weight_path}")
 
186
 
187
- _model.load_weights(weight_path)
188
- return _model
189
-
190
- ##############################################################################
191
- # Utility & Detection Logic
192
- ##############################################################################
193
-
194
- def allowed_file(filename: str) -> bool:
195
- return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
196
 
197
- def get_db() -> Session:
198
- """
199
- Yields a database session.
200
- """
201
  db = SessionLocal()
202
- try:
203
- yield db
204
- finally:
 
 
 
 
 
 
 
 
205
  db.close()
 
 
 
 
 
 
 
 
 
206
 
207
- def run_detection_on_frame(frame: np.ndarray,
208
- approach: int,
209
- upload_id: int,
210
- db: Session) -> np.ndarray:
211
- """
212
- Runs YOLO detection on a single frame, updates DB counters/detections,
213
- and returns the annotated frame.
214
- """
215
- global model, anchor_boxes, class_names, input_shape
216
-
217
  ih, iw = frame.shape[:2]
218
  resized = letterbox_image(frame, input_shape)
219
  resized_expanded = np.expand_dims(resized, 0)
220
  image_data = np.array(resized_expanded) / 255.0
221
-
222
  prediction = model.predict(image_data)
223
- boxes = detection(
224
- prediction,
225
- anchor_boxes,
226
- len(class_names),
227
- image_shape=(ih, iw),
228
- input_shape=input_shape,
229
- max_boxes=50,
230
- score_threshold=0.3,
231
- iou_threshold=0.45,
232
- classes_can_overlap=False
233
- )[0].numpy()
234
-
235
- # Tally
236
  workers, helmets, vests = [], [], []
237
  for box in boxes:
238
- x1, y1, x2, y2, score, cls_id = box
239
- x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
240
- cls_id = int(cls_id)
241
  label = class_names[cls_id]
242
-
243
  if label == 'W':
244
  workers.append((x1, y1, x2, y2))
245
- color = (0, 255, 0)
246
  elif label == 'H':
247
  helmets.append((x1, y1, x2, y2))
248
- color = (255, 0, 0)
249
  elif label == 'V':
250
  vests.append((x1, y1, x2, y2))
251
- color = (0, 0, 255)
252
- else:
253
- color = (255, 255, 0)
254
-
255
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
256
- cv2.putText(frame, label, (x1, y1 - 10),
257
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
258
-
259
  upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
260
  if upload_obj:
261
  upload_obj.total_workers += len(workers)
262
  upload_obj.total_helmets += len(helmets)
263
  upload_obj.total_vests += len(vests)
264
  db.commit()
265
-
266
- # Insert SafetyDetection for helmets/vests
267
- now_utc = datetime.now(timezone.utc)
268
- for (hx1, hy1, hx2, hy2) in helmets:
269
- db.add(SafetyDetection(
270
- label='H',
271
- box=f"{hx1},{hy1},{hx2},{hy2}",
272
- timestamp=now_utc,
273
- upload_id=upload_id
274
- ))
275
- for (vx1, vy1, vx2, vy2) in vests:
276
- db.add(SafetyDetection(
277
- label='V',
278
- box=f"{vx1},{vy1},{vx2},{vy2}",
279
- timestamp=now_utc,
280
- upload_id=upload_id
281
- ))
282
- db.commit()
283
-
284
- # Also save worker crops
285
- worker_images_list = []
286
- for idx, (wx1, wy1, wx2, wy2) in enumerate(workers, start=1):
287
- crop = frame[wy1:wy2, wx1:wx2]
288
- if crop.size == 0:
289
- continue
290
- worker_filename = f"worker_{upload_id}_{idx}.jpg"
291
- worker_path = os.path.join(WORKER_FOLDER, worker_filename)
292
- cv2.imwrite(worker_path, crop)
293
- worker_images_list.append(worker_path)
294
-
295
- # Append new worker images
296
- existing_imgs = upload_obj.worker_images.split(",") if upload_obj.worker_images else []
297
- all_imgs = existing_imgs + worker_images_list
298
- upload_obj.worker_images = ",".join([w for w in all_imgs if w])
299
- db.commit()
300
-
301
  return frame
302
 
303
- def generate_and_email_pdf(upload_obj: Upload, db: Session):
304
- """
305
- Generates a PDF report for a single upload, then emails it to upload_obj.user_email.
306
- """
307
- # We’ll produce a single-page-ish PDF with the detection summary for this upload.
308
-
309
- # Grab top-level stats
310
- total_workers = upload_obj.total_workers
311
- total_helmets = upload_obj.total_helmets
312
- total_vests = upload_obj.total_vests
313
- worker_images = upload_obj.worker_images.split(",") if upload_obj.worker_images else []
314
-
315
- # Create a PDF
316
  buffer = BytesIO()
317
  doc = SimpleDocTemplate(buffer, pagesize=A4)
318
  elements = []
319
  styles = getSampleStyleSheet()
320
-
321
- # Title
322
- elements.append(Paragraph("Industrial Safety Monitor Report", styles["Title"]))
323
- elements.append(Paragraph(f"Upload ID: {upload_obj.id}", styles["Normal"]))
324
  elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
325
  elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
326
- elements.append(Paragraph(f"Approach: {upload_obj.approach}", styles["Normal"]))
327
- elements.append(Paragraph(f"User Email: {upload_obj.user_email}", styles["Normal"]))
328
- elements.append(Spacer(1, 12))
329
-
330
- # Table of basic detection metrics
331
- data = [
332
- ["Total Workers", total_workers],
333
- ["Total Helmets", total_helmets],
334
- ["Total Vests", total_vests]
335
- ]
336
- table = Table(data, colWidths=[200, 200])
337
- table.setStyle(TableStyle([
338
- ("BACKGROUND", (0, 0), (-1, 0), colors.grey),
339
- ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
340
- ("ALIGN", (0, 0), (-1, -1), "CENTER"),
341
- ("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
342
- ("FONTSIZE", (0, 0), (-1, 0), 12),
343
- ("BOTTOMPADDING", (0, 0), (-1, 0), 12),
344
- ("BACKGROUND", (0, 1), (-1, -1), colors.beige),
345
- ("GRID", (0, 0), (-1, -1), 1, colors.black),
346
- ]))
347
  elements.append(table)
348
- elements.append(Spacer(1, 12))
349
-
350
- # Show worker crops, if any
351
- if worker_images:
352
- elements.append(Paragraph("Detected Workers:", styles["Heading3"]))
353
- elements.append(Spacer(1, 12))
354
- for wimg in worker_images:
355
- wimg = wimg.strip()
356
- if wimg and os.path.exists(wimg):
357
- elements.append(RLImage(wimg, width=100, height=75))
358
- elements.append(Spacer(1, 12))
359
-
360
  doc.build(elements)
361
  buffer.seek(0)
362
- pdf_data = buffer.getvalue()
363
-
364
- # Email the PDF
365
- receiver_email = upload_obj.user_email
366
- if not receiver_email:
367
- print("No email to send to.")
368
- return # skip emailing if no user email
369
-
370
- # Adjust credentials
371
- sender_email = "[email protected]"
372
- sender_password = "aobh rdgp iday bpwg"
373
- subject = "Industrial Safety Monitor - Your Detection Report"
374
- body = (
375
- "Hello,\n\n"
376
- "Please find attached the Industrial Safety Monitor detection report.\n"
377
- "Regards,\nISM Bot"
378
- )
379
-
380
- from email.mime.multipart import MIMEMultipart
381
- from email.mime.text import MIMEText
382
- from email.mime.application import MIMEApplication
383
-
384
- msg = MIMEMultipart()
385
- msg["From"] = sender_email
386
- msg["To"] = receiver_email
387
- msg["Subject"] = subject
388
- msg.attach(MIMEText(body, "plain"))
389
-
390
- part = MIMEApplication(pdf_data, _subtype="pdf")
391
- part.add_header("Content-Disposition", "attachment", filename="ISM_Report.pdf")
392
- msg.attach(part)
393
-
394
- try:
395
- with smtplib.SMTP("smtp.gmail.com", 587) as server:
396
- server.starttls()
397
- server.login(sender_email, sender_password)
398
- server.send_message(msg)
399
- print(f"Email sent successfully to {receiver_email}!")
400
- except Exception as e:
401
- print(f"Error sending email: {e}")
402
-
403
-
404
- ##############################################################################
405
- # 1) /upload
406
- ##############################################################################
407
-
408
- @app.post("/upload", summary="Upload image/video + email; run detection, send PDF to email.")
409
- async def upload_file(
410
- approach: int = Form(...),
411
- file: UploadFile = File(...),
412
- user_email: str = Form(...),
413
- ):
414
- """
415
- 1) User uploads an image/video with approach + email.
416
- 2) We run YOLO detection.
417
- 3) We store results in DB.
418
- 4) We generate a PDF and email it to `user_email`.
419
- 5) Return detection counts in JSON.
420
- """
421
- global model
422
-
423
- db = SessionLocal()
424
-
425
- # Prepare YOLO model for the chosen approach
426
- try:
427
- if (model is None) or (approach not in [1, 2, 3]):
428
- model = prepare_model(approach)
429
- except Exception as e:
430
- db.close()
431
- raise HTTPException(status_code=500, detail=str(e))
432
-
433
- # Check file type
434
- filename = file.filename
435
- if not allowed_file(filename):
436
- db.close()
437
- raise HTTPException(
438
- status_code=400,
439
- detail="Unsupported file type. Allowed: .png, .jpg, .jpeg, .gif, .mp4",
440
- )
441
-
442
- # Save the uploaded file
443
- filepath = os.path.join(UPLOAD_FOLDER, filename)
444
- with open(filepath, "wb") as f:
445
- f.write(await file.read())
446
-
447
- # Create an Upload record
448
- upload_obj = Upload(
449
- filename=filename,
450
- filepath=filepath,
451
- timestamp=datetime.now(timezone.utc),
452
- approach=approach,
453
- user_email=user_email,
454
- total_workers=0,
455
- total_helmets=0,
456
- total_vests=0,
457
- worker_images=""
458
- )
459
- db.add(upload_obj)
460
- db.commit()
461
- db.refresh(upload_obj)
462
- upload_id = upload_obj.id
463
-
464
- # If it's an image
465
- if filename.lower().endswith((".png", ".jpg", ".jpeg", ".gif")):
466
- img = cv2.imread(filepath)
467
- if img is None:
468
- db.close()
469
- raise HTTPException(status_code=400, detail="Failed to read the image file.")
470
-
471
- # Run detection on the single image
472
- annotated_frame = run_detection_on_frame(img, approach, upload_id, db)
473
-
474
- # Save processed image
475
- processed_filename = f"processed_{filename}"
476
- processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
477
- cv2.imwrite(processed_path, annotated_frame)
478
-
479
- # If it's a video
480
- elif filename.lower().endswith(".mp4"):
481
- video = cv2.VideoCapture(filepath)
482
- if not video.isOpened():
483
- db.close()
484
- raise HTTPException(status_code=400, detail="Failed to read the video file.")
485
-
486
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
487
- processed_filename = f"processed_{filename}"
488
- processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
489
-
490
- original_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
491
- original_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
492
- fps = video.get(cv2.CAP_PROP_FPS)
493
- frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
494
-
495
- out = cv2.VideoWriter(
496
- processed_path, fourcc, fps, (original_width, original_height)
497
- )
498
-
499
- current_frame = 0
500
- while True:
501
- ret, frame = video.read()
502
- if not ret:
503
- break
504
- current_frame += 1
505
- print(f"Processing frame {current_frame}/{frame_count} (Upload ID={upload_id})")
506
-
507
- annotated_frame = run_detection_on_frame(frame, approach, upload_id, db)
508
- out.write(annotated_frame)
509
-
510
- video.release()
511
- out.release()
512
-
513
- # Now fetch updated counts
514
- db.refresh(upload_obj)
515
-
516
- # Generate & email PDF
517
- generate_and_email_pdf(upload_obj, db)
518
-
519
- counts = {
520
- "total_workers": upload_obj.total_workers,
521
- "total_helmets": upload_obj.total_helmets,
522
- "total_vests": upload_obj.total_vests
523
- }
524
-
525
- db.close()
526
- return {
527
- "message": f"File uploaded, detection done, PDF emailed to {user_email}.",
528
- "upload_id": upload_id,
529
- "counts": counts
530
- }
531
-
532
-
533
- ##############################################################################
534
- # 2) /results
535
- ##############################################################################
536
-
537
- @app.get("/results", summary="Fetch the most recent upload’s details.")
538
- def get_results():
539
- """
540
- Returns the details (counts, file paths, worker_images) of the most recent upload.
541
- """
542
- db = SessionLocal()
543
- latest = db.query(Upload).order_by(Upload.timestamp.desc()).first()
544
- if not latest:
545
- db.close()
546
- return {"message": "No uploads found in the database."}
547
-
548
- processed_filename = f"processed_{latest.filename}"
549
- processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
550
- data = {
551
- "upload_id": latest.id,
552
- "filename": latest.filename,
553
- "original_path": latest.filepath,
554
- "processed_path": processed_path if os.path.exists(processed_path) else None,
555
- "approach": latest.approach,
556
- "user_email": latest.user_email,
557
- "total_workers": latest.total_workers,
558
- "total_helmets": latest.total_helmets,
559
- "total_vests": latest.total_vests,
560
- "worker_images": (latest.worker_images.split(",") if latest.worker_images else []),
561
- "timestamp": latest.timestamp.strftime("%Y-%m-%d %H:%M:%S")
562
- }
563
- db.close()
564
- return data
565
-
566
-
567
- ##############################################################################
568
- # 3) /dashboard
569
- ##############################################################################
570
-
571
- @app.get("/dashboard", summary="Get aggregated statistics for a dashboard.")
572
- def dashboard():
573
- """
574
- Returns aggregated stats (uploads, detection sums, time-series, approach usage) in JSON.
575
- """
576
- db = SessionLocal()
577
-
578
- # Total uploads
579
- total_uploads = db.query(Upload).count()
580
-
581
- # Summation of detections
582
- agg = db.query(
583
- func.sum(Upload.total_workers).label("tw"),
584
- func.sum(Upload.total_helmets).label("th"),
585
- func.sum(Upload.total_vests).label("tv")
586
- ).one()
587
- total_workers = agg.tw or 0
588
- total_helmets = agg.th or 0
589
- total_vests = agg.tv or 0
590
-
591
- # Time-series by day
592
- day_rows = db.query(
593
- func.date(Upload.timestamp).label("day"),
594
- func.count(Upload.id).label("uploads"),
595
- func.sum(Upload.total_workers).label("workers"),
596
- func.sum(Upload.total_helmets).label("helmets"),
597
- func.sum(Upload.total_vests).label("vests")
598
- ).group_by(func.date(Upload.timestamp)).order_by(func.date(Upload.timestamp)).all()
599
-
600
- dates = []
601
- uploads_per_day = []
602
- workers_per_day = []
603
- helmets_per_day = []
604
- vests_per_day = []
605
-
606
- for row in day_rows:
607
- dates.append(row.day)
608
- uploads_per_day.append(row.uploads or 0)
609
- workers_per_day.append(row.workers or 0)
610
- helmets_per_day.append(row.helmets or 0)
611
- vests_per_day.append(row.vests or 0)
612
-
613
- # Approach usage
614
- approach_rows = db.query(
615
- Upload.approach,
616
- func.count(Upload.id).label("count")
617
- ).group_by(Upload.approach).all()
618
- approach_data = []
619
- for ar in approach_rows:
620
- approach_data.append({
621
- "approach": f"Approach-{ar.approach}",
622
- "count": ar.count
623
- })
624
-
625
- # Basic distribution of helmets vs. vests
626
- safety_gear_labels = ["Helmets", "Vests"]
627
- safety_gear_counts = [total_helmets, total_vests]
628
-
629
- db.close()
630
- return {
631
- "total_uploads": total_uploads,
632
- "total_workers": total_workers,
633
- "total_helmets": total_helmets,
634
- "total_vests": total_vests,
635
- "time_series": {
636
- "dates": dates,
637
- "uploads_per_day": uploads_per_day,
638
- "workers_per_day": workers_per_day,
639
- "helmets_per_day": helmets_per_day,
640
- "vests_per_day": vests_per_day
641
- },
642
- "approach_usage": approach_data,
643
- "safety_gear_distribution": {
644
- "labels": safety_gear_labels,
645
- "counts": safety_gear_counts
646
- }
647
- }
648
-
649
-
650
- ##############################################################################
651
- # Startup (Load YOLO Model)
652
- ##############################################################################
653
-
654
- @app.on_event("startup")
655
- def on_startup():
656
- fix_tf_gpu()
657
- global model
658
- try:
659
- # Load default approach=1 at startup (optional)
660
- model_local = prepare_model(approach=1)
661
- model = model_local
662
- print("YOLO model (Approach=1) loaded successfully.")
663
- except FileNotFoundError as e:
664
- print(f"Model file not found on startup: {e}")
665
- except Exception as e:
666
- print(f"Error preparing model on startup: {e}")
 
 
 
1
  import os
2
  import cv2
3
  import numpy as np
4
  import tensorflow as tf
 
 
5
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
6
+ from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
  from typing import Dict, Any
9
  from datetime import datetime, timezone
10
  from io import BytesIO
 
 
11
  from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
12
  from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
 
 
13
  from reportlab.lib.pagesizes import A4
14
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
15
  from reportlab.lib.styles import getSampleStyleSheet
16
  from reportlab.lib import colors
 
 
 
 
17
  import matplotlib.pyplot as plt
 
 
18
  from src.yolo3.model import yolo_body
19
  from src.yolo3.detect import detection
20
  from src.utils.image import letterbox_image
21
  from src.utils.fixes import fix_tf_gpu
22
  from tensorflow.keras.layers import Input
23
 
 
 
 
 
 
24
  DB_URL = "sqlite:///./safety_monitor.db"
25
+ engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
 
 
 
26
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
27
  Base = declarative_base()
28
 
29
  class Upload(Base):
 
 
 
30
  __tablename__ = "uploads"
 
31
  id = Column(Integer, primary_key=True, index=True)
32
  filename = Column(String)
33
  filepath = Column(String)
34
  timestamp = Column(DateTime)
35
  approach = Column(Integer)
 
36
  total_workers = Column(Integer, default=0)
37
  total_helmets = Column(Integer, default=0)
38
  total_vests = Column(Integer, default=0)
 
39
  worker_images = Column(Text, default="")
 
 
40
  detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
41
 
 
42
  class SafetyDetection(Base):
 
 
 
43
  __tablename__ = "safety_detections"
 
44
  id = Column(Integer, primary_key=True, index=True)
45
+ label = Column(String)
46
+ box = Column(String)
47
  timestamp = Column(DateTime)
 
48
  upload_id = Column(Integer, ForeignKey("uploads.id"))
49
  upload = relationship("Upload", back_populates="detections")
50
 
 
51
  Base.metadata.create_all(bind=engine)
52
 
53
+ app = FastAPI(title="Industrial Safety Monitor", version="1.0.0")
54
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  UPLOAD_FOLDER = "static/uploads"
57
  PROCESSED_FOLDER = "static/processed"
58
  WORKER_FOLDER = "static/workers"
59
  CHARTS_FOLDER = "static/charts"
 
 
60
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
61
  os.makedirs(PROCESSED_FOLDER, exist_ok=True)
62
  os.makedirs(WORKER_FOLDER, exist_ok=True)
63
  os.makedirs(CHARTS_FOLDER, exist_ok=True)
64
 
 
 
 
 
65
  input_shape = (416, 416)
66
+ class_names = ['H', 'V', 'W']
67
+ num_classes = len(class_names)
68
+ num_anchors = 9
 
69
  model = None
70
 
71
+ def prepare_model():
72
+ global model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
74
  num_out_filters = (num_anchors // 3) * (5 + num_classes)
75
+ model = yolo_body(input_tensor, num_out_filters)
76
+ weight_path = "model-data/weights/yolo_weights.h5"
 
77
  if not os.path.exists(weight_path):
78
  raise FileNotFoundError(f"Weight file not found: {weight_path}")
79
+ model.load_weights(weight_path)
80
 
81
+ @app.on_event("startup")
82
+ def on_startup():
83
+ fix_tf_gpu()
84
+ prepare_model()
 
 
 
 
 
85
 
86
+ @app.post("/upload")
87
+ async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
88
+ global model
 
89
  db = SessionLocal()
90
+ filename = file.filename
91
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
92
+ with open(filepath, "wb") as f:
93
+ f.write(await file.read())
94
+ upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach)
95
+ db.add(upload_obj)
96
+ db.commit()
97
+ db.refresh(upload_obj)
98
+ upload_id = upload_obj.id
99
+ img = cv2.imread(filepath)
100
+ if img is None:
101
  db.close()
102
+ raise HTTPException(status_code=400, detail="Failed to read the image file.")
103
+ processed_img = run_detection_on_frame(img, upload_id, db)
104
+ processed_filename = f"processed_{filename}"
105
+ processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
106
+ cv2.imwrite(processed_path, processed_img)
107
+ db.refresh(upload_obj)
108
+ pdf_path = generate_pdf(upload_obj)
109
+ db.close()
110
+ return {"message": "File processed successfully.", "upload_id": upload_id, "pdf_path": pdf_path}
111
 
112
+ def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np.ndarray:
113
+ global model
 
 
 
 
 
 
 
 
114
  ih, iw = frame.shape[:2]
115
  resized = letterbox_image(frame, input_shape)
116
  resized_expanded = np.expand_dims(resized, 0)
117
  image_data = np.array(resized_expanded) / 255.0
 
118
  prediction = model.predict(image_data)
119
+ boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
 
 
 
 
 
 
 
 
 
 
 
 
120
  workers, helmets, vests = [], [], []
121
  for box in boxes:
122
+ x1, y1, x2, y2, _, cls_id = map(int, box)
 
 
123
  label = class_names[cls_id]
 
124
  if label == 'W':
125
  workers.append((x1, y1, x2, y2))
 
126
  elif label == 'H':
127
  helmets.append((x1, y1, x2, y2))
 
128
  elif label == 'V':
129
  vests.append((x1, y1, x2, y2))
 
 
 
 
 
 
 
 
130
  upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
131
  if upload_obj:
132
  upload_obj.total_workers += len(workers)
133
  upload_obj.total_helmets += len(helmets)
134
  upload_obj.total_vests += len(vests)
135
  db.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return frame
137
 
138
+ def generate_pdf(upload_obj: Upload):
 
 
 
 
 
 
 
 
 
 
 
 
139
  buffer = BytesIO()
140
  doc = SimpleDocTemplate(buffer, pagesize=A4)
141
  elements = []
142
  styles = getSampleStyleSheet()
143
+ elements.append(Paragraph("Industrial Safety Report", styles["Title"]))
 
 
 
144
  elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
145
  elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
146
+ data = [["Total Workers", upload_obj.total_workers], ["Total Helmets", upload_obj.total_helmets], ["Total Vests", upload_obj.total_vests]]
147
+ table = Table(data)
148
+ table.setStyle(TableStyle([("BACKGROUND", (0, 0), (-1, 0), colors.grey), ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), ("GRID", (0, 0), (-1, -1), 1, colors.black)]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  elements.append(table)
 
 
 
 
 
 
 
 
 
 
 
 
150
  doc.build(elements)
151
  buffer.seek(0)
152
+ pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_obj.id}.pdf")
153
+ with open(pdf_path, "wb") as f:
154
+ f.write(buffer.getvalue())
155
+ return pdf_path