Spaces:
Running
Running
File size: 5,622 Bytes
52f1bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import List, Tuple, Generator
import torch
import numpy as np
from PIL import Image
from surya.model.detection.model import EfficientViTForSemanticSegmentation
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
from surya.schema import TextDetectionResult
from surya.settings import settings
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import torch.nn.functional as F
def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 36
return batch_size
def batch_detection(
images: List,
model: EfficientViTForSemanticSegmentation,
processor,
batch_size=None
) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
heatmap_count = model.config.num_labels
orig_sizes = [image.size for image in images]
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
batches = []
current_batch_size = 0
current_batch = []
for i in range(len(images)):
if current_batch_size + splits_per_image[i] > batch_size:
if len(current_batch) > 0:
batches.append(current_batch)
current_batch = []
current_batch_size = 0
current_batch.append(i)
current_batch_size += splits_per_image[i]
if len(current_batch) > 0:
batches.append(current_batch)
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
batch_image_idxs = batches[batch_idx]
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
split_index = []
split_heights = []
image_splits = []
for image_idx, image in enumerate(batch_images):
image_parts, split_height = split_image(image, processor)
image_splits.extend(image_parts)
split_index.extend([image_idx] * len(image_parts))
split_heights.extend(split_height)
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
# Batch images in dim 0
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
with torch.inference_mode():
pred = model(pixel_values=batch)
logits = pred.logits
correct_shape = [processor.size["height"], processor.size["width"]]
current_shape = list(logits.shape[2:])
if current_shape != correct_shape:
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
logits = logits.cpu().detach().numpy().astype(np.float32)
preds = []
for i, (idx, height) in enumerate(zip(split_index, split_heights)):
# If our current prediction length is below the image idx, that means we have a new image
# Otherwise, we need to add to the current image
if len(preds) <= idx:
preds.append([logits[i][k] for k in range(heatmap_count)])
else:
heatmaps = preds[idx]
pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
if height < processor.size["height"]:
# Cut off padding to get original height
pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps]
for k in range(heatmap_count):
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
preds[idx] = heatmaps
yield preds, [orig_sizes[j] for j in batch_image_idxs]
def parallel_get_lines(preds, orig_sizes):
heatmap, affinity_map = preds
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
result = TextDetectionResult(
bboxes=bboxes,
vertical_lines=vertical_lines,
heatmap=heat_img,
affinity_map=aff_img,
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
)
return result
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
results = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
return results
|