omniscience / inference.py
dwb2023's picture
Update inference.py
e38f582 verified
raw
history blame
4.8 kB
import os
import PIL.Image
import torch
from huggingface_hub import login
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import jax
import jax.numpy as jnp
import numpy as np
import functools
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)
class PaliGemmaModel:
def __init__(self):
self.model_id = "google/paligemma-3b-mix-448"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
inputs = self.processor(text=text, images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
with torch.inference_mode():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
class VAEModel:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.params = self._get_params(model_path)
def _get_params(self, checkpoint_path):
"""Converts PyTorch checkpoint to Flax params."""
checkpoint = dict(np.load(checkpoint_path))
def transp(kernel):
return np.transpose(kernel, (2, 3, 1, 0))
def conv(name):
return {
'bias': checkpoint[name + '.bias'],
'kernel': transp(checkpoint[name + '.weight']),
}
def resblock(name):
return {
'Conv_0': conv(name + '.0'),
'Conv_1': conv(name + '.2'),
'Conv_2': conv(name + '.4'),
}
return {
'_embeddings': checkpoint['_vq_vae._embedding'],
'Conv_0': conv('decoder.0'),
'ResBlock_0': resblock('decoder.2.net'),
'ResBlock_1': resblock('decoder.3.net'),
'ConvTranspose_0': conv('decoder.4'),
'ConvTranspose_1': conv('decoder.6'),
'ConvTranspose_2': conv('decoder.8'),
'ConvTranspose_3': conv('decoder.10'),
'Conv_1': conv('decoder.12'),
}
def reconstruct_masks(self, codebook_indices):
quantized = self._quantized_values_from_codebook_indices(codebook_indices)
return self._decoder().apply({'params': self.params}, quantized)
def _quantized_values_from_codebook_indices(self, codebook_indices):
batch_size, num_tokens = codebook_indices.shape
assert num_tokens == 16, codebook_indices.shape
unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape
encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0)
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
return encodings
@functools.cache
def _decoder(self):
class ResBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x):
original_x = x
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
return x + original_x
class Decoder(nn.Module):
"""Upscales quantized vectors to mask."""
@nn.compact
def __call__(self, x):
num_res_blocks = 2
dim = 128
num_upsample_layers = 4
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
x = nn.relu(x)
for _ in range(num_res_blocks):
x = ResBlock(features=dim)(x)
for _ in range(num_upsample_layers):
x = nn.ConvTranspose(
features=dim,
kernel_size=(4, 4),
strides=(2, 2),
padding=2,
transpose_kernel=True,
)(x)
x = nn.relu(x)
dim //= 2
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
return x
return jax.jit(Decoder().apply, backend='cpu')