|
import base64 |
|
import io |
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
|
|
|
import cv2 |
|
import easyocr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from PIL.Image import Image as ImageType |
|
from supervision.detection.core import Detections |
|
from supervision.draw.color import Color, ColorPalette |
|
from torchvision.ops import box_convert |
|
from torchvision.transforms import ToPILImage |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
from transformers.image_utils import load_image |
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
easyocr.Reader(["en"]) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str = "/repository") -> None: |
|
self.device = ( |
|
torch.device("cuda") if torch.cuda.is_available() |
|
else (torch.device("mps") if torch.backends.mps.is_available() |
|
else torch.device("cpu")) |
|
) |
|
|
|
|
|
self.yolo = YOLO(f"{model_dir}/icon_detect/model.pt") |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
"microsoft/Florence-2-base", trust_remote_code=True |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
f"{model_dir}/icon_caption", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
).to(self.device) |
|
|
|
|
|
self.ocr = easyocr.Reader(["en"]) |
|
|
|
|
|
self.annotator = BoxAnnotator() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = data.pop("inputs") |
|
|
|
|
|
image = load_image(data["image"]) |
|
|
|
ocr_texts, ocr_bboxes = self.check_ocr_bboxes( |
|
image, |
|
out_format="xyxy", |
|
ocr_kwargs={"text_threshold": 0.8}, |
|
) |
|
annotated_image, filtered_bboxes_out = self.get_som_labeled_img( |
|
image, |
|
image_size=data.get("image_size", None), |
|
ocr_texts=ocr_texts, |
|
ocr_bboxes=ocr_bboxes, |
|
bbox_threshold=data.get("bbox_threshold", 0.05), |
|
iou_threshold=data.get("iou_threshold", None), |
|
) |
|
return { |
|
"image": annotated_image, |
|
"bboxes": filtered_bboxes_out, |
|
} |
|
|
|
def check_ocr_bboxes( |
|
self, |
|
image: ImageType, |
|
out_format: Literal["xywh", "xyxy"] = "xywh", |
|
ocr_kwargs: Optional[Dict[str, Any]] = {}, |
|
) -> Tuple[List[str], List[List[int]]]: |
|
if image.mode == "RBGA": |
|
image = image.convert("RGB") |
|
|
|
result = self.ocr.readtext(np.array(image), **ocr_kwargs) |
|
texts = [str(item[1]) for item in result] |
|
bboxes = [ |
|
self.coordinates_to_bbox(item[0], format=out_format) for item in result |
|
] |
|
return (texts, bboxes) |
|
|
|
@staticmethod |
|
def coordinates_to_bbox( |
|
coordinates: np.ndarray, format: Literal["xywh", "xyxy"] = "xywh" |
|
) -> List[int]: |
|
match format: |
|
case "xywh": |
|
return [ |
|
int(coordinates[0][0]), |
|
int(coordinates[0][1]), |
|
int(coordinates[2][0] - coordinates[0][0]), |
|
int(coordinates[2][1] - coordinates[0][1]), |
|
] |
|
case "xyxy": |
|
return [ |
|
int(coordinates[0][0]), |
|
int(coordinates[0][1]), |
|
int(coordinates[2][0]), |
|
int(coordinates[2][1]), |
|
] |
|
|
|
@staticmethod |
|
def bbox_area(bbox: List[int], w: int, h: int) -> int: |
|
bbox = [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h] |
|
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
|
@staticmethod |
|
def remove_bbox_overlap( |
|
xyxy_bboxes: List[Dict[str, Any]], |
|
ocr_bboxes: Optional[List[Dict[str, Any]]] = None, |
|
iou_threshold: Optional[float] = 0.7, |
|
) -> List[Dict[str, Any]]: |
|
filtered_bboxes = [] |
|
if ocr_bboxes is not None: |
|
filtered_bboxes.extend(ocr_bboxes) |
|
|
|
for i, bbox_outter in enumerate(xyxy_bboxes): |
|
bbox_left = bbox_outter["bbox"] |
|
valid_bbox = True |
|
|
|
for j, bbox_inner in enumerate(xyxy_bboxes): |
|
if i == j: |
|
continue |
|
|
|
bbox_right = bbox_inner["bbox"] |
|
if ( |
|
intersection_over_union( |
|
bbox_left, |
|
bbox_right, |
|
) |
|
> iou_threshold |
|
) and (area(bbox_left) > area(bbox_right)): |
|
valid_bbox = False |
|
break |
|
|
|
if valid_bbox is False: |
|
continue |
|
|
|
if ocr_bboxes is None: |
|
filtered_bboxes.append(bbox_outter) |
|
continue |
|
|
|
box_added = False |
|
ocr_labels = [] |
|
for ocr_bbox in ocr_bboxes: |
|
if not box_added: |
|
bbox_right = ocr_bbox["bbox"] |
|
if overlap(bbox_right, bbox_left): |
|
try: |
|
ocr_labels.append(ocr_bbox["content"]) |
|
filtered_bboxes.remove(ocr_bbox) |
|
except Exception: |
|
continue |
|
elif overlap(bbox_left, bbox_right): |
|
box_added = True |
|
break |
|
|
|
if not box_added: |
|
filtered_bboxes.append( |
|
{ |
|
"type": "icon", |
|
"bbox": bbox_outter["bbox"], |
|
"interactivity": True, |
|
"content": " ".join(ocr_labels) if ocr_labels else None, |
|
} |
|
) |
|
|
|
return filtered_bboxes |
|
|
|
def get_som_labeled_img( |
|
self, |
|
image: ImageType, |
|
image_size: Optional[Dict[Literal["w", "h"], int]] = None, |
|
ocr_texts: Optional[List[str]] = None, |
|
ocr_bboxes: Optional[List[List[int]]] = None, |
|
bbox_threshold: float = 0.01, |
|
iou_threshold: Optional[float] = None, |
|
caption_prompt: Optional[str] = None, |
|
caption_batch_size: int = 64, |
|
) -> Tuple[str, List[Dict[str, Any]]]: |
|
if image.mode == "RBGA": |
|
image = image.convert("RGB") |
|
|
|
w, h = image.size |
|
if image_size is None: |
|
imgsz = {"h": h, "w": w} |
|
else: |
|
imgsz = [image_size.get("h", h), image_size.get("w", w)] |
|
|
|
out = self.yolo.predict( |
|
image, |
|
imgsz=imgsz, |
|
conf=bbox_threshold, |
|
iou=iou_threshold or 0.7, |
|
verbose=False, |
|
)[0] |
|
if out.boxes is None: |
|
raise RuntimeError( |
|
"YOLO prediction failed to produce the bounding boxes..." |
|
) |
|
|
|
xyxy_bboxes = out.boxes.xyxy |
|
xyxy_bboxes = xyxy_bboxes / torch.Tensor([w, h, w, h]).to(xyxy_bboxes.device) |
|
image_np = np.asarray(image) |
|
|
|
if ocr_bboxes: |
|
ocr_bboxes = torch.tensor(ocr_bboxes) / torch.Tensor([w, h, w, h]) |
|
ocr_bboxes = ocr_bboxes.tolist() |
|
|
|
ocr_bboxes = [ |
|
{ |
|
"type": "text", |
|
"bbox": bbox, |
|
"interactivity": False, |
|
"content": text, |
|
"source": "box_ocr_content_ocr", |
|
} |
|
for bbox, text in zip(ocr_bboxes, ocr_texts) |
|
if self.bbox_area(bbox, w, h) > 0 |
|
] |
|
xyxy_bboxes = [ |
|
{ |
|
"type": "icon", |
|
"bbox": bbox, |
|
"interactivity": True, |
|
"content": None, |
|
"source": "box_yolo_content_yolo", |
|
} |
|
for bbox in xyxy_bboxes.tolist() |
|
if self.bbox_area(bbox, w, h) > 0 |
|
] |
|
|
|
filtered_bboxes = self.remove_bbox_overlap( |
|
xyxy_bboxes=xyxy_bboxes, |
|
ocr_bboxes=ocr_bboxes, |
|
iou_threshold=iou_threshold or 0.7, |
|
) |
|
|
|
filtered_bboxes_out = sorted( |
|
filtered_bboxes, key=lambda x: x["content"] is None |
|
) |
|
starting_idx = next( |
|
( |
|
idx |
|
for idx, bbox in enumerate(filtered_bboxes_out) |
|
if bbox["content"] is None |
|
), |
|
-1, |
|
) |
|
|
|
filtered_bboxes = torch.tensor([box["bbox"] for box in filtered_bboxes_out]) |
|
non_ocr_bboxes = filtered_bboxes[starting_idx:] |
|
|
|
bbox_images = [] |
|
for _, coordinates in enumerate(non_ocr_bboxes): |
|
try: |
|
xmin, xmax = ( |
|
int(coordinates[0] * image_np.shape[1]), |
|
int(coordinates[2] * image_np.shape[1]), |
|
) |
|
ymin, ymax = ( |
|
int(coordinates[1] * image_np.shape[0]), |
|
int(coordinates[3] * image_np.shape[0]), |
|
) |
|
cropped_image = image_np[ymin:ymax, xmin:xmax, :] |
|
cropped_image = cv2.resize(cropped_image, (64, 64)) |
|
bbox_images.append(ToPILImage()(cropped_image)) |
|
except Exception: |
|
continue |
|
|
|
if caption_prompt is None: |
|
caption_prompt = "<CAPTION>" |
|
|
|
captions = [] |
|
for idx in range(0, len(bbox_images), caption_batch_size): |
|
batch = bbox_images[idx : idx + caption_batch_size] |
|
inputs = self.processor( |
|
images=batch, |
|
text=[caption_prompt] * len(batch), |
|
return_tensors="pt", |
|
do_resize=False, |
|
) |
|
if self.device.type in {"cuda", "mps"}: |
|
inputs = inputs.to(device=self.device, dtype=torch.float16) |
|
|
|
with torch.inference_mode(): |
|
generated_ids = self.model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=20, |
|
num_beams=1, |
|
do_sample=False, |
|
early_stopping=False, |
|
) |
|
|
|
generated_texts = self.processor.batch_decode( |
|
generated_ids, skip_special_tokens=True |
|
) |
|
captions.extend([text.strip() for text in generated_texts]) |
|
|
|
ocr_texts = [f"Text Box ID {idx}: {text}" for idx, text in enumerate(ocr_texts)] |
|
for _, bbox in enumerate(filtered_bboxes_out): |
|
if bbox["content"] is None: |
|
bbox["content"] = captions.pop(0) |
|
|
|
filtered_bboxes = box_convert( |
|
boxes=filtered_bboxes, in_fmt="xyxy", out_fmt="cxcywh" |
|
) |
|
|
|
annotated_image = image_np.copy() |
|
bboxes_annotate = filtered_bboxes * torch.Tensor([w, h, w, h]) |
|
xyxy_annotate = box_convert( |
|
bboxes_annotate, in_fmt="cxcywh", out_fmt="xyxy" |
|
).numpy() |
|
detections = Detections(xyxy=xyxy_annotate) |
|
labels = [str(idx) for idx in range(bboxes_annotate.shape[0])] |
|
|
|
annotated_image = self.annotator.annotate( |
|
scene=annotated_image, |
|
detections=detections, |
|
labels=labels, |
|
image_size=(w, h), |
|
) |
|
assert w == annotated_image.shape[1] and h == annotated_image.shape[0] |
|
|
|
out_image = Image.fromarray(annotated_image) |
|
out_buffer = io.BytesIO() |
|
out_image.save(out_buffer, format="PNG") |
|
encoded_image = base64.b64encode(out_buffer.getvalue()).decode("ascii") |
|
|
|
return encoded_image, filtered_bboxes_out |
|
|
|
|
|
def area(bbox: List[int]) -> int: |
|
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
|
|
|
def intersection_area(bbox_left: List[int], bbox_right: List[int]) -> int: |
|
return max( |
|
0, min(bbox_left[2], bbox_right[2]) - min(bbox_left[0], bbox_right[0]) |
|
) * max(0, min(bbox_left[3], bbox_right[3]) - min(bbox_left[1], bbox_right[1])) |
|
|
|
|
|
def intersection_over_union(bbox_left: List[int], bbox_right: List[int]) -> float: |
|
intersection = intersection_area(bbox_left, bbox_right) |
|
bbox_left_area = area(bbox_left) |
|
bbox_right_area = area(bbox_right) |
|
union = bbox_left_area + bbox_right_area - intersection + 1e-6 |
|
|
|
ratio_left, ratio_right = 0, 0 |
|
if bbox_left_area > 0 and bbox_right_area > 0: |
|
ratio_left = intersection / bbox_left_area |
|
ratio_right = intersection / bbox_right_area |
|
return max(intersection / union, ratio_left, ratio_right) |
|
|
|
|
|
def overlap(bbox_left: List[int], bbox_right: List[int]) -> bool: |
|
intersection = intersection_area(bbox_left, bbox_right) |
|
ratio_left = intersection / area(bbox_left) |
|
return ratio_left > 0.80 |
|
|
|
|
|
class BoxAnnotator: |
|
def __init__( |
|
self, |
|
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, |
|
thickness: int = 3, |
|
text_color: Color = Color.BLACK, |
|
text_scale: float = 0.5, |
|
text_thickness: int = 2, |
|
text_padding: int = 10, |
|
avoid_overlap: bool = True, |
|
): |
|
self.color: Union[Color, ColorPalette] = color |
|
self.thickness: int = thickness |
|
self.text_color: Color = text_color |
|
self.text_scale: float = text_scale |
|
self.text_thickness: int = text_thickness |
|
self.text_padding: int = text_padding |
|
self.avoid_overlap: bool = avoid_overlap |
|
|
|
def annotate( |
|
self, |
|
scene: np.ndarray, |
|
detections: Detections, |
|
labels: Optional[List[str]] = None, |
|
skip_label: bool = False, |
|
image_size: Optional[Tuple[int, int]] = None, |
|
) -> np.ndarray: |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
for i in range(len(detections)): |
|
x1, y1, x2, y2 = detections.xyxy[i].astype(int) |
|
class_id = ( |
|
detections.class_id[i] if detections.class_id is not None else None |
|
) |
|
idx = class_id if class_id is not None else i |
|
color = ( |
|
self.color.by_idx(idx) |
|
if isinstance(self.color, ColorPalette) |
|
else self.color |
|
) |
|
cv2.rectangle( |
|
img=scene, |
|
pt1=(x1, y1), |
|
pt2=(x2, y2), |
|
color=color.as_bgr(), |
|
thickness=self.thickness, |
|
) |
|
if skip_label: |
|
continue |
|
|
|
text = ( |
|
f"{class_id}" |
|
if (labels is None or len(detections) != len(labels)) |
|
else labels[i] |
|
) |
|
|
|
text_width, text_height = cv2.getTextSize( |
|
text=text, |
|
fontFace=font, |
|
fontScale=self.text_scale, |
|
thickness=self.text_thickness, |
|
)[0] |
|
|
|
if not self.avoid_overlap: |
|
text_x = x1 + self.text_padding |
|
text_y = y1 - self.text_padding |
|
|
|
text_background_x1 = x1 |
|
text_background_y1 = y1 - 2 * self.text_padding - text_height |
|
|
|
text_background_x2 = x1 + 2 * self.text_padding + text_width |
|
text_background_y2 = y1 |
|
else: |
|
( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) = self.get_optimal_label_pos( |
|
self.text_padding, |
|
text_width, |
|
text_height, |
|
x1, |
|
y1, |
|
x2, |
|
y2, |
|
detections, |
|
image_size, |
|
) |
|
|
|
cv2.rectangle( |
|
img=scene, |
|
pt1=(text_background_x1, text_background_y1), |
|
pt2=(text_background_x2, text_background_y2), |
|
color=color.as_bgr(), |
|
thickness=cv2.FILLED, |
|
) |
|
box_color = color.as_rgb() |
|
luminance = ( |
|
0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] |
|
) |
|
text_color = (0, 0, 0) if luminance > 160 else (255, 255, 255) |
|
cv2.putText( |
|
img=scene, |
|
text=text, |
|
org=(text_x, text_y), |
|
fontFace=font, |
|
fontScale=self.text_scale, |
|
color=text_color, |
|
thickness=self.text_thickness, |
|
lineType=cv2.LINE_AA, |
|
) |
|
return scene |
|
|
|
@staticmethod |
|
def get_optimal_label_pos( |
|
text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size |
|
): |
|
def get_is_overlap( |
|
detections, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
image_size, |
|
): |
|
is_overlap = False |
|
for i in range(len(detections)): |
|
detection = detections.xyxy[i].astype(int) |
|
if ( |
|
intersection_over_union( |
|
[ |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
], |
|
detection, |
|
) |
|
> 0.3 |
|
): |
|
is_overlap = True |
|
break |
|
if ( |
|
text_background_x1 < 0 |
|
or text_background_x2 > image_size[0] |
|
or text_background_y1 < 0 |
|
or text_background_y2 > image_size[1] |
|
): |
|
is_overlap = True |
|
return is_overlap |
|
|
|
text_x = x1 + text_padding |
|
text_y = y1 - text_padding |
|
|
|
text_background_x1 = x1 |
|
text_background_y1 = y1 - 2 * text_padding - text_height |
|
|
|
text_background_x2 = x1 + 2 * text_padding + text_width |
|
text_background_y2 = y1 |
|
is_overlap = get_is_overlap( |
|
detections, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
image_size, |
|
) |
|
if not is_overlap: |
|
return ( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) |
|
|
|
text_x = x1 - text_padding - text_width |
|
text_y = y1 + text_padding + text_height |
|
|
|
text_background_x1 = x1 - 2 * text_padding - text_width |
|
text_background_y1 = y1 |
|
|
|
text_background_x2 = x1 |
|
text_background_y2 = y1 + 2 * text_padding + text_height |
|
is_overlap = get_is_overlap( |
|
detections, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
image_size, |
|
) |
|
if not is_overlap: |
|
return ( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) |
|
|
|
text_x = x2 + text_padding |
|
text_y = y1 + text_padding + text_height |
|
|
|
text_background_x1 = x2 |
|
text_background_y1 = y1 |
|
|
|
text_background_x2 = x2 + 2 * text_padding + text_width |
|
text_background_y2 = y1 + 2 * text_padding + text_height |
|
|
|
is_overlap = get_is_overlap( |
|
detections, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
image_size, |
|
) |
|
if not is_overlap: |
|
return ( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) |
|
|
|
text_x = x2 - text_padding - text_width |
|
text_y = y1 - text_padding |
|
|
|
text_background_x1 = x2 - 2 * text_padding - text_width |
|
text_background_y1 = y1 - 2 * text_padding - text_height |
|
|
|
text_background_x2 = x2 |
|
text_background_y2 = y1 |
|
|
|
is_overlap = get_is_overlap( |
|
detections, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
image_size, |
|
) |
|
if not is_overlap: |
|
return ( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) |
|
|
|
return ( |
|
text_x, |
|
text_y, |
|
text_background_x1, |
|
text_background_y1, |
|
text_background_x2, |
|
text_background_y2, |
|
) |