safetyproject / app.py
nooneshouldtouch's picture
oiok
8050a76
# backend.py
import os
import cv2
import numpy as np
import tensorflow as tf
import smtplib
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
from datetime import datetime, timezone
from io import BytesIO
# SQLAlchemy imports
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
# ReportLab (PDF generation)
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as RLImage, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
# Matplotlib (Chart generation)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# YOLO-related imports
from src.yolo3.model import yolo_body
from src.yolo3.detect import detection
from src.utils.image import letterbox_image
from src.utils.fixes import fix_tf_gpu
from tensorflow.keras.layers import Input
##############################################################################
# Database Setup (SQLite)
##############################################################################
DB_URL = "sqlite:///./safety_monitor.db"
engine = create_engine(
DB_URL, connect_args={"check_same_thread": False} # for single-threaded SQLite
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class Upload(Base):
"""
Stores information about each upload (image or video), plus the user's email.
"""
__tablename__ = "uploads"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String)
filepath = Column(String)
timestamp = Column(DateTime)
approach = Column(Integer)
user_email = Column(String) # The user’s email address
total_workers = Column(Integer, default=0)
total_helmets = Column(Integer, default=0)
total_vests = Column(Integer, default=0)
# We'll store worker_images as a comma-separated string for simplicity
worker_images = Column(Text, default="")
# Relationship to SafetyDetection
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
class SafetyDetection(Base):
"""
Stores individual safety gear detections (e.g., bounding boxes for helmets/vests).
"""
__tablename__ = "safety_detections"
id = Column(Integer, primary_key=True, index=True)
label = Column(String) # e.g. 'H', 'V'
box = Column(String) # bounding box as string, e.g. "x1,y1,x2,y2"
timestamp = Column(DateTime)
upload_id = Column(Integer, ForeignKey("uploads.id"))
upload = relationship("Upload", back_populates="detections")
Base.metadata.create_all(bind=engine)
##############################################################################
# FastAPI App & Configuration
##############################################################################
app = FastAPI(
title="Industrial Safety Monitor (FastAPI + SQLite)",
description="A YOLO-based safety gear detection app. Three endpoints: upload, results, dashboard.",
version="1.0.0",
)
# Allow cross-origin requests (optional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Directories
UPLOAD_FOLDER = "static/uploads"
PROCESSED_FOLDER = "static/processed"
WORKER_FOLDER = "static/workers"
CHARTS_FOLDER = "static/charts"
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'mp4'}
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(WORKER_FOLDER, exist_ok=True)
os.makedirs(CHARTS_FOLDER, exist_ok=True)
##############################################################################
# YOLO Model Setup
##############################################################################
input_shape = (416, 416)
class_names = []
anchor_boxes = None
num_classes = 0
num_anchors = 0
model = None
def prepare_model(approach: int):
"""
Prepares the YOLO model for the selected approach (1, 2, or 3).
"""
global input_shape, class_names, anchor_boxes
global num_classes, num_anchors
if approach not in [1, 2, 3]:
raise NotImplementedError("Approach must be 1, 2, or 3")
# Classes: H=Helmet, V=Vest, W=Worker
class_names[:] = ['H', 'V', 'W']
# Anchor boxes by approach
if approach == 1:
anchor_boxes = np.array(
[
np.array([[76, 59], [84, 136], [188, 225]]) / 32,
np.array([[25, 15], [46, 29], [27, 56]]) / 16,
np.array([[5, 3], [10, 8], [12, 26]]) / 8
],
dtype='float64'
)
elif approach == 2:
anchor_boxes = np.array(
[
np.array([[73, 158], [128, 209], [224, 246]]) / 32,
np.array([[32, 50], [40, 104], [76, 73]]) / 16,
np.array([[6, 11], [11, 23], [19, 36]]) / 8
],
dtype='float64'
)
else: # approach == 3
anchor_boxes = np.array(
[
np.array([[76, 59], [84, 136], [188, 225]]) / 32,
np.array([[25, 15], [46, 29], [27, 56]]) / 16,
np.array([[5, 3], [10, 8], [12, 26]]) / 8
],
dtype='float64'
)
num_classes = len(class_names)
num_anchors = anchor_boxes.shape[0] * anchor_boxes.shape[1]
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
num_out_filters = (num_anchors // 3) * (5 + num_classes)
_model = yolo_body(input_tensor, num_out_filters)
weight_path = f"model-data/weights/pictor-ppe-v302-a{approach}-yolo-v3-weights.h5"
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight file not found: {weight_path}")
_model.load_weights(weight_path)
return _model
##############################################################################
# Utility & Detection Logic
##############################################################################
def allowed_file(filename: str) -> bool:
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def get_db() -> Session:
"""
Yields a database session.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def run_detection_on_frame(frame: np.ndarray,
approach: int,
upload_id: int,
db: Session) -> np.ndarray:
"""
Runs YOLO detection on a single frame, updates DB counters/detections,
and returns the annotated frame.
"""
global model, anchor_boxes, class_names, input_shape
ih, iw = frame.shape[:2]
resized = letterbox_image(frame, input_shape)
resized_expanded = np.expand_dims(resized, 0)
image_data = np.array(resized_expanded) / 255.0
prediction = model.predict(image_data)
boxes = detection(
prediction,
anchor_boxes,
len(class_names),
image_shape=(ih, iw),
input_shape=input_shape,
max_boxes=50,
score_threshold=0.3,
iou_threshold=0.45,
classes_can_overlap=False
)[0].numpy()
# Tally
workers, helmets, vests = [], [], []
for box in boxes:
x1, y1, x2, y2, score, cls_id = box
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
cls_id = int(cls_id)
label = class_names[cls_id]
if label == 'W':
workers.append((x1, y1, x2, y2))
color = (0, 255, 0)
elif label == 'H':
helmets.append((x1, y1, x2, y2))
color = (255, 0, 0)
elif label == 'V':
vests.append((x1, y1, x2, y2))
color = (0, 0, 255)
else:
color = (255, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
if upload_obj:
upload_obj.total_workers += len(workers)
upload_obj.total_helmets += len(helmets)
upload_obj.total_vests += len(vests)
db.commit()
# Insert SafetyDetection for helmets/vests
now_utc = datetime.now(timezone.utc)
for (hx1, hy1, hx2, hy2) in helmets:
db.add(SafetyDetection(
label='H',
box=f"{hx1},{hy1},{hx2},{hy2}",
timestamp=now_utc,
upload_id=upload_id
))
for (vx1, vy1, vx2, vy2) in vests:
db.add(SafetyDetection(
label='V',
box=f"{vx1},{vy1},{vx2},{vy2}",
timestamp=now_utc,
upload_id=upload_id
))
db.commit()
# Also save worker crops
worker_images_list = []
for idx, (wx1, wy1, wx2, wy2) in enumerate(workers, start=1):
crop = frame[wy1:wy2, wx1:wx2]
if crop.size == 0:
continue
worker_filename = f"worker_{upload_id}_{idx}.jpg"
worker_path = os.path.join(WORKER_FOLDER, worker_filename)
cv2.imwrite(worker_path, crop)
worker_images_list.append(worker_path)
# Append new worker images
existing_imgs = upload_obj.worker_images.split(",") if upload_obj.worker_images else []
all_imgs = existing_imgs + worker_images_list
upload_obj.worker_images = ",".join([w for w in all_imgs if w])
db.commit()
return frame
def generate_and_email_pdf(upload_obj: Upload, db: Session):
"""
Generates a PDF report for a single upload, then emails it to upload_obj.user_email.
"""
# We’ll produce a single-page-ish PDF with the detection summary for this upload.
# Grab top-level stats
total_workers = upload_obj.total_workers
total_helmets = upload_obj.total_helmets
total_vests = upload_obj.total_vests
worker_images = upload_obj.worker_images.split(",") if upload_obj.worker_images else []
# Create a PDF
buffer = BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=A4)
elements = []
styles = getSampleStyleSheet()
# Title
elements.append(Paragraph("Industrial Safety Monitor Report", styles["Title"]))
elements.append(Paragraph(f"Upload ID: {upload_obj.id}", styles["Normal"]))
elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
elements.append(Paragraph(f"Approach: {upload_obj.approach}", styles["Normal"]))
elements.append(Paragraph(f"User Email: {upload_obj.user_email}", styles["Normal"]))
elements.append(Spacer(1, 12))
# Table of basic detection metrics
data = [
["Total Workers", total_workers],
["Total Helmets", total_helmets],
["Total Vests", total_vests]
]
table = Table(data, colWidths=[200, 200])
table.setStyle(TableStyle([
("BACKGROUND", (0, 0), (-1, 0), colors.grey),
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke),
("ALIGN", (0, 0), (-1, -1), "CENTER"),
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"),
("FONTSIZE", (0, 0), (-1, 0), 12),
("BOTTOMPADDING", (0, 0), (-1, 0), 12),
("BACKGROUND", (0, 1), (-1, -1), colors.beige),
("GRID", (0, 0), (-1, -1), 1, colors.black),
]))
elements.append(table)
elements.append(Spacer(1, 12))
# Show worker crops, if any
if worker_images:
elements.append(Paragraph("Detected Workers:", styles["Heading3"]))
elements.append(Spacer(1, 12))
for wimg in worker_images:
wimg = wimg.strip()
if wimg and os.path.exists(wimg):
elements.append(RLImage(wimg, width=100, height=75))
elements.append(Spacer(1, 12))
doc.build(elements)
buffer.seek(0)
pdf_data = buffer.getvalue()
# Email the PDF
receiver_email = upload_obj.user_email
if not receiver_email:
print("No email to send to.")
return # skip emailing if no user email
# Adjust credentials
sender_email = "[email protected]"
sender_password = "aobh rdgp iday bpwg"
subject = "Industrial Safety Monitor - Your Detection Report"
body = (
"Hello,\n\n"
"Please find attached the Industrial Safety Monitor detection report.\n"
"Regards,\nISM Bot"
)
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.application import MIMEApplication
msg = MIMEMultipart()
msg["From"] = sender_email
msg["To"] = receiver_email
msg["Subject"] = subject
msg.attach(MIMEText(body, "plain"))
part = MIMEApplication(pdf_data, _subtype="pdf")
part.add_header("Content-Disposition", "attachment", filename="ISM_Report.pdf")
msg.attach(part)
try:
with smtplib.SMTP("smtp.gmail.com", 587) as server:
server.starttls()
server.login(sender_email, sender_password)
server.send_message(msg)
print(f"Email sent successfully to {receiver_email}!")
except Exception as e:
print(f"Error sending email: {e}")
##############################################################################
# 1) /upload
##############################################################################
@app.post("/upload", summary="Upload image/video + email; run detection, send PDF to email.")
async def upload_file(
approach: int = Form(...),
file: UploadFile = File(...),
user_email: str = Form(...),
):
"""
1) User uploads an image/video with approach + email.
2) We run YOLO detection.
3) We store results in DB.
4) We generate a PDF and email it to `user_email`.
5) Return detection counts in JSON.
"""
global model
db = SessionLocal()
# Prepare YOLO model for the chosen approach
try:
if (model is None) or (approach not in [1, 2, 3]):
model = prepare_model(approach)
except Exception as e:
db.close()
raise HTTPException(status_code=500, detail=str(e))
# Check file type
filename = file.filename
if not allowed_file(filename):
db.close()
raise HTTPException(
status_code=400,
detail="Unsupported file type. Allowed: .png, .jpg, .jpeg, .gif, .mp4",
)
# Save the uploaded file
filepath = os.path.join(UPLOAD_FOLDER, filename)
with open(filepath, "wb") as f:
f.write(await file.read())
# Create an Upload record
upload_obj = Upload(
filename=filename,
filepath=filepath,
timestamp=datetime.now(timezone.utc),
approach=approach,
user_email=user_email,
total_workers=0,
total_helmets=0,
total_vests=0,
worker_images=""
)
db.add(upload_obj)
db.commit()
db.refresh(upload_obj)
upload_id = upload_obj.id
# If it's an image
if filename.lower().endswith((".png", ".jpg", ".jpeg", ".gif")):
img = cv2.imread(filepath)
if img is None:
db.close()
raise HTTPException(status_code=400, detail="Failed to read the image file.")
# Run detection on the single image
annotated_frame = run_detection_on_frame(img, approach, upload_id, db)
# Save processed image
processed_filename = f"processed_{filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
cv2.imwrite(processed_path, annotated_frame)
# If it's a video
elif filename.lower().endswith(".mp4"):
video = cv2.VideoCapture(filepath)
if not video.isOpened():
db.close()
raise HTTPException(status_code=400, detail="Failed to read the video file.")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
processed_filename = f"processed_{filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
original_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video.get(cv2.CAP_PROP_FPS)
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
out = cv2.VideoWriter(
processed_path, fourcc, fps, (original_width, original_height)
)
current_frame = 0
while True:
ret, frame = video.read()
if not ret:
break
current_frame += 1
print(f"Processing frame {current_frame}/{frame_count} (Upload ID={upload_id})")
annotated_frame = run_detection_on_frame(frame, approach, upload_id, db)
out.write(annotated_frame)
video.release()
out.release()
# Now fetch updated counts
db.refresh(upload_obj)
# Generate & email PDF
generate_and_email_pdf(upload_obj, db)
counts = {
"total_workers": upload_obj.total_workers,
"total_helmets": upload_obj.total_helmets,
"total_vests": upload_obj.total_vests
}
db.close()
return {
"message": f"File uploaded, detection done, PDF emailed to {user_email}.",
"upload_id": upload_id,
"counts": counts
}
##############################################################################
# 2) /results
##############################################################################
@app.get("/results", summary="Fetch the most recent upload’s details.")
def get_results():
"""
Returns the details (counts, file paths, worker_images) of the most recent upload.
"""
db = SessionLocal()
latest = db.query(Upload).order_by(Upload.timestamp.desc()).first()
if not latest:
db.close()
return {"message": "No uploads found in the database."}
processed_filename = f"processed_{latest.filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
data = {
"upload_id": latest.id,
"filename": latest.filename,
"original_path": latest.filepath,
"processed_path": processed_path if os.path.exists(processed_path) else None,
"approach": latest.approach,
"user_email": latest.user_email,
"total_workers": latest.total_workers,
"total_helmets": latest.total_helmets,
"total_vests": latest.total_vests,
"worker_images": (latest.worker_images.split(",") if latest.worker_images else []),
"timestamp": latest.timestamp.strftime("%Y-%m-%d %H:%M:%S")
}
db.close()
return data
##############################################################################
# 3) /dashboard
##############################################################################
@app.get("/dashboard", summary="Get aggregated statistics for a dashboard.")
def dashboard():
"""
Returns aggregated stats (uploads, detection sums, time-series, approach usage) in JSON.
"""
db = SessionLocal()
# Total uploads
total_uploads = db.query(Upload).count()
# Summation of detections
agg = db.query(
func.sum(Upload.total_workers).label("tw"),
func.sum(Upload.total_helmets).label("th"),
func.sum(Upload.total_vests).label("tv")
).one()
total_workers = agg.tw or 0
total_helmets = agg.th or 0
total_vests = agg.tv or 0
# Time-series by day
day_rows = db.query(
func.date(Upload.timestamp).label("day"),
func.count(Upload.id).label("uploads"),
func.sum(Upload.total_workers).label("workers"),
func.sum(Upload.total_helmets).label("helmets"),
func.sum(Upload.total_vests).label("vests")
).group_by(func.date(Upload.timestamp)).order_by(func.date(Upload.timestamp)).all()
dates = []
uploads_per_day = []
workers_per_day = []
helmets_per_day = []
vests_per_day = []
for row in day_rows:
dates.append(row.day)
uploads_per_day.append(row.uploads or 0)
workers_per_day.append(row.workers or 0)
helmets_per_day.append(row.helmets or 0)
vests_per_day.append(row.vests or 0)
# Approach usage
approach_rows = db.query(
Upload.approach,
func.count(Upload.id).label("count")
).group_by(Upload.approach).all()
approach_data = []
for ar in approach_rows:
approach_data.append({
"approach": f"Approach-{ar.approach}",
"count": ar.count
})
# Basic distribution of helmets vs. vests
safety_gear_labels = ["Helmets", "Vests"]
safety_gear_counts = [total_helmets, total_vests]
db.close()
return {
"total_uploads": total_uploads,
"total_workers": total_workers,
"total_helmets": total_helmets,
"total_vests": total_vests,
"time_series": {
"dates": dates,
"uploads_per_day": uploads_per_day,
"workers_per_day": workers_per_day,
"helmets_per_day": helmets_per_day,
"vests_per_day": vests_per_day
},
"approach_usage": approach_data,
"safety_gear_distribution": {
"labels": safety_gear_labels,
"counts": safety_gear_counts
}
}
##############################################################################
# Startup (Load YOLO Model)
##############################################################################
@app.on_event("startup")
def on_startup():
fix_tf_gpu()
global model
try:
# Load default approach=1 at startup (optional)
model_local = prepare_model(approach=1)
model = model_local
print("YOLO model (Approach=1) loaded successfully.")
except FileNotFoundError as e:
print(f"Model file not found on startup: {e}")
except Exception as e:
print(f"Error preparing model on startup: {e}")