Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
3.61 kB
import torch
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
from surya.settings import settings
from PIL import Image
import numpy as np
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def to_tensor(image):
# Convert PIL Image to NumPy array
np_image = np.array(image).astype(np.float32)
# Rearrange dimensions to [C, H, W] format
np_image = np_image.transpose((2, 0, 1))
# Normalize to [0.0, 1.0]
np_image /= 255.0
return torch.from_numpy(np_image)
def normalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def structure_transform(image):
image = MaxResize(1000)(image)
tensor = to_tensor(image)
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return normalized_tensor
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
return boxes
def outputs_to_objects(outputs, img_sizes, id2label):
m = outputs.logits.softmax(-1).max(-1)
batch_labels = list(m.indices.detach().cpu().numpy())
batch_scores = list(m.values.detach().cpu().numpy())
batch_bboxes = outputs['pred_boxes'].detach().cpu()
batch_objects = []
for i in range(len(img_sizes)):
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
pred_scores = batch_scores[i]
pred_labels = batch_labels[i]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({
'label': class_label,
'score': float(score),
'bbox': [float(elem) for elem in bbox]}
)
rows = []
cols = []
for i, cell in enumerate(objects):
if cell["label"] == "table column":
cols.append(cell)
if cell["label"] == "table row":
rows.append(cell)
batch_objects.append({
"rows": rows,
"cols": cols
})
return batch_objects
def load_tatr():
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
def batch_inference_tatr(model, images, batch_size):
device = model.device
rows_cols = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
return rows_cols