File size: 4,772 Bytes
da59cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import Pix2StructProcessor, Pix2StructVisionModel
from utils import download_default_font, render_header

class Pix2StructForRegression(nn.Module):
    def __init__(self, sourcemodel_path, device):
        super(Pix2StructForRegression, self).__init__()
        self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path)
        print("Pix2StructForRegression Model is Loaded...")
        self.regression_layer1 = nn.Linear(768, 1536)
        self.dropout1 = nn.Dropout(0.1)
        self.regression_layer2 = nn.Linear(1536, 768)
        self.dropout2 = nn.Dropout(0.1)
        self.regression_layer3 = nn.Linear(768, 2)
        self.device = device
        print("Regression Layers are Loaded...")

    def forward(self, *args, **kwargs):
        outputs = self.model(*args, **kwargs)
        sequence_output = outputs.last_hidden_state
        first_token_output = sequence_output[:, 0, :]

        x = F.relu(self.regression_layer1(first_token_output))
        x = F.relu(self.regression_layer2(x))
        regression_output = torch.sigmoid(self.regression_layer3(x))

        return regression_output

    def load_state_dict_file(self, checkpoint_path, strict=True):
        print("Loading Model Weights...")
        state_dict = torch.load(checkpoint_path, map_location=self.device)
        self.load_state_dict(state_dict, strict=strict)
        print("Model Weights are Loaded...")

class Inference:
    def __init__(self) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.processor = self.load_model_and_processor("matcha-base", "model/pta-text-v0.1.pt")
        print("Model and Processor are Loaded...")

    def load_model_and_processor(self, model_name, checkpoint_path):
        model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device)
        model.load_state_dict_file(checkpoint_path=checkpoint_path)
        model.eval()
        model = model.to(self.device)
        processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False)
        return model, processor

    def prepare_image(self, image, prompt, processor):
        image = image.resize((1920, 1080))
        download_default_font_path = download_default_font()
        rendered_image, _, render_variables = render_header(
            image=image,
            header=prompt,
            bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0},
            font_path=download_default_font_path,
        )
        encoding = processor(
            images=rendered_image,
            max_patches=2048,
            add_special_tokens=True,
            return_tensors="pt",
        )
        return encoding, render_variables

    def predict_coordinates(self, encoding, model, render_variables):
        with torch.no_grad():
            pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"])
            new_height = render_variables["height"]
            new_header_height = render_variables["header_height"]
            new_total_height = render_variables["total_height"]

            pred_regression_outs[:, 1] = (
                (pred_regression_outs[:, 1] * new_total_height) - new_header_height
            ) / new_height

            pred_coordinates = pred_regression_outs.squeeze().tolist()
        return pred_coordinates

    def draw_circle_on_image(self, image, coordinates):
        x, y = coordinates[0] * image.width, coordinates[1] * image.height
        print(coordinates)
        draw = ImageDraw.Draw(image)
        radius = 5
        draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red")
        return image

    def process_image_and_draw_circle(self, image, prompt):
        encoding, render_variables = self.prepare_image(image, prompt, self.processor)
        pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables)
        result_image = self.draw_circle_on_image(image, pred_coordinates)
        return result_image


def main():
    inference = Inference()
    print("Model and Processor are Loaded...")
    # Gradio Interface
    iface = gr.Interface(
        fn=inference.process_image_and_draw_circle,
        inputs=[gr.Image(type="pil", label = "Upload Image"), 
                gr.Textbox(label = "Prompt", placeholder="Enter prompt here...")],
        outputs=gr.Image(type="pil"),
        title="Pix2Struct Image Processing",
        description="Upload an image and enter a prompt to see the model's prediction."
    )


    iface.launch()
if __name__ == "__main__":
    main()