File size: 6,349 Bytes
8c6f511
 
 
 
 
40d4c17
8c6f511
 
 
 
 
 
 
40d4c17
8c6f511
 
 
 
 
 
 
 
 
 
40d4c17
8c6f511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d4c17
 
8c6f511
 
 
 
 
 
40d4c17
 
8c6f511
 
 
 
 
 
 
 
 
 
 
40d4c17
 
 
8c6f511
 
40d4c17
 
8c6f511
 
40d4c17
 
8c6f511
 
40d4c17
8c6f511
40d4c17
 
 
 
8c6f511
40d4c17
 
 
8c6f511
40d4c17
 
 
 
 
 
 
 
 
 
 
8c6f511
40d4c17
 
 
 
 
 
 
 
 
8c6f511
40d4c17
 
8c6f511
 
 
 
 
40d4c17
8c6f511
 
40d4c17
8c6f511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d4c17
8c6f511
 
 
 
40d4c17
8c6f511
 
40d4c17
 
 
8c6f511
 
 
40d4c17
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import cv2
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
from datetime import datetime, timezone
from io import BytesIO
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, func
from sqlalchemy.orm import sessionmaker, relationship, declarative_base, Session
from reportlab.lib.pagesizes import A4
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
import matplotlib.pyplot as plt
from src.yolo3.model import yolo_body
from src.yolo3.detect import detection
from src.utils.image import letterbox_image
from src.utils.fixes import fix_tf_gpu
from tensorflow.keras.layers import Input

DB_URL = "sqlite:///./safety_monitor.db"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

class Upload(Base):
    __tablename__ = "uploads"
    id = Column(Integer, primary_key=True, index=True)
    filename = Column(String)
    filepath = Column(String)
    timestamp = Column(DateTime)
    approach = Column(Integer)
    total_workers = Column(Integer, default=0)
    total_helmets = Column(Integer, default=0)
    total_vests = Column(Integer, default=0)
    worker_images = Column(Text, default="")
    detections = relationship("SafetyDetection", back_populates="upload", cascade="all, delete-orphan")

class SafetyDetection(Base):
    __tablename__ = "safety_detections"
    id = Column(Integer, primary_key=True, index=True)
    label = Column(String)
    box = Column(String)
    timestamp = Column(DateTime)
    upload_id = Column(Integer, ForeignKey("uploads.id"))
    upload = relationship("Upload", back_populates="detections")

Base.metadata.create_all(bind=engine)

app = FastAPI(title="Industrial Safety Monitor", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

UPLOAD_FOLDER = "static/uploads"
PROCESSED_FOLDER = "static/processed"
WORKER_FOLDER = "static/workers"
CHARTS_FOLDER = "static/charts"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(WORKER_FOLDER, exist_ok=True)
os.makedirs(CHARTS_FOLDER, exist_ok=True)

input_shape = (416, 416)
class_names = ['H', 'V', 'W']
num_classes = len(class_names)
num_anchors = 9
model = None

def prepare_model():
    global model
    input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
    num_out_filters = (num_anchors // 3) * (5 + num_classes)
    model = yolo_body(input_tensor, num_out_filters)
    weight_path = "model-data/weights/yolo_weights.h5"
    if not os.path.exists(weight_path):
        raise FileNotFoundError(f"Weight file not found: {weight_path}")
    model.load_weights(weight_path)

@app.on_event("startup")
def on_startup():
    fix_tf_gpu()
    prepare_model()

@app.post("/upload")
async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
    global model
    db = SessionLocal()
    filename = file.filename
    filepath = os.path.join(UPLOAD_FOLDER, filename)
    with open(filepath, "wb") as f:
        f.write(await file.read())
    upload_obj = Upload(filename=filename, filepath=filepath, timestamp=datetime.now(timezone.utc), approach=approach)
    db.add(upload_obj)
    db.commit()
    db.refresh(upload_obj)
    upload_id = upload_obj.id
    img = cv2.imread(filepath)
    if img is None:
        db.close()
        raise HTTPException(status_code=400, detail="Failed to read the image file.")
    processed_img = run_detection_on_frame(img, upload_id, db)
    processed_filename = f"processed_{filename}"
    processed_path = os.path.join(PROCESSED_FOLDER, processed_filename)
    cv2.imwrite(processed_path, processed_img)
    db.refresh(upload_obj)
    pdf_path = generate_pdf(upload_obj)
    db.close()
    return {"message": "File processed successfully.", "upload_id": upload_id, "pdf_path": pdf_path}

def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np.ndarray:
    global model
    ih, iw = frame.shape[:2]
    resized = letterbox_image(frame, input_shape)
    resized_expanded = np.expand_dims(resized, 0)
    image_data = np.array(resized_expanded) / 255.0
    prediction = model.predict(image_data)
    boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
    workers, helmets, vests = [], [], []
    for box in boxes:
        x1, y1, x2, y2, _, cls_id = map(int, box)
        label = class_names[cls_id]
        if label == 'W':
            workers.append((x1, y1, x2, y2))
        elif label == 'H':
            helmets.append((x1, y1, x2, y2))
        elif label == 'V':
            vests.append((x1, y1, x2, y2))
    upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
    if upload_obj:
        upload_obj.total_workers += len(workers)
        upload_obj.total_helmets += len(helmets)
        upload_obj.total_vests += len(vests)
        db.commit()
    return frame

def generate_pdf(upload_obj: Upload):
    buffer = BytesIO()
    doc = SimpleDocTemplate(buffer, pagesize=A4)
    elements = []
    styles = getSampleStyleSheet()
    elements.append(Paragraph("Industrial Safety Report", styles["Title"]))
    elements.append(Paragraph(f"Filename: {upload_obj.filename}", styles["Normal"]))
    elements.append(Paragraph(f"Timestamp: {upload_obj.timestamp.strftime('%Y-%m-%d %H:%M:%S')}", styles["Normal"]))
    data = [["Total Workers", upload_obj.total_workers], ["Total Helmets", upload_obj.total_helmets], ["Total Vests", upload_obj.total_vests]]
    table = Table(data)
    table.setStyle(TableStyle([("BACKGROUND", (0, 0), (-1, 0), colors.grey), ("TEXTCOLOR", (0, 0), (-1, 0), colors.whitesmoke), ("GRID", (0, 0), (-1, -1), 1, colors.black)]))
    elements.append(table)
    doc.build(elements)
    buffer.seek(0)
    pdf_path = os.path.join(PROCESSED_FOLDER, f"report_{upload_obj.id}.pdf")
    with open(pdf_path, "wb") as f:
        f.write(buffer.getvalue())
    return pdf_path