Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, Any, List | |
from datetime import datetime, timezone | |
from io import BytesIO | |
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func | |
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session | |
from reportlab.lib.pagesizes import A4 | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle | |
from reportlab.lib.styles import getSampleStyleSheet | |
from reportlab.lib import colors | |
import matplotlib.pyplot as plt | |
from src.yolo3.model import yolo_body | |
from src.yolo3.detect import detection | |
from src.utils.image import letterbox_image | |
from src.utils.fixes import fix_tf_gpu | |
from tensorflow.keras.layers import Input | |
DB_URL = "sqlite:///./safety_monitor.db" | |
engine = create_engine(DB_URL, connect_args={"check_same_thread": False}) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
class Upload(Base): | |
__tablename__ = "uploads" | |
id = Column(Integer, primary_key=True, index=True) | |
filename = Column(String) | |
filepath = Column(String) | |
timestamp = Column(DateTime) | |
approach = Column(Integer) | |
total_workers = Column(Integer, default=0) | |
total_helmets = Column(Integer, default=0) | |
total_vests = Column(Integer, default=0) | |
worker_images = Column(Text, default="") | |
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan") | |
class SafetyDetection(Base): | |
__tablename__ = "safety_detections" | |
id = Column(Integer, primary_key=True, index=True) | |
label = Column(String) | |
box = Column(String) | |
timestamp = Column(DateTime) | |
upload_id = Column(Integer, ForeignKey("uploads.id")) | |
upload = relationship("Upload", back_populates="detections") | |
Base.metadata.create_all(bind=engine) | |
app = FastAPI(title="Industrial Safety Monitor", version="1.0.0") | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
UPLOAD_FOLDER = "static/uploads" | |
PROCESSED_FOLDER = "static/processed" | |
WORKER_FOLDER = "static/workers" | |
CHARTS_FOLDER = "static/charts" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(PROCESSED_FOLDER, exist_ok=True) | |
os.makedirs(WORKER_FOLDER, exist_ok=True) | |
os.makedirs(CHARTS_FOLDER, exist_ok=True) | |
input_shape = (416, 416) | |
class_names = ['H', 'V', 'W'] | |
num_classes = len(class_names) | |
num_anchors = 9 | |
model = None | |
def prepare_model(approach: int): | |
global model | |
if approach not in [1, 2, 3]: | |
raise ValueError("Approach must be 1, 2, or 3.") | |
weight_files = { | |
1: "pictor-ppe-v302-a1-yolo-v3-weights.h5", | |
2: "pictor-ppe-v302-a2-yolo-v3-weights.h5", | |
3: "pictor-ppe-v302-a3-yolo-v3-weights.h5", | |
} | |
weight_path = os.path.join("model-data", "weights", weight_files[approach]) | |
if not os.path.exists(weight_path): | |
raise FileNotFoundError(f"Weight file not found: {weight_path}") | |
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3)) | |
num_out_filters = (num_anchors // 3) * (5 + num_classes) | |
model = yolo_body(input_tensor, num_out_filters) | |
model.load_weights(weight_path) | |
def on_startup(): | |
fix_tf_gpu() | |
prepare_model(approach=1) | |
def generate_pdf(upload_id: int, total_workers: int, total_helmets: int, total_vests: int) -> str: | |
pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_id}.pdf") | |
doc = SimpleDocTemplate(pdf_path, pagesize=A4) | |
styles = getSampleStyleSheet() | |
elements = [] | |
elements.append(Paragraph("Industrial Safety Detection Report", styles["Title"])) | |
elements.append(Spacer(1, 12)) | |
data = [["Total Workers", total_workers], | |
["Total Helmets Detected", total_helmets], | |
["Total Vests Detected", total_vests]] | |
table = Table(data, colWidths=[200, 100]) | |
table.setStyle(TableStyle([ | |
('BACKGROUND', (0, 0), (-1, 0), colors.grey), | |
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
('FONTNAME', (0, 0), (-1, -1), 'Helvetica-Bold'), | |
('BOTTOMPADDING', (0, 0), (-1, 0), 12), | |
('BACKGROUND', (0, 1), (-1, -1), colors.beige), | |
])) | |
elements.append(table) | |
doc.build(elements) | |
return pdf_path | |
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)): | |
db = SessionLocal() | |
filename = file.filename | |
filepath = os.path.join(UPLOAD_FOLDER, filename) | |
with open(filepath, "wb") as f: | |
f.write(await file.read()) | |
upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach) | |
db.add(upload_obj) | |
db.commit() | |
db.refresh(upload_obj) | |
upload_id = upload_obj.id | |
img = cv2.imread(filepath) | |
if img is None: | |
db.close() | |
raise HTTPException(status_code=400, detail="Failed to read the image file.") | |
processed_img = run_detection_on_frame(img, upload_id, db) | |
processed_filename = f"processed_{filename}" | |
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename) | |
cv2.imwrite(processed_path, processed_img) | |
pdf_path = generate_pdf(upload_id, upload_obj.total_workers, upload_obj.total_helmets, upload_obj.total_vests) | |
db.refresh(upload_obj) | |
db.close() | |
return { | |
"message": "File processed successfully.", | |
"upload_id": upload_id, | |
"total_workers": upload_obj.total_workers, | |
"total_helmets": upload_obj.total_helmets, | |
"total_vests": upload_obj.total_vests, | |
"processed_image": processed_path, | |
"pdf_report": pdf_path | |
} |