Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
"""Image tokenizer.""" | |
import numpy as np | |
import torch | |
from torch import nn | |
class ImageTokenizer(nn.Module): | |
"""Tokenize image regions with visual prompts.""" | |
def __init__( | |
self, | |
image_encoder, | |
prompt_encoder, | |
image_decoder, | |
concept_projector=None, | |
text_tokenizer=None, | |
text_decoder=None, | |
pixel_mean=(103.53, 116.28, 123.675), | |
pixel_std=(57.375, 57.12, 58.395), | |
): | |
super(ImageTokenizer, self).__init__() | |
self.image_encoder = image_encoder | |
self.prompt_encoder = prompt_encoder | |
self.image_decoder = image_decoder | |
self.concept_projector = concept_projector | |
self.text_tokenizer = text_tokenizer | |
self.text_decoder = text_decoder | |
self.pixel_mean_value = pixel_mean # BGR order. | |
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean)) | |
self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_()) | |
def get_inputs(self, inputs, dtype=None): | |
"""Return the model inputs. | |
Parameters | |
---------- | |
inputs : dict | |
The initial inputs. | |
dtype : torch.dtype, optional | |
The optional input dtype. | |
Returns | |
------- | |
dict | |
The model inputs. | |
""" | |
img_dtype, img_device = self.pixel_mean.dtype, self.pixel_mean.device | |
inputs["img"] = torch.as_tensor(inputs["img"], dtype=img_dtype, device=img_device) | |
inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig).permute(0, 3, 1, 2) | |
inputs["img"] = inputs["img"].to(dtype=dtype) if dtype else inputs["img"] | |
return inputs | |
def get_features(self, inputs): | |
"""Return the image features. | |
Parameters | |
---------- | |
inputs : dict | |
The inputs. | |
Returns | |
------- | |
dict | |
The image features. | |
""" | |
features = self.image_encoder(inputs["img"]) | |
img_embeds = features[0].permute(0, 2, 3, 1).unsqueeze_(1) | |
return {"features": features, "img_embeds": img_embeds} | |
def get_outputs(self, inputs): | |
"""Return the model outputs. | |
Parameters | |
---------- | |
inputs : dict | |
The model inputs. | |
Returns | |
------- | |
dict | |
The model outputs. | |
""" | |
inputs.update(self.prompt_encoder(inputs)) | |
return self.image_decoder(inputs) | |
def forward(self, inputs): | |
"""Define the computation performed at every call. | |
Parameters | |
---------- | |
inputs : dict | |
The initial inputs. | |
Returns | |
------- | |
dict | |
The model outputs. | |
""" | |
inputs = self.get_inputs(inputs) | |
inputs.update(self.get_features(inputs)) | |
return self.get_outputs(inputs) | |
def upscale_masks(self, masks, size): | |
"""Upscale masks using bilinear interpolation. | |
Parameters | |
---------- | |
masks : torch.Tensor | |
The input masks. | |
size : Union[int, Tuple[int]] | |
The output size. | |
Returns | |
------- | |
torch.Tensor | |
The output masks. | |
""" | |
return nn.functional.interpolate(masks, size, mode="bilinear", align_corners=False) | |
def predict_concept(self, visual_embeds, k=1): | |
"""Predict top-k concepts based on visual embeddings. | |
Parameters | |
---------- | |
visual_embeds: torch.Tensor | |
The embeddings to predict visual content. | |
k : int, optional, default=1 | |
The k value. | |
Returns | |
------- | |
Tuple[numpy.ndarray, numpy.ndarray] | |
The concept scores and indices. | |
""" | |
return self.concept_projector.decode(visual_embeds, k) | |
def generate_text(self, visual_tokens, max_gen_len=None, temperature=0): | |
"""Generate text sequences based on visual tokens. | |
Parameters | |
---------- | |
visual_tokens: torch.Tensor | |
The tokens to prompt visual context. | |
max_gen_len : int, optional | |
The maximum length of the generated text sequences. | |
temperature : float, optional | |
The temperature for controlling randomness in sampling. | |
Returns | |
------- | |
np.ndarray | |
An array of generated texts. | |
""" | |
max_gen_len = max_gen_len or self.text_decoder.max_text_len | |
prompts = self.text_decoder.get_prompts(visual_tokens) | |
out_shape = (prompts.size(0), self.text_decoder.max_text_len) | |
tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64") | |
tokens[:, 0], prev_pos = self.text_tokenizer.bos_id, 0 | |
eos_reached = np.array([False] * tokens.shape[0]) | |
for cur_pos in range(1, max_gen_len): | |
decode_seq_len = cur_pos - prev_pos | |
x = torch.from_numpy(tokens[:, prev_pos:cur_pos]).to(device=prompts.device) | |
logits = self.text_decoder.transformer(prompts, x, prev_pos) | |
next_logits = logits[: x.size(0), decode_seq_len - 1] | |
if temperature > 0: | |
p = nn.functional.softmax(next_logits / temperature, dim=-1) | |
next_token = torch.multinomial(p, 1).cpu().numpy().flatten() | |
else: | |
next_token = next_logits.argmax(-1).cpu().numpy() | |
tokens[:, cur_pos] = next_token | |
eos_reached |= next_token == self.text_tokenizer.eos_id | |
prev_pos, logits, next_logits = cur_pos, None, None | |
if eos_reached.all(): | |
break | |
return np.array(self.text_tokenizer.detokenize(tokens)) | |