dwb2023 commited on
Commit
7403d98
·
verified ·
1 Parent(s): a2e1737

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +100 -1
inference.py CHANGED
@@ -17,7 +17,8 @@ class PaliGemmaModel:
17
 
18
  @spaces.GPU
19
  def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
20
- inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.device)
 
21
  with torch.inference_mode():
22
  generated_ids = self.model.generate(
23
  **inputs,
@@ -26,3 +27,101 @@ class PaliGemmaModel:
26
  )
27
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
28
  return result[0][len(text):].lstrip("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @spaces.GPU
19
  def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
20
+ inputs = self.processor(text=text, images=image, return_tensors="pt")
21
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
22
  with torch.inference_mode():
23
  generated_ids = self.model.generate(
24
  **inputs,
 
27
  )
28
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
29
  return result[0][len(text):].lstrip("\n")
30
+
31
+ class VAEModel:
32
+ def __init__(self, model_path: str):
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ self.params = self._get_params(model_path)
35
+
36
+ def _get_params(self, checkpoint_path):
37
+ """Converts PyTorch checkpoint to Flax params."""
38
+ checkpoint = dict(np.load(checkpoint_path))
39
+
40
+ def transp(kernel):
41
+ return np.transpose(kernel, (2, 3, 1, 0))
42
+
43
+ def conv(name):
44
+ return {
45
+ 'bias': checkpoint[name + '.bias'],
46
+ 'kernel': transp(checkpoint[name + '.weight']),
47
+ }
48
+
49
+ def resblock(name):
50
+ return {
51
+ 'Conv_0': conv(name + '.0'),
52
+ 'Conv_1': conv(name + '.2'),
53
+ 'Conv_2': conv(name + '.4'),
54
+ }
55
+
56
+ return {
57
+ '_embeddings': checkpoint['_vq_vae._embedding'],
58
+ 'Conv_0': conv('decoder.0'),
59
+ 'ResBlock_0': resblock('decoder.2.net'),
60
+ 'ResBlock_1': resblock('decoder.3.net'),
61
+ 'ConvTranspose_0': conv('decoder.4'),
62
+ 'ConvTranspose_1': conv('decoder.6'),
63
+ 'ConvTranspose_2': conv('decoder.8'),
64
+ 'ConvTranspose_3': conv('decoder.10'),
65
+ 'Conv_1': conv('decoder.12'),
66
+ }
67
+
68
+ def reconstruct_masks(self, codebook_indices):
69
+ quantized = self._quantized_values_from_codebook_indices(codebook_indices)
70
+ return self._decoder().apply({'params': self.params}, quantized)
71
+
72
+ def _quantized_values_from_codebook_indices(self, codebook_indices):
73
+ batch_size, num_tokens = codebook_indices.shape
74
+ assert num_tokens == 16, codebook_indices.shape
75
+ unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape
76
+
77
+ encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0)
78
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
79
+ return encodings
80
+
81
+ @functools.cache
82
+ def _decoder(self):
83
+ class ResBlock(nn.Module):
84
+ features: int
85
+
86
+ @nn.compact
87
+ def __call__(self, x):
88
+ original_x = x
89
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
90
+ x = nn.relu(x)
91
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
92
+ x = nn.relu(x)
93
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
94
+ return x + original_x
95
+
96
+ class Decoder(nn.Module):
97
+ """Upscales quantized vectors to mask."""
98
+
99
+ @nn.compact
100
+ def __call__(self, x):
101
+ num_res_blocks = 2
102
+ dim = 128
103
+ num_upsample_layers = 4
104
+
105
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
106
+ x = nn.relu(x)
107
+
108
+ for _ in range(num_res_blocks):
109
+ x = ResBlock(features=dim)(x)
110
+
111
+ for _ in range(num_upsample_layers):
112
+ x = nn.ConvTranspose(
113
+ features=dim,
114
+ kernel_size=(4, 4),
115
+ strides=(2, 2),
116
+ padding=2,
117
+ transpose_kernel=True,
118
+ )(x)
119
+ x = nn.relu(x)
120
+ dim //= 2
121
+
122
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
123
+
124
+ return x
125
+
126
+ return jax.jit(Decoder().apply, backend='cpu')
127
+