File size: 3,606 Bytes
52f1bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
from surya.settings import settings

from PIL import Image
import numpy as np


class MaxResize(object):
    def __init__(self, max_size=800):
        self.max_size = max_size

    def __call__(self, image):
        width, height = image.size
        current_max_size = max(width, height)
        scale = self.max_size / current_max_size
        resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))

        return resized_image


def to_tensor(image):
    # Convert PIL Image to NumPy array
    np_image = np.array(image).astype(np.float32)

    # Rearrange dimensions to [C, H, W] format
    np_image = np_image.transpose((2, 0, 1))

    # Normalize to [0.0, 1.0]
    np_image /= 255.0

    return torch.from_numpy(np_image)


def normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def structure_transform(image):
    image = MaxResize(1000)(image)
    tensor = to_tensor(image)
    normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return normalized_tensor


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    width, height = size
    boxes = box_cxcywh_to_xyxy(out_bbox)
    boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
    return boxes


def outputs_to_objects(outputs, img_sizes, id2label):
    m = outputs.logits.softmax(-1).max(-1)
    batch_labels = list(m.indices.detach().cpu().numpy())
    batch_scores = list(m.values.detach().cpu().numpy())
    batch_bboxes = outputs['pred_boxes'].detach().cpu()

    batch_objects = []
    for i in range(len(img_sizes)):
        pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
        pred_scores = batch_scores[i]
        pred_labels = batch_labels[i]

        objects = []
        for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
            class_label = id2label[int(label)]
            if not class_label == 'no object':
                objects.append({
                    'label': class_label,
                    'score': float(score),
                    'bbox': [float(elem) for elem in bbox]}
                )

        rows = []
        cols = []
        for i, cell in enumerate(objects):
            if cell["label"] == "table column":
                cols.append(cell)

            if cell["label"] == "table row":
                rows.append(cell)
        batch_objects.append({
            "rows": rows,
            "cols": cols
        })

    return batch_objects


def load_tatr():
    return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)


def batch_inference_tatr(model, images, batch_size):
    device = model.device
    rows_cols = []
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i + batch_size]
        pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)

        # forward pass
        with torch.no_grad():
            outputs = model(pixel_values)

        id2label = model.config.id2label
        id2label[len(model.config.id2label)] = "no object"
        rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
    return rows_cols