import os import PIL.Image import torch from huggingface_hub import login from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import jax import jax.numpy as jnp import numpy as np import functools hf_token = os.getenv("HF_TOKEN") login(token=hf_token, add_to_git_credential=True) class PaliGemmaModel: def __init__(self): self.model_id = "google/paligemma-3b-mix-448" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device) self.processor = PaliGemmaProcessor.from_pretrained(self.model_id) def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str: inputs = self.processor(text=text, images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device with torch.inference_mode(): generated_ids = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = self.processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):].lstrip("\n") class VAEModel: def __init__(self, model_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.params = self._get_params(model_path) def _get_params(self, checkpoint_path): """Converts PyTorch checkpoint to Flax params.""" checkpoint = dict(np.load(checkpoint_path)) def transp(kernel): return np.transpose(kernel, (2, 3, 1, 0)) def conv(name): return { 'bias': checkpoint[name + '.bias'], 'kernel': transp(checkpoint[name + '.weight']), } def resblock(name): return { 'Conv_0': conv(name + '.0'), 'Conv_1': conv(name + '.2'), 'Conv_2': conv(name + '.4'), } return { '_embeddings': checkpoint['_vq_vae._embedding'], 'Conv_0': conv('decoder.0'), 'ResBlock_0': resblock('decoder.2.net'), 'ResBlock_1': resblock('decoder.3.net'), 'ConvTranspose_0': conv('decoder.4'), 'ConvTranspose_1': conv('decoder.6'), 'ConvTranspose_2': conv('decoder.8'), 'ConvTranspose_3': conv('decoder.10'), 'Conv_1': conv('decoder.12'), } def reconstruct_masks(self, codebook_indices): quantized = self._quantized_values_from_codebook_indices(codebook_indices) return self._decoder().apply({'params': self.params}, quantized) def _quantized_values_from_codebook_indices(self, codebook_indices): batch_size, num_tokens = codebook_indices.shape assert num_tokens == 16, codebook_indices.shape unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0) encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) return encodings @functools.cache def _decoder(self): class ResBlock(nn.Module): features: int @nn.compact def __call__(self, x): original_x = x x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) return x + original_x class Decoder(nn.Module): """Upscales quantized vectors to mask.""" @nn.compact def __call__(self, x): num_res_blocks = 2 dim = 128 num_upsample_layers = 4 x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) x = nn.relu(x) for _ in range(num_res_blocks): x = ResBlock(features=dim)(x) for _ in range(num_upsample_layers): x = nn.ConvTranspose( features=dim, kernel_size=(4, 4), strides=(2, 2), padding=2, transpose_kernel=True, )(x) x = nn.relu(x) dim //= 2 x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) return x return jax.jit(Decoder().apply, backend='cpu')