Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,228 Bytes
b764ffe 0aa924d be425b2 8242b6e ec23149 e21d024 5be1788 e21d024 8242b6e e21d024 e134b51 8242b6e e21d024 8242b6e e21d024 6cd21dc c4dd123 8242b6e 0aa924d e21d024 8242b6e e21d024 8242b6e 0aa924d 8242b6e e21d024 8242b6e e21d024 8242b6e 0aa924d 8242b6e e21d024 8242b6e e21d024 8242b6e 0aa924d c4dd123 0aa924d 5be1788 113f0fc f7a222c 113f0fc e21d024 6cd21dc 8242b6e 0aa924d 8242b6e b764ffe 0aa924d c4dd123 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
from ultralytics import YOLO
import spaces
import torch
import cv2
import numpy as np
import os
import requests
# Define constants for the new model
ENTITIES_COLORS = {
"Caption": (191, 100, 21),
"Footnote": (2, 62, 115),
"Formula": (140, 80, 58),
"List-item": (168, 181, 69),
"Page-footer": (2, 69, 84),
"Page-header": (83, 115, 106),
"Picture": (255, 72, 88),
"Section-header": (0, 204, 192),
"Table": (116, 127, 127),
"Text": (0, 153, 221),
"Title": (196, 51, 2)
}
BOX_PADDING = 2
# Load pre-trained YOLOv8 models
model_paths = {
"YOLOv8x Model": "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt",
"YOLOv8m Model": "yolov8m-doclaynet.pt",
"YOLOv8n Model": "yolov8n-doclaynet.pt",
"YOLOv8s Model": "yolov8s-doclaynet.pt",
"DLA Model": "models/dla-model.pt"
}
# Ensure the model files are in the correct location
for model_name, model_path in model_paths.items():
if not os.path.exists(model_path):
# For demonstration, we only download the YOLOv8x model
if model_name == "YOLOv8x Model":
model_url = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
# Load models
models = {name: YOLO(path) for name, path in model_paths.items()}
# Get class names from the YOLOv8 models
class_names = list(ENTITIES_COLORS.keys())
@spaces.GPU(duration=60)
def process_image(image, model_choice):
try:
if "YOLOv8" in model_choice:
# Use the selected YOLOv8 model
model = models[model_choice]
results = model(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
result = results[0]
# Extract annotated image and labels with class names
annotated_image = result.plot()
detected_areas_labels = "\n".join([
f"{class_names[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
])
return annotated_image, detected_areas_labels
elif model_choice == "DLA Model":
# Use the DLA model
image_path = "input_image.jpg" # Temporary save the uploaded image
cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
image = cv2.imread(image_path)
results = models[model_choice].predict(source=image, conf=0.2, iou=0.8)
boxes = results[0].boxes
if len(boxes) == 0:
return image
for box in boxes:
detection_class_conf = round(box.conf.item(), 2)
cls = class_names[int(box.cls)]
start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))
line_thickness = round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
image = cv2.rectangle(img=image,
pt1=start_box,
pt2=end_box,
color=ENTITIES_COLORS[cls],
thickness=line_thickness)
text = cls + " " + str(detection_class_conf)
font_thickness = max(line_thickness - 1, 1)
(text_w, text_h), _ = cv2.getTextSize(text=text, fontFace=2, fontScale=line_thickness/3, thickness=font_thickness)
image = cv2.rectangle(img=image,
pt1=(start_box[0], start_box[1] - text_h - BOX_PADDING*2),
pt2=(start_box[0] + text_w + BOX_PADDING * 2, start_box[1]),
color=ENTITIES_COLORS[cls],
thickness=-1)
start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names)
else:
return None, "Invalid model choice"
except Exception as e:
return None, f"Error processing image: {e}"
# Create the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Document Layout Segmentation Comparison (ZeroGPU)")
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Image")
output_image = gr.Image(type="pil", label="Annotated Image")
model_choice = gr.Dropdown(list(model_paths.keys()), label="Select Model", value="YOLOv8x Model", scale=0.5)
output_text = gr.Textbox(label="Detected Areas and Labels")
btn = gr.Button("Run Document Segmentation")
btn.click(fn=process_image, inputs=[input_image, model_choice], outputs=[output_image, output_text])
# Launch the demo with queuing
demo.queue(max_size=1).launch()
|