Spaces:
Runtime error
Runtime error
import os | |
import PIL.Image | |
import torch | |
from huggingface_hub import login | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import functools | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token, add_to_git_credential=True) | |
class PaliGemmaModel: | |
def __init__(self): | |
self.model_id = "google/paligemma-3b-mix-448" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device) | |
self.processor = PaliGemmaProcessor.from_pretrained(self.model_id) | |
def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str: | |
inputs = self.processor(text=text, images=image, return_tensors="pt") | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device | |
with torch.inference_mode(): | |
generated_ids = self.model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=False | |
) | |
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return result[0][len(text):].lstrip("\n") | |
class VAEModel: | |
def __init__(self, model_path: str): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.params = self._get_params(model_path) | |
def _get_params(self, checkpoint_path): | |
"""Converts PyTorch checkpoint to Flax params.""" | |
checkpoint = dict(np.load(checkpoint_path)) | |
def transp(kernel): | |
return np.transpose(kernel, (2, 3, 1, 0)) | |
def conv(name): | |
return { | |
'bias': checkpoint[name + '.bias'], | |
'kernel': transp(checkpoint[name + '.weight']), | |
} | |
def resblock(name): | |
return { | |
'Conv_0': conv(name + '.0'), | |
'Conv_1': conv(name + '.2'), | |
'Conv_2': conv(name + '.4'), | |
} | |
return { | |
'_embeddings': checkpoint['_vq_vae._embedding'], | |
'Conv_0': conv('decoder.0'), | |
'ResBlock_0': resblock('decoder.2.net'), | |
'ResBlock_1': resblock('decoder.3.net'), | |
'ConvTranspose_0': conv('decoder.4'), | |
'ConvTranspose_1': conv('decoder.6'), | |
'ConvTranspose_2': conv('decoder.8'), | |
'ConvTranspose_3': conv('decoder.10'), | |
'Conv_1': conv('decoder.12'), | |
} | |
def reconstruct_masks(self, codebook_indices): | |
quantized = self._quantized_values_from_codebook_indices(codebook_indices) | |
return self._decoder().apply({'params': self.params}, quantized) | |
def _quantized_values_from_codebook_indices(self, codebook_indices): | |
batch_size, num_tokens = codebook_indices.shape | |
assert num_tokens == 16, codebook_indices.shape | |
unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape | |
encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0) | |
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) | |
return encodings | |
def _decoder(self): | |
class ResBlock(nn.Module): | |
features: int | |
def __call__(self, x): | |
original_x = x | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) | |
return x + original_x | |
class Decoder(nn.Module): | |
"""Upscales quantized vectors to mask.""" | |
def __call__(self, x): | |
num_res_blocks = 2 | |
dim = 128 | |
num_upsample_layers = 4 | |
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) | |
x = nn.relu(x) | |
for _ in range(num_res_blocks): | |
x = ResBlock(features=dim)(x) | |
for _ in range(num_upsample_layers): | |
x = nn.ConvTranspose( | |
features=dim, | |
kernel_size=(4, 4), | |
strides=(2, 2), | |
padding=2, | |
transpose_kernel=True, | |
)(x) | |
x = nn.relu(x) | |
dim //= 2 | |
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) | |
return x | |
return jax.jit(Decoder().apply, backend='cpu') | |