File size: 5,228 Bytes
b764ffe
 
0aa924d
be425b2
8242b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec23149
e21d024
 
 
 
 
5be1788
e21d024
 
8242b6e
e21d024
 
 
 
 
 
 
 
 
e134b51
8242b6e
e21d024
8242b6e
e21d024
 
6cd21dc
c4dd123
8242b6e
0aa924d
e21d024
 
 
 
8242b6e
 
 
 
 
 
e21d024
8242b6e
0aa924d
8242b6e
 
 
e21d024
8242b6e
 
 
e21d024
8242b6e
0aa924d
8242b6e
 
 
 
 
e21d024
8242b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e21d024
8242b6e
 
 
0aa924d
 
c4dd123
0aa924d
 
 
5be1788
113f0fc
f7a222c
 
113f0fc
 
e21d024
6cd21dc
8242b6e
0aa924d
8242b6e
b764ffe
0aa924d
c4dd123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from ultralytics import YOLO
import spaces
import torch
import cv2
import numpy as np
import os
import requests

# Define constants for the new model
ENTITIES_COLORS = {
    "Caption": (191, 100, 21),
    "Footnote": (2, 62, 115),
    "Formula": (140, 80, 58),
    "List-item": (168, 181, 69),
    "Page-footer": (2, 69, 84),
    "Page-header": (83, 115, 106),
    "Picture": (255, 72, 88),
    "Section-header": (0, 204, 192),
    "Table": (116, 127, 127),
    "Text": (0, 153, 221),
    "Title": (196, 51, 2)
}
BOX_PADDING = 2

# Load pre-trained YOLOv8 models
model_paths = {
    "YOLOv8x Model": "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt",
    "YOLOv8m Model": "yolov8m-doclaynet.pt",
    "YOLOv8n Model": "yolov8n-doclaynet.pt",
    "YOLOv8s Model": "yolov8s-doclaynet.pt",
    "DLA Model": "models/dla-model.pt"
}

# Ensure the model files are in the correct location
for model_name, model_path in model_paths.items():
    if not os.path.exists(model_path):
        # For demonstration, we only download the YOLOv8x model
        if model_name == "YOLOv8x Model":
            model_url = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
            response = requests.get(model_url)
            with open(model_path, "wb") as f:
                f.write(response.content)

# Load models
models = {name: YOLO(path) for name, path in model_paths.items()}

# Get class names from the YOLOv8 models
class_names = list(ENTITIES_COLORS.keys())

@spaces.GPU(duration=60)
def process_image(image, model_choice):
    try:
        if "YOLOv8" in model_choice:
            # Use the selected YOLOv8 model
            model = models[model_choice]
            results = model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
            result = results[0]

            # Extract annotated image and labels with class names
            annotated_image = result.plot()

            detected_areas_labels = "\n".join([
                f"{class_names[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
            ])

            return annotated_image, detected_areas_labels
        
        elif model_choice == "DLA Model":
            # Use the DLA model
            image_path = "input_image.jpg"  # Temporary save the uploaded image
            cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
            image = cv2.imread(image_path)
            results = models[model_choice].predict(source=image, conf=0.2, iou=0.8)
            boxes = results[0].boxes

            if len(boxes) == 0:
                return image

            for box in boxes:
                detection_class_conf = round(box.conf.item(), 2)
                cls = class_names[int(box.cls)]
                start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
                end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))

                line_thickness = round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
                image = cv2.rectangle(img=image, 
                                      pt1=start_box, 
                                      pt2=end_box,
                                      color=ENTITIES_COLORS[cls], 
                                      thickness=line_thickness)
                
                text = cls + " " + str(detection_class_conf)
                font_thickness = max(line_thickness - 1, 1)
                (text_w, text_h), _ = cv2.getTextSize(text=text, fontFace=2, fontScale=line_thickness/3, thickness=font_thickness)
                image = cv2.rectangle(img=image,
                                      pt1=(start_box[0], start_box[1] - text_h - BOX_PADDING*2),
                                      pt2=(start_box[0] + text_w + BOX_PADDING * 2, start_box[1]),
                                      color=ENTITIES_COLORS[cls],
                                      thickness=-1)
                start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
                image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
            
            return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names)
        
        else:
            return None, "Invalid model choice"

    except Exception as e:
        return None, f"Error processing image: {e}"

# Create the Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Document Layout Segmentation Comparison (ZeroGPU)")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload Image")
        output_image = gr.Image(type="pil", label="Annotated Image")

    model_choice = gr.Dropdown(list(model_paths.keys()), label="Select Model", value="YOLOv8x Model", scale=0.5)
    output_text = gr.Textbox(label="Detected Areas and Labels")
    
    btn = gr.Button("Run Document Segmentation")
    btn.click(fn=process_image, inputs=[input_image, model_choice], outputs=[output_image, output_text])

# Launch the demo with queuing
demo.queue(max_size=1).launch()