backendsafety / app.py
nooneshouldtouch's picture
Update app.py
6cdf49c verified
raw
history blame
5.89 kB
import os
import cv2
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any, List
from datetime import datetime, timezone
from io import BytesIO
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
import matplotlib.pyplot as plt
from src.yolo3.model import yolo_body
from src.yolo3.detect import detection
from src.utils.image import letterbox_image
from src.utils.fixes import fix_tf_gpu
from tensorflow.keras.layers import Input
DB_URL = "sqlite:///./safety_monitor.db"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class Upload(Base):
__tablename__ = "uploads"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String)
filepath = Column(String)
timestamp = Column(DateTime)
approach = Column(Integer)
total_workers = Column(Integer, default=0)
total_helmets = Column(Integer, default=0)
total_vests = Column(Integer, default=0)
worker_images = Column(Text, default="")
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
class SafetyDetection(Base):
__tablename__ = "safety_detections"
id = Column(Integer, primary_key=True, index=True)
label = Column(String)
box = Column(String)
timestamp = Column(DateTime)
upload_id = Column(Integer, ForeignKey("uploads.id"))
upload = relationship("Upload", back_populates="detections")
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Industrial Safety Monitor", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
UPLOAD_FOLDER = "static/uploads"
PROCESSED_FOLDER = "static/processed"
WORKER_FOLDER = "static/workers"
CHARTS_FOLDER = "static/charts"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(WORKER_FOLDER, exist_ok=True)
os.makedirs(CHARTS_FOLDER, exist_ok=True)
input_shape = (416, 416)
class_names = ['H', 'V', 'W']
num_classes = len(class_names)
num_anchors = 9
model = None
def prepare_model(approach: int):
global model
if approach not in [1, 2, 3]:
raise ValueError("Approach must be 1, 2, or 3.")
weight_files = {
1: "pictor-ppe-v302-a1-yolo-v3-weights.h5",
2: "pictor-ppe-v302-a2-yolo-v3-weights.h5",
3: "pictor-ppe-v302-a3-yolo-v3-weights.h5",
}
weight_path = os.path.join("model-data", "weights", weight_files[approach])
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight file not found: {weight_path}")
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
num_out_filters = (num_anchors // 3) * (5 + num_classes)
model = yolo_body(input_tensor, num_out_filters)
model.load_weights(weight_path)
@app.on_event("startup")
def on_startup():
fix_tf_gpu()
prepare_model(approach=1)
def generate_pdf(upload_id: int, total_workers: int, total_helmets: int, total_vests: int) -> str:
pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_id}.pdf")
doc = SimpleDocTemplate(pdf_path, pagesize=A4)
styles = getSampleStyleSheet()
elements = []
elements.append(Paragraph("Industrial Safety Detection Report", styles["Title"]))
elements.append(Spacer(1, 12))
data = [["Total Workers", total_workers],
["Total Helmets Detected", total_helmets],
["Total Vests Detected", total_vests]]
table = Table(data, colWidths=[200, 100])
table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('FONTNAME', (0, 0), (-1, -1), 'Helvetica-Bold'),
('BOTTOMPADDING', (0, 0), (-1, 0), 12),
('BACKGROUND', (0, 1), (-1, -1), colors.beige),
]))
elements.append(table)
doc.build(elements)
return pdf_path
@app.post("/upload")
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
db = SessionLocal()
filename = file.filename
filepath = os.path.join(UPLOAD_FOLDER, filename)
with open(filepath, "wb") as f:
f.write(await file.read())
upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach)
db.add(upload_obj)
db.commit()
db.refresh(upload_obj)
upload_id = upload_obj.id
img = cv2.imread(filepath)
if img is None:
db.close()
raise HTTPException(status_code=400, detail="Failed to read the image file.")
processed_img = run_detection_on_frame(img, upload_id, db)
processed_filename = f"processed_{filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
cv2.imwrite(processed_path, processed_img)
pdf_path = generate_pdf(upload_id, upload_obj.total_workers, upload_obj.total_helmets, upload_obj.total_vests)
db.refresh(upload_obj)
db.close()
return {
"message": "File processed successfully.",
"upload_id": upload_id,
"total_workers": upload_obj.total_workers,
"total_helmets": upload_obj.total_helmets,
"total_vests": upload_obj.total_vests,
"processed_image": processed_path,
"pdf_report": pdf_path
}