File size: 5,070 Bytes
c1393bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
import os 
from typing import Dict, List, Any

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# HF_TASK = os.getenv('HF_TASK')

# API_TOKEN = os.getenv('API_TOKEN')  # Ensure you replace this with your actual API token

# # Load processor and model
# PROCESSOR = AutoProcessor.from_pretrained(
#     "marutitecblic/HtmlTocode",
#     trust_remote_code=True,
#     # token=API_TOKEN,
# )
# MODEL = AutoModelForCausalLM.from_pretrained(
#     "marutitecblic/HtmlTocode",
#     # token=API_TOKEN,
#     trust_remote_code=True,
#     torch_dtype=torch.bfloat16,
# ).to(DEVICE)

# image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
# BOS_TOKEN = PROCESSOR.tokenizer.bos_token
# BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids



# def preprocess(event):
#     image = Image.open(event["file"]).convert("RGB")
#     inputs = PROCESSOR.tokenizer(
#         f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
#         return_tensors="pt",
#         add_special_tokens=False,
#     )
#     inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)
#     inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
#     return inputs

# def inference(model_inputs):
#     inputs = preprocess(model_inputs)
#     generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
#     generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
#     return {"generated_text": generated_text}

# def postprocess(model_outputs):
#     return model_outputs

# def handle(event, context):
#     model_inputs = event
#     model_outputs = inference(model_inputs)
#     response = postprocess(model_outputs)
#     return response

class ImageToTextPipeline:
    def __init__(self,model_path:str):
        # Load processor and model
        self.PROCESSOR = AutoProcessor.from_pretrained(
            model_path,
            trust_remote_code=True,
            # token=API_TOKEN,
        )
        self.MODEL = AutoModelForCausalLM.from_pretrained(
            model_path,
            # token=API_TOKEN,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
        ).to(DEVICE)
        self.image_seq_len = self.MODEL.config.perceiver_config.resampler_n_latents
        self.BOS_TOKEN = self.PROCESSOR.tokenizer.bos_token
        self.BAD_WORDS_IDS = self.PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # image = data.pop("inputs", data)

        # # process image
        # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values

        # # run prediction
        # generated_ids = self.model.generate(pixel_values)
        
        # # decode output
        # prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        image = Image.open(data["file"]).convert("RGB")
        inputs = self.PROCESSOR.tokenizer(
            f"{self.BOS_TOKEN}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
            return_tensors="pt",
            add_special_tokens=False,
        )
        inputs["pixel_values"] = self.PROCESSOR.image_processor([image], transform=self.custom_transform)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        # inputs = preprocess(model_inputs)
        generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
        generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return {"text": generated_text}
        # return {"text":prediction[0]}
    
    # @classmethod
    def convert_to_rgb(self, image):
        if image.mode == "RGB":
            return image
        image_rgba = image.convert("RGBA")
        background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, image_rgba)
        alpha_composite = alpha_composite.convert("RGB")
        return alpha_composite
    # @classmethod
    def custom_transform(self, x):
        x = self.convert_to_rgb(x)
        x = to_numpy_array(x)
        x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
        x = self.PROCESSOR.image_processor.rescale(x, scale=1 / 255)
        x = self.PROCESSOR.image_processor.normalize(
            x,
            mean=self.PROCESSOR.image_processor.image_mean,
            std=self.PROCESSOR.image_processor.image_std
        )
        x = to_channel_dimension_format(x, ChannelDimension.FIRST)
        x = torch.tensor(x)
        return x