Spaces:
Configuration error
Configuration error
# backend.py | |
import os | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
import smtplib | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, Any | |
from datetime import datetime, timezone | |
from io import BytesIO | |
# SQLAlchemy imports | |
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func | |
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session | |
# ReportLab (PDF generation) | |
from reportlab.lib.pagesizes import A4 | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as RLImage, Table, TableStyle | |
from reportlab.lib.styles import getSampleStyleSheet | |
from reportlab.lib import colors | |
# Matplotlib (Chart generation) | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
# YOLO-related imports | |
from src.yolo3.model import yolo_body | |
from src.yolo3.detect import detection | |
from src.utils.image import letterbox_image | |
from src.utils.fixes import fix_tf_gpu | |
from tensorflow.keras.layers import Input | |
############################################################################## | |
# Database Setup (SQLite) | |
############################################################################## | |
DB_URL = "sqlite:///./safety_monitor.db" | |
engine = create_engine( | |
DB_URL, connect_args={"check_same_thread": False} # for single-threaded SQLite | |
) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
class Upload(Base): | |
""" | |
Stores information about each upload (image or video), plus the user's email. | |
""" | |
__tablename__ = "uploads" | |
id = Column(Integer, primary_key=True, index=True) | |
filename = Column(String) | |
filepath = Column(String) | |
timestamp = Column(DateTime) | |
approach = Column(Integer) | |
user_email = Column(String) # The user’s email address | |
total_workers = Column(Integer, default=0) | |
total_helmets = Column(Integer, default=0) | |
total_vests = Column(Integer, default=0) | |
# We'll store worker_images as a comma-separated string for simplicity | |
worker_images = Column(Text, default="") | |
# Relationship to SafetyDetection | |
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan") | |
class SafetyDetection(Base): | |
""" | |
Stores individual safety gear detections (e.g., bounding boxes for helmets/vests). | |
""" | |
__tablename__ = "safety_detections" | |
id = Column(Integer, primary_key=True, index=True) | |
label = Column(String) # e.g. 'H', 'V' | |
box = Column(String) # bounding box as string, e.g. "x1,y1,x2,y2" | |
timestamp = Column(DateTime) | |
upload_id = Column(Integer, ForeignKey("uploads.id")) | |
upload = relationship("Upload", back_populates="detections") | |
Base.metadata.create_all(bind=engine) | |
############################################################################## | |
# FastAPI App & Configuration | |
############################################################################## | |
app = FastAPI( | |
title="Industrial Safety Monitor (FastAPI + SQLite)", | |
description="A YOLO-based safety gear detection app. Three endpoints: upload, results, dashboard.", | |
version="1.0.0", | |
) | |
# Allow cross-origin requests (optional) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Directories | |
UPLOAD_FOLDER = "static/uploads" | |
PROCESSED_FOLDER = "static/processed" | |
WORKER_FOLDER = "static/workers" | |
CHARTS_FOLDER = "static/charts" | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'mp4'} | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(PROCESSED_FOLDER, exist_ok=True) | |
os.makedirs(WORKER_FOLDER, exist_ok=True) | |
os.makedirs(CHARTS_FOLDER, exist_ok=True) | |
############################################################################## | |
# YOLO Model Setup | |
############################################################################## | |
input_shape = (416, 416) | |
class_names = [] | |
anchor_boxes = None | |
num_classes = 0 | |
num_anchors = 0 | |
model = None | |
def prepare_model(approach: int): | |
""" | |
Prepares the YOLO model for the selected approach (1, 2, or 3). | |
""" | |
global input_shape, class_names, anchor_boxes | |
global num_classes, num_anchors | |
if approach not in [1, 2, 3]: | |
raise NotImplementedError("Approach must be 1, 2, or 3") | |
# Classes: H=Helmet, V=Vest, W=Worker | |
class_names[:] = ['H', 'V', 'W'] | |
# Anchor boxes by approach | |
if approach == 1: | |
anchor_boxes = np.array( | |
[ | |
np.array([[76, 59], [84, 136], [188, 225]]) / 32, | |
np.array([[25, 15], [46, 29], [27, 56]]) / 16, | |
np.array([[5, 3], [10, 8], [12, 26]]) / 8 | |
], | |
dtype='float64' | |
) | |
elif approach == 2: | |
anchor_boxes = np.array( | |
[ | |
np.array([[73, 158], [128, 209], [224, 246]]) / 32, | |
np.array([[32, 50], [40, 104], [76, 73]]) / 16, | |
np.array([[6, 11], [11, 23], [19, 36]]) / 8 | |
], | |
dtype='float64' | |
) | |
else: # approach == 3 | |
anchor_boxes = np.array( | |
[ | |
np.array([[76, 59], [84, 136], [188, 225]]) / 32, | |
np.array([[25, 15], [46, 29], [27, 56]]) / 16, | |
np.array([[5, 3], [10, 8], [12, 26]]) / 8 | |
], | |
dtype='float64' | |
) | |
num_classes = len(class_names) | |
num_anchors = anchor_boxes.shape[0] * anchor_boxes.shape[1] | |
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3)) | |
num_out_filters = (num_anchors // 3) * (5 + num_classes) | |
_model = yolo_body(input_tensor, num_out_filters) | |
weight_path = f"model-data/weights/pictor-ppe-v302-a{approach}-yolo-v3-weights.h5" | |
if not os.path.exists(weight_path): | |
raise FileNotFoundError(f"Weight file not found: {weight_path}") | |
_model.load_weights(weight_path) | |
return _model | |
############################################################################## | |
# Utility & Detection Logic | |
############################################################################## | |
def allowed_file(filename: str) -> bool: | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def get_db() -> Session: | |
""" | |
Yields a database session. | |
""" | |
db = SessionLocal() | |
try: | |
yield db | |
finally: | |
db.close() | |
def run_detection_on_frame(frame: np.ndarray, | |
approach: int, | |
upload_id: int, | |
db: Session) -> np.ndarray: | |
""" | |
Runs YOLO detection on a single frame, updates DB counters/detections, | |
and returns the annotated frame. | |
""" | |
global model, anchor_boxes, class_names, input_shape | |
ih, iw = frame.shape[:2] | |
resized = letterbox_image(frame, input_shape) | |
resized_expanded = np.expand_dims(resized, 0) | |
image_data = np.array(resized_expanded) / 255.0 | |
prediction = model.predict(image_data) | |
boxes = detection( | |
prediction, | |
anchor_boxes, | |
len(class_names), | |
image_shape=(ih, iw), | |
input_shape=input_shape, | |
max_boxes=50, | |
score_threshold=0.3, | |
iou_threshold=0.45, | |
classes_can_overlap=False | |
)[0].numpy() | |
# Tally | |
workers, helmets, vests = [], [], [] | |
for box in boxes: | |
x1, y1, x2, y2, score, cls_id = box | |
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) | |
cls_id = int(cls_id) | |
label = class_names[cls_id] | |
if label == 'W': | |
workers.append((x1, y1, x2, y2)) | |
color = (0, 255, 0) | |
elif label == 'H': | |
helmets.append((x1, y1, x2, y2)) | |
color = (255, 0, 0) | |
elif label == 'V': | |
vests.append((x1, y1, x2, y2)) | |
color = (0, 0, 255) | |
else: | |
color = (255, 255, 0) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
cv2.putText(frame, label, (x1, y1 - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
upload_obj = db.query(Upload).filter(Upload.id == upload_id).first() | |
if upload_obj: | |
upload_obj.total_workers += len(workers) | |
upload_obj.total_helmets += len(helmets) | |
upload_obj.total_vests += len(vests) | |
db.commit() | |
# Insert SafetyDetection for helmets/vests | |
now_utc = datetime.now(timezone.utc) | |
for (hx1, hy1, hx2, hy2) in helmets: | |
db.add(SafetyDetection( | |
label='H', | |
box=f"{hx1},{hy1},{hx2},{hy2}", | |
timestamp=now_utc, | |
upload_id=upload_id | |
)) | |
for (vx1, vy1, vx2, vy2) in vests: | |
db.add(SafetyDetection( | |
label='V', | |
box=f"{vx1},{vy1},{vx2},{vy2}", | |
timestamp=now_utc, | |
upload_id=upload_id | |
)) | |
db.commit() | |
# Also save worker crops | |
worker_images_list = [] | |
for idx, (wx1, wy1, wx2, wy2) in enumerate(workers, start=1): | |
crop = frame[wy1:wy2, wx1:wx2] | |
if crop.size == 0: | |
continue | |
worker_filename = f"worker_{upload_id}_{idx}.jpg" | |
worker_path = os.path.join(WORKER_FOLDER, worker_filename) | |
cv2.imwrite(worker_path, crop) | |
worker_images_list.append(worker_path) | |
# Append new worker images | |
existing_imgs = upload_obj.worker_images.split(",") if upload_obj.worker_images else [] | |
all_imgs = existing_imgs + worker_images_list | |
upload_obj.worker_images = ",".join([w for w in all_imgs if w]) | |
db.commit() | |
return frame | |
def generate_and_email_pdf(upload_obj: Upload, db: Session): | |
""" | |
Generates a PDF report for a single upload, then emails it to upload_obj.user_email. | |
""" | |
# We’ll produce a single-page-ish PDF with the detection summary for this upload. | |
# Grab top-level stats | |
total_workers = upload_obj.total_workers | |
total_helmets = upload_obj.total_helmets | |
total_vests = upload_obj.total_vests | |
worker_images = upload_obj.worker_images.split(",") if upload_obj.worker_images else [] | |
# Create a PDF | |
buffer = BytesIO() | |
doc = SimpleDocTemplate(buffer, pagesize=A4) | |
elements = [] | |
styles = getSampleStyleSheet() | |
# Title | |
elements.append(Paragraph("Industrial Safety Monitor Report", styles["Title"])) | |
elements.append(Paragraph(f"Upload ID: {upload_obj.id}", styles["Normal"])) | |
elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"])) | |
elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"])) | |
elements.append(Paragraph(f"Approach: {upload_obj.approach}", styles["Normal"])) | |
elements.append(Paragraph(f"User Email: {upload_obj.user_email}", styles["Normal"])) | |
elements.append(Spacer(1, 12)) | |
# Table of basic detection metrics | |
data = [ | |
["Total Workers", total_workers], | |
["Total Helmets", total_helmets], | |
["Total Vests", total_vests] | |
] | |
table = Table(data, colWidths=[200, 200]) | |
table.setStyle(TableStyle([ | |
("BACKGROUND", (0, 0), (-1, 0), colors.grey), | |
("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), | |
("ALIGN", (0, 0), (-1, -1), "CENTER"), | |
("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"), | |
("FONTSIZE", (0, 0), (-1, 0), 12), | |
("BOTTOMPADDING", (0, 0), (-1, 0), 12), | |
("BACKGROUND", (0, 1), (-1, -1), colors.beige), | |
("GRID", (0, 0), (-1, -1), 1, colors.black), | |
])) | |
elements.append(table) | |
elements.append(Spacer(1, 12)) | |
# Show worker crops, if any | |
if worker_images: | |
elements.append(Paragraph("Detected Workers:", styles["Heading3"])) | |
elements.append(Spacer(1, 12)) | |
for wimg in worker_images: | |
wimg = wimg.strip() | |
if wimg and os.path.exists(wimg): | |
elements.append(RLImage(wimg, width=100, height=75)) | |
elements.append(Spacer(1, 12)) | |
doc.build(elements) | |
buffer.seek(0) | |
pdf_data = buffer.getvalue() | |
# Email the PDF | |
receiver_email = upload_obj.user_email | |
if not receiver_email: | |
print("No email to send to.") | |
return # skip emailing if no user email | |
# Adjust credentials | |
sender_email = "[email protected]" | |
sender_password = "aobh rdgp iday bpwg" | |
subject = "Industrial Safety Monitor - Your Detection Report" | |
body = ( | |
"Hello,\n\n" | |
"Please find attached the Industrial Safety Monitor detection report.\n" | |
"Regards,\nISM Bot" | |
) | |
from email.mime.multipart import MIMEMultipart | |
from email.mime.text import MIMEText | |
from email.mime.application import MIMEApplication | |
msg = MIMEMultipart() | |
msg["From"] = sender_email | |
msg["To"] = receiver_email | |
msg["Subject"] = subject | |
msg.attach(MIMEText(body, "plain")) | |
part = MIMEApplication(pdf_data, _subtype="pdf") | |
part.add_header("Content-Disposition", "attachment", filename="ISM_Report.pdf") | |
msg.attach(part) | |
try: | |
with smtplib.SMTP("smtp.gmail.com", 587) as server: | |
server.starttls() | |
server.login(sender_email, sender_password) | |
server.send_message(msg) | |
print(f"Email sent successfully to {receiver_email}!") | |
except Exception as e: | |
print(f"Error sending email: {e}") | |
############################################################################## | |
# 1) /upload | |
############################################################################## | |
async def upload_file( | |
approach: int = Form(...), | |
file: UploadFile = File(...), | |
user_email: str = Form(...), | |
): | |
""" | |
1) User uploads an image/video with approach + email. | |
2) We run YOLO detection. | |
3) We store results in DB. | |
4) We generate a PDF and email it to `user_email`. | |
5) Return detection counts in JSON. | |
""" | |
global model | |
db = SessionLocal() | |
# Prepare YOLO model for the chosen approach | |
try: | |
if (model is None) or (approach not in [1, 2, 3]): | |
model = prepare_model(approach) | |
except Exception as e: | |
db.close() | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Check file type | |
filename = file.filename | |
if not allowed_file(filename): | |
db.close() | |
raise HTTPException( | |
status_code=400, | |
detail="Unsupported file type. Allowed: .png, .jpg, .jpeg, .gif, .mp4", | |
) | |
# Save the uploaded file | |
filepath = os.path.join(UPLOAD_FOLDER, filename) | |
with open(filepath, "wb") as f: | |
f.write(await file.read()) | |
# Create an Upload record | |
upload_obj = Upload( | |
filename=filename, | |
filepath=filepath, | |
timestamp=datetime.now(timezone.utc), | |
approach=approach, | |
user_email=user_email, | |
total_workers=0, | |
total_helmets=0, | |
total_vests=0, | |
worker_images="" | |
) | |
db.add(upload_obj) | |
db.commit() | |
db.refresh(upload_obj) | |
upload_id = upload_obj.id | |
# If it's an image | |
if filename.lower().endswith((".png", ".jpg", ".jpeg", ".gif")): | |
img = cv2.imread(filepath) | |
if img is None: | |
db.close() | |
raise HTTPException(status_code=400, detail="Failed to read the image file.") | |
# Run detection on the single image | |
annotated_frame = run_detection_on_frame(img, approach, upload_id, db) | |
# Save processed image | |
processed_filename = f"processed_{filename}" | |
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename) | |
cv2.imwrite(processed_path, annotated_frame) | |
# If it's a video | |
elif filename.lower().endswith(".mp4"): | |
video = cv2.VideoCapture(filepath) | |
if not video.isOpened(): | |
db.close() | |
raise HTTPException(status_code=400, detail="Failed to read the video file.") | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
processed_filename = f"processed_{filename}" | |
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename) | |
original_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
original_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = video.get(cv2.CAP_PROP_FPS) | |
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
out = cv2.VideoWriter( | |
processed_path, fourcc, fps, (original_width, original_height) | |
) | |
current_frame = 0 | |
while True: | |
ret, frame = video.read() | |
if not ret: | |
break | |
current_frame += 1 | |
print(f"Processing frame {current_frame}/{frame_count} (Upload ID={upload_id})") | |
annotated_frame = run_detection_on_frame(frame, approach, upload_id, db) | |
out.write(annotated_frame) | |
video.release() | |
out.release() | |
# Now fetch updated counts | |
db.refresh(upload_obj) | |
# Generate & email PDF | |
generate_and_email_pdf(upload_obj, db) | |
counts = { | |
"total_workers": upload_obj.total_workers, | |
"total_helmets": upload_obj.total_helmets, | |
"total_vests": upload_obj.total_vests | |
} | |
db.close() | |
return { | |
"message": f"File uploaded, detection done, PDF emailed to {user_email}.", | |
"upload_id": upload_id, | |
"counts": counts | |
} | |
############################################################################## | |
# 2) /results | |
############################################################################## | |
def get_results(): | |
""" | |
Returns the details (counts, file paths, worker_images) of the most recent upload. | |
""" | |
db = SessionLocal() | |
latest = db.query(Upload).order_by(Upload.timestamp.desc()).first() | |
if not latest: | |
db.close() | |
return {"message": "No uploads found in the database."} | |
processed_filename = f"processed_{latest.filename}" | |
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename) | |
data = { | |
"upload_id": latest.id, | |
"filename": latest.filename, | |
"original_path": latest.filepath, | |
"processed_path": processed_path if os.path.exists(processed_path) else None, | |
"approach": latest.approach, | |
"user_email": latest.user_email, | |
"total_workers": latest.total_workers, | |
"total_helmets": latest.total_helmets, | |
"total_vests": latest.total_vests, | |
"worker_images": (latest.worker_images.split(",") if latest.worker_images else []), | |
"timestamp": latest.timestamp.strftime("%Y-%m-%d %H:%M:%S") | |
} | |
db.close() | |
return data | |
############################################################################## | |
# 3) /dashboard | |
############################################################################## | |
def dashboard(): | |
""" | |
Returns aggregated stats (uploads, detection sums, time-series, approach usage) in JSON. | |
""" | |
db = SessionLocal() | |
# Total uploads | |
total_uploads = db.query(Upload).count() | |
# Summation of detections | |
agg = db.query( | |
func.sum(Upload.total_workers).label("tw"), | |
func.sum(Upload.total_helmets).label("th"), | |
func.sum(Upload.total_vests).label("tv") | |
).one() | |
total_workers = agg.tw or 0 | |
total_helmets = agg.th or 0 | |
total_vests = agg.tv or 0 | |
# Time-series by day | |
day_rows = db.query( | |
func.date(Upload.timestamp).label("day"), | |
func.count(Upload.id).label("uploads"), | |
func.sum(Upload.total_workers).label("workers"), | |
func.sum(Upload.total_helmets).label("helmets"), | |
func.sum(Upload.total_vests).label("vests") | |
).group_by(func.date(Upload.timestamp)).order_by(func.date(Upload.timestamp)).all() | |
dates = [] | |
uploads_per_day = [] | |
workers_per_day = [] | |
helmets_per_day = [] | |
vests_per_day = [] | |
for row in day_rows: | |
dates.append(row.day) | |
uploads_per_day.append(row.uploads or 0) | |
workers_per_day.append(row.workers or 0) | |
helmets_per_day.append(row.helmets or 0) | |
vests_per_day.append(row.vests or 0) | |
# Approach usage | |
approach_rows = db.query( | |
Upload.approach, | |
func.count(Upload.id).label("count") | |
).group_by(Upload.approach).all() | |
approach_data = [] | |
for ar in approach_rows: | |
approach_data.append({ | |
"approach": f"Approach-{ar.approach}", | |
"count": ar.count | |
}) | |
# Basic distribution of helmets vs. vests | |
safety_gear_labels = ["Helmets", "Vests"] | |
safety_gear_counts = [total_helmets, total_vests] | |
db.close() | |
return { | |
"total_uploads": total_uploads, | |
"total_workers": total_workers, | |
"total_helmets": total_helmets, | |
"total_vests": total_vests, | |
"time_series": { | |
"dates": dates, | |
"uploads_per_day": uploads_per_day, | |
"workers_per_day": workers_per_day, | |
"helmets_per_day": helmets_per_day, | |
"vests_per_day": vests_per_day | |
}, | |
"approach_usage": approach_data, | |
"safety_gear_distribution": { | |
"labels": safety_gear_labels, | |
"counts": safety_gear_counts | |
} | |
} | |
############################################################################## | |
# Startup (Load YOLO Model) | |
############################################################################## | |
def on_startup(): | |
fix_tf_gpu() | |
global model | |
try: | |
# Load default approach=1 at startup (optional) | |
model_local = prepare_model(approach=1) | |
model = model_local | |
print("YOLO model (Approach=1) loaded successfully.") | |
except FileNotFoundError as e: | |
print(f"Model file not found on startup: {e}") | |
except Exception as e: | |
print(f"Error preparing model on startup: {e}") | |