Spaces:
Sleeping
Sleeping
File size: 6,349 Bytes
8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 8c6f511 40d4c17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import cv2
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
from datetime import datetime, timezone
from io import BytesIO
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
import matplotlib.pyplot as plt
from src.yolo3.model import yolo_body
from src.yolo3.detect import detection
from src.utils.image import letterbox_image
from src.utils.fixes import fix_tf_gpu
from tensorflow.keras.layers import Input
DB_URL = "sqlite:///./safety_monitor.db"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
class Upload(Base):
__tablename__ = "uploads"
id = Column(Integer, primary_key=True, index=True)
filename = Column(String)
filepath = Column(String)
timestamp = Column(DateTime)
approach = Column(Integer)
total_workers = Column(Integer, default=0)
total_helmets = Column(Integer, default=0)
total_vests = Column(Integer, default=0)
worker_images = Column(Text, default="")
detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")
class SafetyDetection(Base):
__tablename__ = "safety_detections"
id = Column(Integer, primary_key=True, index=True)
label = Column(String)
box = Column(String)
timestamp = Column(DateTime)
upload_id = Column(Integer, ForeignKey("uploads.id"))
upload = relationship("Upload", back_populates="detections")
Base.metadata.create_all(bind=engine)
app = FastAPI(title="Industrial Safety Monitor", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
UPLOAD_FOLDER = "static/uploads"
PROCESSED_FOLDER = "static/processed"
WORKER_FOLDER = "static/workers"
CHARTS_FOLDER = "static/charts"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(WORKER_FOLDER, exist_ok=True)
os.makedirs(CHARTS_FOLDER, exist_ok=True)
input_shape = (416, 416)
class_names = ['H', 'V', 'W']
num_classes = len(class_names)
num_anchors = 9
model = None
def prepare_model():
global model
input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
num_out_filters = (num_anchors // 3) * (5 + num_classes)
model = yolo_body(input_tensor, num_out_filters)
weight_path = "model-data/weights/yolo_weights.h5"
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weight file not found: {weight_path}")
model.load_weights(weight_path)
@app.on_event("startup")
def on_startup():
fix_tf_gpu()
prepare_model()
@app.post("/upload")
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
global model
db = SessionLocal()
filename = file.filename
filepath = os.path.join(UPLOAD_FOLDER, filename)
with open(filepath, "wb") as f:
f.write(await file.read())
upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach)
db.add(upload_obj)
db.commit()
db.refresh(upload_obj)
upload_id = upload_obj.id
img = cv2.imread(filepath)
if img is None:
db.close()
raise HTTPException(status_code=400, detail="Failed to read the image file.")
processed_img = run_detection_on_frame(img, upload_id, db)
processed_filename = f"processed_{filename}"
processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
cv2.imwrite(processed_path, processed_img)
db.refresh(upload_obj)
pdf_path = generate_pdf(upload_obj)
db.close()
return {"message": "File processed successfully.", "upload_id": upload_id, "pdf_path": pdf_path}
def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np.ndarray:
global model
ih, iw = frame.shape[:2]
resized = letterbox_image(frame, input_shape)
resized_expanded = np.expand_dims(resized, 0)
image_data = np.array(resized_expanded) / 255.0
prediction = model.predict(image_data)
boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
workers, helmets, vests = [], [], []
for box in boxes:
x1, y1, x2, y2, _, cls_id = map(int, box)
label = class_names[cls_id]
if label == 'W':
workers.append((x1, y1, x2, y2))
elif label == 'H':
helmets.append((x1, y1, x2, y2))
elif label == 'V':
vests.append((x1, y1, x2, y2))
upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
if upload_obj:
upload_obj.total_workers += len(workers)
upload_obj.total_helmets += len(helmets)
upload_obj.total_vests += len(vests)
db.commit()
return frame
def generate_pdf(upload_obj: Upload):
buffer = BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=A4)
elements = []
styles = getSampleStyleSheet()
elements.append(Paragraph("Industrial Safety Report", styles["Title"]))
elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
data = [["Total Workers", upload_obj.total_workers], ["Total Helmets", upload_obj.total_helmets], ["Total Vests", upload_obj.total_vests]]
table = Table(data)
table.setStyle(TableStyle([("BACKGROUND", (0, 0), (-1, 0), colors.grey), ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), ("GRID", (0, 0), (-1, -1), 1, colors.black)]))
elements.append(table)
doc.build(elements)
buffer.seek(0)
pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_obj.id}.pdf")
with open(pdf_path, "wb") as f:
f.write(buffer.getvalue())
return pdf_path
|