File size: 8,093 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

import os
import io
import jax
import base64
import warnings
import functools
import numpy as np
import sentencepiece
import ml_collections
from PIL import Image
import big_vision.utils
import tensorflow as tf
import supervision as sv
import big_vision.sharding
from typing import Tuple, List, Optional
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns



SEQLEN = 128

class PaliGemmaManager:
  _instance = None

  def __new__(cls, *args, **kwargs):
    if not cls._instance:
      cls._instance = super(PaliGemmaManager, cls).__new__(cls)
    return cls._instance

  def __init__(self, model, params, tokenizer):
    self.model = model
    self.params = params
    self.tokenizer = tokenizer
    self.decode_fn = None
    self.decode = None
    self.mesh = None
    self.data_sharding = None
    self.params_sharding = None
    self.trainable_mask = None

    self.initialise_model()


  def initialise_model(self):
    self.decode_fn = predict_fns.get_all(self.model)['decode']
    self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id())

    def is_trainable_param(name, param): 
      if name.startswith("llm/layers/attn/"):  return True
      if name.startswith("llm/"):              return False
      if name.startswith("img/"):              return False
      raise ValueError(f"Unexpected param name {name}")
    self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params)

    self.mesh = jax.sharding.Mesh(jax.devices(), ("data"))

    self.data_sharding = jax.sharding.NamedSharding(
        self.mesh, jax.sharding.PartitionSpec("data"))

    self.params_sharding = big_vision.sharding.infer_sharding(
        self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh)
  def preprocess_image(self,image, size=224):
    image = np.asarray(image)
    if image.ndim == 2:  # Convert image without last channel into greyscale.
        image = np.stack((image,)*3, axis=-1)

    image = image[..., :3]  # Remove alpha layer.
    assert image.shape[-1] == 3

    image = tf.constant(image)
    image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
    return image.numpy() / 127.5 - 1.0 

  def preprocess_tokens(self, prefix, suffix=None, seqlen=None):
    separator = "\n"
    tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
    mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

    if suffix:
        suffix = self.tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
        mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

    mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
    if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding

    return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

  def postprocess_tokens(self, tokens):
    tokens = tokens.tolist()  # np.array to list[int]
    try:  # Remove tokens at and after EOS if any.
        eos_pos = tokens.index(self.tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return self.tokenizer.decode(tokens)

  def split_and_keep_second_part(s):
    parts = s.split('\n', 1)
    if len(parts) > 1:
        return parts[1]
    return s

  def data_iterator(self, image_bytes, caption):
    image = Image.open(io.BytesIO(image_bytes))
    image = self.preprocess_image(image)
    tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

  def make_predictions(self, data_iterator, *, num_examples=None,
                         batch_size=4, seqlen=SEQLEN, sampler="greedy"):
    outputs = []
    while True:
        examples = []
        try:
            for _ in range(batch_size):
                examples.append(next(data_iterator))
                examples[-1]["_mask"] = np.array(True)  # Indicates true example.
        except StopIteration:
            if len(examples) == 0:
                return outputs

    
        while len(examples) % batch_size:
            examples.append(dict(examples[-1]))
            examples[-1]["_mask"] = np.array(False)  # Indicates padding example.


        batch = jax.tree.map(lambda *x: np.stack(x), *examples)
        batch = big_vision.utils.reshard(batch, self.data_sharding)
        tokens = self.decode({"params": self.params}, batch=batch,
                        max_decode_len=seqlen, sampler=sampler)

        # Fetch model predictions to device and detokenize.
        tokens, mask = jax.device_get((tokens, batch["_mask"]))
        tokens = tokens[mask]  # remove padding examples.
        responses = [self.postprocess_tokens(t) for t in tokens]

        for example, response in zip(examples, responses):
            outputs.append((example["image"], response))
            if num_examples and len(outputs) >= num_examples:
              return outputs

  def process_result_to_bbox(self, image, caption, classes, w, h):
    image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]

    try:
        detections = sv.Detections.from_lmm(
            lmm='paligemma',
            result=caption,
            resolution_wh=(w, h),
            classes=caption)

        xyxy = list(detections.xyxy[0])
        x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] #The number here could be result of 224x224
        width = x2 - x1
        height = y2 - y1
        output = [x1, y1, width, height]
    except Exception as e:
        print('Error detection')
        print(e)
        output = [0,0,0,0]

    return output

  def predict(self, image: bytes, caption: str) -> List[int]:
    image_original = Image.open(io.BytesIO(image))
    original_width, original_height = image_original.size
    if "detect" not in caption:
      caption = f"detect {caption}"
    # print("Making predictions...")
    for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1):
      classes = response.replace("detect ", "")
      output = self.process_result_to_bbox(image, response, classes, original_width, original_height)

    return (output, response)






INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png'
INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance."




TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model'
MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz'


model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
paligemma_manager = PaliGemmaManager(model, params, tokenizer)

with open(INFERENCE_IMAGE, 'rb') as f:
    image_bytes = f.read()

output, response = paligemma_manager.predict(image_bytes,
                                               INFERENCE_PROMPT)
image = Image.open(INFERENCE_IMAGE)
detections = sv.Detections.from_lmm(
          lmm='paligemma',
          result=response,
          resolution_wh=image.size,
          classes=response)

coordinates = detections.xyxy[0]  # assuming we want the first detection
x1, y1, x2, y2 = coordinates

print('x1,y1,x2,y2:',coordinates)