TextSnap / src /utils.py
sitammeur's picture
Update src/utils.py
b6654d0 verified
raw
history blame
5.35 kB
from PIL import ImageDraw
import numpy as np
import re
# Use a color map for bounding boxes
colormap = [
"#0000FF",
"#FFA500",
"#008000",
"#800080",
"#A52A2A",
"#FFC0CB",
"#808080",
"#808000",
"#00FFFF",
"#FF0000",
"#00FF00",
"#4B0082",
"#4B0082",
"#EE82EE",
"#00FFFF",
"#FF00FF",
"#FF7F50",
"#FFD700",
"#87CEEB",
]
# Text cleaning function
def clean_text(text):
"""
Cleans the given text by removing unwanted tokens, extra spaces,
and ensures proper spacing between words and after periods.
Args:
text (str): The input text to be cleaned.
Returns:
str: The cleaned and properly formatted text.
"""
# Remove unwanted tokens
text = text.replace("<pad>", "").replace("</s>", "").strip()
# Split the text into lines and clean each line
lines = text.split("\n")
cleaned_lines = [line.strip() for line in lines if line.strip()]
# Join the cleaned lines into a single string with a space between each line
cleaned_text = " ".join(cleaned_lines)
# Ensure proper spacing between words and after periods using regex
cleaned_text = re.sub(
r"\s+", " ", cleaned_text
) # Replace multiple spaces with a single space
cleaned_text = re.sub(
r"(?<=[.])(?=[^\s])", r" ", cleaned_text
) # Add space after a period if not followed by a space
# Return the cleaned text
return cleaned_text
# Convert hex color to RGBA with the given alpha
def hex_to_rgba(hex_color, alpha):
"""
Convert a hexadecimal color code to RGBA format.
Args:
hex_color (str): The hexadecimal color code (e.g., "#FF0000").
alpha (int): The alpha value for the RGBA color (0-255).
Returns:
tuple: A tuple representing the RGBA color values (red, green, blue, alpha).
"""
hex_color = hex_color.lstrip("#")
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
return (r, g, b, alpha)
# Draw OCR bounding boxes with enhanced visual elements
def draw_ocr_bboxes(image, prediction):
"""
Draw bounding boxes with enhanced visual elements on the given image based on the OCR prediction.
Args:
image (PIL.Image.Image): The input image on which the bounding boxes will be drawn.
prediction (dict): The OCR prediction containing 'quad_boxes' and 'labels'.
Returns:
PIL.Image.Image: The image with the bounding boxes drawn.
"""
# Create a drawing object for the image with RGBA mode
draw = ImageDraw.Draw(image, "RGBA")
# Extract bounding boxes and labels from the prediction
bboxes, labels = prediction["quad_boxes"], prediction["labels"]
for i, (box, label) in enumerate(zip(bboxes, labels)):
# Select color for the bounding box and label
color = colormap[i % len(colormap)]
new_box = (np.array(box)).tolist()
# Define the outline width and corner radius for the bounding box
box_outline_width = 3
corner_radius = 10
# Draw rounded corners for the bounding box
for j in range(4):
start_x, start_y = new_box[j * 2], new_box[j * 2 + 1]
end_x, end_y = new_box[(j * 2 + 2) % 8], new_box[(j * 2 + 3) % 8]
# Draw the arcs for the rounded corners
draw.arc(
[
(start_x - corner_radius, start_y - corner_radius),
(start_x + corner_radius, start_y + corner_radius),
],
90 + j * 90,
180 + j * 90,
fill=color,
width=box_outline_width,
)
draw.arc(
[
(end_x - corner_radius, end_y - corner_radius),
(end_x + corner_radius, end_y + corner_radius),
],
j * 90,
90 + j * 90,
fill=color,
width=box_outline_width,
)
# Draw the lines connecting the arcs
if j in [0, 1, 2]:
draw.line(
[
(start_x + corner_radius if j != 1 else start_x, start_y),
(end_x - corner_radius if j != 1 else end_x, end_y),
],
fill=color,
width=box_outline_width,
)
else:
draw.line(
[
(start_x, start_y + corner_radius),
(end_x, end_y - corner_radius),
],
fill=color,
width=box_outline_width,
)
# Calculate the position for the text label
text_x, text_y = min(new_box[0::2]), min(new_box[1::2]) - 20
text_w, text_h = draw.textlength(label)
rgba_color = hex_to_rgba(color, 200) # Semi-transparent background for text
# Draw the background rectangle for the text
draw.rectangle(
[text_x, text_y, text_x + text_w + 10, text_y + text_h + 10],
fill=rgba_color,
)
# Draw the text label
draw.text((text_x + 5, text_y + 5), label, fill=(0, 0, 0, 255))
# Return the image with the OCR boxes drawn
return image