File size: 4,798 Bytes
5dae26f
222b8b3
5dae26f
 
 
e38f582
 
 
 
5dae26f
 
 
 
 
 
 
 
 
 
 
 
7403d98
 
5dae26f
 
 
 
 
 
 
1365247
7403d98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import PIL.Image
import torch
from huggingface_hub import login
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import jax
import jax.numpy as jnp
import numpy as np
import functools

hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)

class PaliGemmaModel:
    def __init__(self):
        self.model_id = "google/paligemma-3b-mix-448"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
        self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)

    def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
        inputs = self.processor(text=text, images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}  # Move inputs to the correct device
        with torch.inference_mode():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
        result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        return result[0][len(text):].lstrip("\n")

class VAEModel:
    def __init__(self, model_path: str):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.params = self._get_params(model_path)

    def _get_params(self, checkpoint_path):
        """Converts PyTorch checkpoint to Flax params."""
        checkpoint = dict(np.load(checkpoint_path))

        def transp(kernel):
            return np.transpose(kernel, (2, 3, 1, 0))

        def conv(name):
            return {
                'bias': checkpoint[name + '.bias'],
                'kernel': transp(checkpoint[name + '.weight']),
            }

        def resblock(name):
            return {
                'Conv_0': conv(name + '.0'),
                'Conv_1': conv(name + '.2'),
                'Conv_2': conv(name + '.4'),
            }

        return {
            '_embeddings': checkpoint['_vq_vae._embedding'],
            'Conv_0': conv('decoder.0'),
            'ResBlock_0': resblock('decoder.2.net'),
            'ResBlock_1': resblock('decoder.3.net'),
            'ConvTranspose_0': conv('decoder.4'),
            'ConvTranspose_1': conv('decoder.6'),
            'ConvTranspose_2': conv('decoder.8'),
            'ConvTranspose_3': conv('decoder.10'),
            'Conv_1': conv('decoder.12'),
        }

    def reconstruct_masks(self, codebook_indices):
        quantized = self._quantized_values_from_codebook_indices(codebook_indices)
        return self._decoder().apply({'params': self.params}, quantized)

    def _quantized_values_from_codebook_indices(self, codebook_indices):
        batch_size, num_tokens = codebook_indices.shape
        assert num_tokens == 16, codebook_indices.shape
        unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape

        encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0)
        encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
        return encodings

    @functools.cache
    def _decoder(self):
        class ResBlock(nn.Module):
            features: int

            @nn.compact
            def __call__(self, x):
                original_x = x
                x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
                x = nn.relu(x)
                x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
                x = nn.relu(x)
                x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
                return x + original_x

        class Decoder(nn.Module):
            """Upscales quantized vectors to mask."""

            @nn.compact
            def __call__(self, x):
                num_res_blocks = 2
                dim = 128
                num_upsample_layers = 4

                x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
                x = nn.relu(x)

                for _ in range(num_res_blocks):
                    x = ResBlock(features=dim)(x)

                for _ in range(num_upsample_layers):
                    x = nn.ConvTranspose(
                        features=dim,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        padding=2,
                        transpose_kernel=True,
                    )(x)
                    x = nn.relu(x)
                    dim //= 2

                x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)

                return x

        return jax.jit(Decoder().apply, backend='cpu')