nooneshouldtouch commited on
Commit
c58c513
·
verified ·
1 Parent(s): 04f3332

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -68,20 +68,30 @@ num_classes = len(class_names)
68
  num_anchors = 9
69
  model = None
70
 
71
- def prepare_model():
72
  global model
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
74
  num_out_filters = (num_anchors // 3) * (5 + num_classes)
75
  model = yolo_body(input_tensor, num_out_filters)
76
- weight_path = "model-data/weights/yolo_weights.h5"
77
- if not os.path.exists(weight_path):
78
- raise FileNotFoundError(f"Weight file not found: {weight_path}")
79
  model.load_weights(weight_path)
80
 
81
  @app.on_event("startup")
82
  def on_startup():
83
  fix_tf_gpu()
84
- prepare_model()
85
 
86
  @app.post("/upload")
87
  async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
@@ -117,22 +127,8 @@ def run_detection_on_frame(frame: np.ndarray, upload_id: int, db: Session) -> np
117
  image_data = np.array(resized_expanded) / 255.0
118
  prediction = model.predict(image_data)
119
  boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
120
- workers, helmets, vests = [], [], []
121
  for box in boxes:
122
  x1, y1, x2, y2, _, cls_id = map(int, box)
123
- label = class_names[cls_id]
124
- if label == 'W':
125
- workers.append((x1, y1, x2, y2))
126
- elif label == 'H':
127
- helmets.append((x1, y1, x2, y2))
128
- elif label == 'V':
129
- vests.append((x1, y1, x2, y2))
130
- upload_obj = db.query(Upload).filter(Upload.id == upload_id).first()
131
- if upload_obj:
132
- upload_obj.total_workers += len(workers)
133
- upload_obj.total_helmets += len(helmets)
134
- upload_obj.total_vests += len(vests)
135
- db.commit()
136
  return frame
137
 
138
  def generate_pdf(upload_obj: Upload):
 
68
  num_anchors = 9
69
  model = None
70
 
71
+ def prepare_model(approach: int):
72
  global model
73
+ if approach not in [1, 2, 3]:
74
+ raise ValueError("Approach must be 1, 2, or 3.")
75
+
76
+ weight_files = {
77
+ 1: "pictor-ppe-v302-a1-yolo-v3-weights.h5",
78
+ 2: "pictor-ppe-v302-a2-yolo-v3-weights.h5",
79
+ 3: "pictor-ppe-v302-a3-yolo-v3-weights.h5",
80
+ }
81
+
82
+ weight_path = os.path.join("model-data", "weights", weight_files[approach])
83
+ if not os.path.exists(weight_path):
84
+ raise FileNotFoundError(f"Weight file not found: {weight_path}")
85
+
86
  input_tensor = Input(shape=(input_shape[0], input_shape[1], 3))
87
  num_out_filters = (num_anchors // 3) * (5 + num_classes)
88
  model = yolo_body(input_tensor, num_out_filters)
 
 
 
89
  model.load_weights(weight_path)
90
 
91
  @app.on_event("startup")
92
  def on_startup():
93
  fix_tf_gpu()
94
+ prepare_model(approach=1)
95
 
96
  @app.post("/upload")
97
  async def upload_file(approach: int = Form(...), file: UploadFile = File(...)):
 
127
  image_data = np.array(resized_expanded) / 255.0
128
  prediction = model.predict(image_data)
129
  boxes = detection(prediction, None, len(class_names), (ih, iw), input_shape, 50, 0.3, 0.45, False)[0].numpy()
 
130
  for box in boxes:
131
  x1, y1, x2, y2, _, cls_id = map(int, box)
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return frame
133
 
134
  def generate_pdf(upload_obj: Upload):