atlury commited on
Commit
db520f8
·
verified ·
1 Parent(s): fea7704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -4,6 +4,7 @@ import cv2
4
  import numpy as np
5
  import os
6
  import requests
 
7
 
8
  # Ensure the model file is in the correct location
9
  model_path = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
@@ -14,18 +15,25 @@ if not os.path.exists(model_path):
14
  with open(model_path, "wb") as f:
15
  f.write(response.content)
16
 
17
- # Load the document segmentation model
18
- docseg_model = YOLO(model_path)
 
19
 
20
  def process_image(image):
21
  # Convert image to the format YOLO model expects
22
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
23
- results = docseg_model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
24
 
25
  # Extract annotated image from results
26
  annotated_img = results[0].plot()
 
27
 
28
- return annotated_img, results[0].boxes
 
 
 
 
 
29
 
30
  # Define the Gradio interface
31
  interface = gr.Interface(
 
4
  import numpy as np
5
  import os
6
  import requests
7
+ import torch
8
 
9
  # Ensure the model file is in the correct location
10
  model_path = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
 
15
  with open(model_path, "wb") as f:
16
  f.write(response.content)
17
 
18
+ # Load the document segmentation model on CPU
19
+ device = torch.device('cpu')
20
+ docseg_model = YOLO(model_path).to(device)
21
 
22
  def process_image(image):
23
  # Convert image to the format YOLO model expects
24
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
25
+ results = docseg_model(image)
26
 
27
  # Extract annotated image from results
28
  annotated_img = results[0].plot()
29
+ annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
30
 
31
+ # Prepare detected areas and labels as text output
32
+ detected_areas_labels = "\n".join(
33
+ [f"{box.label}: {box.conf:.2f}" for box in results[0].boxes]
34
+ )
35
+
36
+ return annotated_img, detected_areas_labels
37
 
38
  # Define the Gradio interface
39
  interface = gr.Interface(