Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, Any | |
from datetime import datetime, timezone | |
from io import BytesIO | |
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func | |
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session | |
from reportlab.lib.pagesizes import A4 | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle | |
from reportlab.lib.styles import getSampleStyleSheet | |
from reportlab.lib import colors | |
import matplotlib.pyplot as plt | |
from src.yolo3.model import yolo_body | |
from src.yolo3.detect import detection | |
from src.utils.image import letterbox_image | |
from src.utils.fixes import fix_tf_gpu | |
from tensorflow.keras.layers import Input | |
DB_URL = "sqlite:///./safety_monitor.db" | |
engine = create_engine(DB_URL, connect_args={"check_same_thread": False}) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
class Upload(Base): | |
__tablename__ = "uploads" | |
id = Column(Integer, primary_key=True, index=True) | |
filename = Column(String) | |
filepath = Column(String) | |
timestamp = Column(DateTime) | |
approach = Column(Integer) | |
total_workers = Column(Integer, default=0) | |
total_helmets = Column(Integer, default=0) | |
total_vests = Column(Integer, default=0) | |
worker_images = Column(Text, default="") | |
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan") | |
class SafetyDetection(Base): | |
__tablename__ = "safety_detections" | |
id = Column(Integer, primary_key=True, index=True) | |
label = Column(String) | |
box = Column(String) | |
timestamp = Column(DateTime) | |
upload_id = Column(Integer, ForeignKey("uploads.id")) | |
upload = relationship("Upload", back_populates="detections") | |
Base.metadata.create_all(bind=engine) | |
app = FastAPI(title="Industrial Safety Monitor", version="1.0.0") | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
UPLOAD_FOLDER = "static/uploads" | |
PROCESSED_FOLDER = "static/processed" | |
WORKER_FOLDER = "static/workers" | |
CHARTS_FOLDER = "static/charts" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(PROCESSED_FOLDER, exist_ok=True) | |
os.makedirs(WORKER_FOLDER, exist_ok=True) | |
os.makedirs(CHARTS_FOLDER, exist_ok=True) | |
input_shape = (416, 416) | |
class_names = ['H', 'V', 'W'] | |
num_classes = len(class_names) | |
num_anchors = 9 | |
model = None | |
def prepare_model(): | |
global model | |
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3)) | |
num_out_filters = (num_anchors // 3) * (5 + num_classes) | |
model = yolo_body(input_tensor, num_out_filters) | |
weight_path = "model-data/weights/yolo_weights.h5" | |
if not os.path.exists(weight_path): | |
raise FileNotFoundError(f"Weight file not found: {weight_path}") | |
model.load_weights(weight_path) | |
def on_startup(): | |
fix_tf_gpu() | |
prepare_model() | |
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)): | |
global model | |
db = SessionLocal() | |
filename = file.filename | |
filepath = os.path.join(UPLOAD_FOLDER, filename) | |
with open(filepath, "wb") as f: | |
f.write(await file.read()) | |
upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach) | |
db.add(upload_obj) | |
db.commit() | |
db.refresh(upload_obj) | |
upload_id = upload_obj.id | |
img = cv2.imread(filepath) | |
if img is None: | |
db.close() | |
raise HTTPException(status_code=400, detail="Failed to read the image file.") | |
processed_img = run_detection_on_frame(img, upload_id, db) | |
processed_filename = f"processed_{filename}" | |
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename) | |
cv2.imwrite(processed_path, processed_img) | |
db.refresh(upload_obj) | |
pdf_path = generate_pdf(upload_obj) | |
db.close() | |
return {"message": "File processed successfully.", "upload_id": upload_id, "pdf_path": pdf_path} | |
def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np.ndarray: | |
global model | |
ih, iw = frame.shape[:2] | |
resized = letterbox_image(frame, input_shape) | |
resized_expanded = np.expand_dims(resized, 0) | |
image_data = np.array(resized_expanded) / 255.0 | |
prediction = model.predict(image_data) | |
boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy() | |
workers, helmets, vests = [], [], [] | |
for box in boxes: | |
x1, y1, x2, y2, _, cls_id = map(int, box) | |
label = class_names[cls_id] | |
if label == 'W': | |
workers.append((x1, y1, x2, y2)) | |
elif label == 'H': | |
helmets.append((x1, y1, x2, y2)) | |
elif label == 'V': | |
vests.append((x1, y1, x2, y2)) | |
upload_obj = db.query(Upload).filter(Upload.id == upload_id).first() | |
if upload_obj: | |
upload_obj.total_workers += len(workers) | |
upload_obj.total_helmets += len(helmets) | |
upload_obj.total_vests += len(vests) | |
db.commit() | |
return frame | |
def generate_pdf(upload_obj: Upload): | |
buffer = BytesIO() | |
doc = SimpleDocTemplate(buffer, pagesize=A4) | |
elements = [] | |
styles = getSampleStyleSheet() | |
elements.append(Paragraph("Industrial Safety Report", styles["Title"])) | |
elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"])) | |
elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"])) | |
data = [["Total Workers", upload_obj.total_workers], ["Total Helmets", upload_obj.total_helmets], ["Total Vests", upload_obj.total_vests]] | |
table = Table(data) | |
table.setStyle(TableStyle([("BACKGROUND", (0, 0), (-1, 0), colors.grey), ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), ("GRID", (0, 0), (-1, -1), 1, colors.black)])) | |
elements.append(table) | |
doc.build(elements) | |
buffer.seek(0) | |
pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_obj.id}.pdf") | |
with open(pdf_path, "wb") as f: | |
f.write(buffer.getvalue()) | |
return pdf_path | |