Spaces:
Running
Running
File size: 3,606 Bytes
52f1bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
from surya.settings import settings
from PIL import Image
import numpy as np
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def to_tensor(image):
# Convert PIL Image to NumPy array
np_image = np.array(image).astype(np.float32)
# Rearrange dimensions to [C, H, W] format
np_image = np_image.transpose((2, 0, 1))
# Normalize to [0.0, 1.0]
np_image /= 255.0
return torch.from_numpy(np_image)
def normalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def structure_transform(image):
image = MaxResize(1000)(image)
tensor = to_tensor(image)
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return normalized_tensor
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
return boxes
def outputs_to_objects(outputs, img_sizes, id2label):
m = outputs.logits.softmax(-1).max(-1)
batch_labels = list(m.indices.detach().cpu().numpy())
batch_scores = list(m.values.detach().cpu().numpy())
batch_bboxes = outputs['pred_boxes'].detach().cpu()
batch_objects = []
for i in range(len(img_sizes)):
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
pred_scores = batch_scores[i]
pred_labels = batch_labels[i]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({
'label': class_label,
'score': float(score),
'bbox': [float(elem) for elem in bbox]}
)
rows = []
cols = []
for i, cell in enumerate(objects):
if cell["label"] == "table column":
cols.append(cell)
if cell["label"] == "table row":
rows.append(cell)
batch_objects.append({
"rows": rows,
"cols": cols
})
return batch_objects
def load_tatr():
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
def batch_inference_tatr(model, images, batch_size):
device = model.device
rows_cols = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
return rows_cols |