File size: 8,093 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
import os
import io
import jax
import base64
import warnings
import functools
import numpy as np
import sentencepiece
import ml_collections
from PIL import Image
import big_vision.utils
import tensorflow as tf
import supervision as sv
import big_vision.sharding
from typing import Tuple, List, Optional
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
SEQLEN = 128
class PaliGemmaManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(PaliGemmaManager, cls).__new__(cls)
return cls._instance
def __init__(self, model, params, tokenizer):
self.model = model
self.params = params
self.tokenizer = tokenizer
self.decode_fn = None
self.decode = None
self.mesh = None
self.data_sharding = None
self.params_sharding = None
self.trainable_mask = None
self.initialise_model()
def initialise_model(self):
self.decode_fn = predict_fns.get_all(self.model)['decode']
self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id())
def is_trainable_param(name, param):
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params)
self.mesh = jax.sharding.Mesh(jax.devices(), ("data"))
self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec("data"))
self.params_sharding = big_vision.sharding.infer_sharding(
self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh)
def preprocess_image(self,image, size=224):
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0
def preprocess_tokens(self, prefix, suffix=None, seqlen=None):
separator = "\n"
tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = self.tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(self, tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(self.tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return self.tokenizer.decode(tokens)
def split_and_keep_second_part(s):
parts = s.split('\n', 1)
if len(parts) > 1:
return parts[1]
return s
def data_iterator(self, image_bytes, caption):
image = Image.open(io.BytesIO(image_bytes))
image = self.preprocess_image(image)
tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
def make_predictions(self, data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, self.data_sharding)
tokens = self.decode({"params": self.params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [self.postprocess_tokens(t) for t in tokens]
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
def process_result_to_bbox(self, image, caption, classes, w, h):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
try:
detections = sv.Detections.from_lmm(
lmm='paligemma',
result=caption,
resolution_wh=(w, h),
classes=caption)
xyxy = list(detections.xyxy[0])
x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] #The number here could be result of 224x224
width = x2 - x1
height = y2 - y1
output = [x1, y1, width, height]
except Exception as e:
print('Error detection')
print(e)
output = [0,0,0,0]
return output
def predict(self, image: bytes, caption: str) -> List[int]:
image_original = Image.open(io.BytesIO(image))
original_width, original_height = image_original.size
if "detect" not in caption:
caption = f"detect {caption}"
# print("Making predictions...")
for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1):
classes = response.replace("detect ", "")
output = self.process_result_to_bbox(image, response, classes, original_width, original_height)
return (output, response)
INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png'
INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance."
TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model'
MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz'
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
paligemma_manager = PaliGemmaManager(model, params, tokenizer)
with open(INFERENCE_IMAGE, 'rb') as f:
image_bytes = f.read()
output, response = paligemma_manager.predict(image_bytes,
INFERENCE_PROMPT)
image = Image.open(INFERENCE_IMAGE)
detections = sv.Detections.from_lmm(
lmm='paligemma',
result=response,
resolution_wh=image.size,
classes=response)
coordinates = detections.xyxy[0] # assuming we want the first detection
x1, y1, x2, y2 = coordinates
print('x1,y1,x2,y2:',coordinates) |