backendsafety / app.py
nooneshouldtouch's picture
Update app.py
40d4c17 verified
raw
history blame
6.35 kB
import os
import cv2
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
from datetime import datetime, timezone
from io import BytesIO
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
import matplotlib.pyplot as plt
from src.yolo3.model import yolo_body
from src.yolo3.detect import detection
from src.utils.image import letterbox_image
from src.utils.fixes import fix_tf_gpu
from tensorflow.keras.layers import Input
DB_URL = "sqlite:///./safety_monitor.db"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class Upload(Base):
__tablename__ = "uploads"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String)
filepath = Column(String)
timestamp = Column(DateTime)
approach = Column(Integer)
total_workers = Column(Integer, default=0)
total_helmets = Column(Integer, default=0)
total_vests = Column(Integer, default=0)
worker_images = Column(Text, default="")
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
class SafetyDetection(Base):
__tablename__ = "safety_detections"
id = Column(Integer, primary_key=True, index=True)
label = Column(String)
box = Column(String)
timestamp = Column(DateTime)
upload_id = Column(Integer, ForeignKey("uploads.id"))
upload = relationship("Upload", back_populates="detections")
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Industrial Safety Monitor", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
UPLOAD_FOLDER = "static/uploads"
PROCESSED_FOLDER = "static/processed"
WORKER_FOLDER = "static/workers"
CHARTS_FOLDER = "static/charts"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(WORKER_FOLDER, exist_ok=True)
os.makedirs(CHARTS_FOLDER, exist_ok=True)
input_shape = (416, 416)
class_names = ['H', 'V', 'W']
num_classes = len(class_names)
num_anchors = 9
model = None
def prepare_model():
global model
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
num_out_filters = (num_anchors // 3) * (5 + num_classes)
model = yolo_body(input_tensor, num_out_filters)
weight_path = "model-data/weights/yolo_weights.h5"
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight file not found: {weight_path}")
model.load_weights(weight_path)
@app.on_event("startup")
def on_startup():
fix_tf_gpu()
prepare_model()
@app.post("/upload")
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
global model
db = SessionLocal()
filename = file.filename
filepath = os.path.join(UPLOAD_FOLDER, filename)
with open(filepath, "wb") as f:
f.write(await file.read())
upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach)
db.add(upload_obj)
db.commit()
db.refresh(upload_obj)
upload_id = upload_obj.id
img = cv2.imread(filepath)
if img is None:
db.close()
raise HTTPException(status_code=400, detail="Failed to read the image file.")
processed_img = run_detection_on_frame(img, upload_id, db)
processed_filename = f"processed_{filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
cv2.imwrite(processed_path, processed_img)
db.refresh(upload_obj)
pdf_path = generate_pdf(upload_obj)
db.close()
return {"message": "File processed successfully.", "upload_id": upload_id, "pdf_path": pdf_path}
def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np.ndarray:
global model
ih, iw = frame.shape[:2]
resized = letterbox_image(frame, input_shape)
resized_expanded = np.expand_dims(resized, 0)
image_data = np.array(resized_expanded) / 255.0
prediction = model.predict(image_data)
boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
workers, helmets, vests = [], [], []
for box in boxes:
x1, y1, x2, y2, _, cls_id = map(int, box)
label = class_names[cls_id]
if label == 'W':
workers.append((x1, y1, x2, y2))
elif label == 'H':
helmets.append((x1, y1, x2, y2))
elif label == 'V':
vests.append((x1, y1, x2, y2))
upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
if upload_obj:
upload_obj.total_workers += len(workers)
upload_obj.total_helmets += len(helmets)
upload_obj.total_vests += len(vests)
db.commit()
return frame
def generate_pdf(upload_obj: Upload):
buffer = BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=A4)
elements = []
styles = getSampleStyleSheet()
elements.append(Paragraph("Industrial Safety Report", styles["Title"]))
elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
data = [["Total Workers", upload_obj.total_workers], ["Total Helmets", upload_obj.total_helmets], ["Total Vests", upload_obj.total_vests]]
table = Table(data)
table.setStyle(TableStyle([("BACKGROUND", (0, 0), (-1, 0), colors.grey), ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), ("GRID", (0, 0), (-1, -1), 1, colors.black)]))
elements.append(table)
doc.build(elements)
buffer.seek(0)
pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_obj.id}.pdf")
with open(pdf_path, "wb") as f:
f.write(buffer.getvalue())
return pdf_path