dwb2023 commited on
Commit
725ac0d
·
verified ·
1 Parent(s): 7403d98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -116
app.py CHANGED
@@ -6,12 +6,13 @@ import jax
6
  import jax.numpy as jnp
7
  import numpy as np
8
  import flax.linen as nn
9
- from inference import PaliGemmaModel
10
 
11
  COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
12
 
13
- # Instantiate the model
14
  pali_gemma_model = PaliGemmaModel()
 
15
 
16
  ##### Parse segmentation output tokens into masks
17
  ##### Also returns bounding boxes with their labels
@@ -120,118 +121,6 @@ with gr.Blocks(css="style.css") as demo:
120
  ### Postprocessing Utils for Segmentation Tokens
121
  ### Segmentation tokens are passed to another VAE which decodes them to a mask
122
 
123
- _MODEL_PATH = 'vae-oid.npz'
124
-
125
- _SEGMENT_DETECT_RE = re.compile(
126
- r'(.*?)' +
127
- r'<loc(\d{4})>' * 4 + r'\s*' +
128
- '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
129
- r'\s*([^;<>]+)? ?(?:; )?',
130
- )
131
-
132
- def _get_params(checkpoint):
133
- """Converts PyTorch checkpoint to Flax params."""
134
-
135
- def transp(kernel):
136
- return np.transpose(kernel, (2, 3, 1, 0))
137
-
138
- def conv(name):
139
- return {
140
- 'bias': checkpoint[name + '.bias'],
141
- 'kernel': transp(checkpoint[name + '.weight']),
142
- }
143
-
144
- def resblock(name):
145
- return {
146
- 'Conv_0': conv(name + '.0'),
147
- 'Conv_1': conv(name + '.2'),
148
- 'Conv_2': conv(name + '.4'),
149
- }
150
-
151
- return {
152
- '_embeddings': checkpoint['_vq_vae._embedding'],
153
- 'Conv_0': conv('decoder.0'),
154
- 'ResBlock_0': resblock('decoder.2.net'),
155
- 'ResBlock_1': resblock('decoder.3.net'),
156
- 'ConvTranspose_0': conv('decoder.4'),
157
- 'ConvTranspose_1': conv('decoder.6'),
158
- 'ConvTranspose_2': conv('decoder.8'),
159
- 'ConvTranspose_3': conv('decoder.10'),
160
- 'Conv_1': conv('decoder.12'),
161
- }
162
-
163
- def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
164
- batch_size, num_tokens = codebook_indices.shape
165
- assert num_tokens == 16, codebook_indices.shape
166
- unused_num_embeddings, embedding_dim = embeddings.shape
167
-
168
- encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
169
- encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
170
- return encodings
171
-
172
- @functools.cache
173
- def _get_reconstruct_masks():
174
- """Reconstructs masks from codebook indices.
175
- Returns:
176
- A function that expects indices shaped `[B, 16]` of dtype int32, each
177
- ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
178
- `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
179
- """
180
-
181
- class ResBlock(nn.Module):
182
- features: int
183
-
184
- @nn.compact
185
- def __call__(self, x):
186
- original_x = x
187
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
188
- x = nn.relu(x)
189
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
190
- x = nn.relu(x)
191
- x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
192
- return x + original_x
193
-
194
- class Decoder(nn.Module):
195
- """Upscales quantized vectors to mask."""
196
-
197
- @nn.compact
198
- def __call__(self, x):
199
- num_res_blocks = 2
200
- dim = 128
201
- num_upsample_layers = 4
202
-
203
- x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
204
- x = nn.relu(x)
205
-
206
- for _ in range(num_res_blocks):
207
- x = ResBlock(features=dim)(x)
208
-
209
- for _ in range(num_upsample_layers):
210
- x = nn.ConvTranspose(
211
- features=dim,
212
- kernel_size=(4, 4),
213
- strides=(2, 2),
214
- padding=2,
215
- transpose_kernel=True,
216
- )(x)
217
- x = nn.relu(x)
218
- dim //= 2
219
-
220
- x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
221
-
222
- return x
223
-
224
- def reconstruct_masks(codebook_indices):
225
- quantized = _quantized_values_from_codebook_indices(
226
- codebook_indices, params['_embeddings']
227
- )
228
- return Decoder().apply({'params': params}, quantized)
229
-
230
- with open(_MODEL_PATH, 'rb') as f:
231
- params = _get_params(dict(np.load(f)))
232
-
233
- return jax.jit(reconstruct_masks, backend='cpu')
234
-
235
  def extract_objs(text, width, height, unique_labels=False):
236
  """Returns objs for a string with "<loc>" and "<seg>" tokens."""
237
  objs = []
@@ -252,7 +141,7 @@ def extract_objs(text, width, height, unique_labels=False):
252
  mask = None
253
  else:
254
  seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
255
- m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
256
  m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
257
  m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
258
  mask = np.zeros([height, width])
@@ -275,7 +164,12 @@ def extract_objs(text, width, height, unique_labels=False):
275
 
276
  return objs
277
 
278
- #########
 
 
 
 
 
279
 
280
  if __name__ == "__main__":
281
  demo.queue(max_size=10).launch(debug=True)
 
6
  import jax.numpy as jnp
7
  import numpy as np
8
  import flax.linen as nn
9
+ from inference import PaliGemmaModel, VAEModel
10
 
11
  COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
12
 
13
+ # Instantiate the models
14
  pali_gemma_model = PaliGemmaModel()
15
+ vae_model = VAEModel('vae-oid.npz')
16
 
17
  ##### Parse segmentation output tokens into masks
18
  ##### Also returns bounding boxes with their labels
 
121
  ### Postprocessing Utils for Segmentation Tokens
122
  ### Segmentation tokens are passed to another VAE which decodes them to a mask
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def extract_objs(text, width, height, unique_labels=False):
125
  """Returns objs for a string with "<loc>" and "<seg>" tokens."""
126
  objs = []
 
141
  mask = None
142
  else:
143
  seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
144
+ m64, = vae_model.reconstruct_masks(seg_indices[None])[..., 0]
145
  m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
146
  m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
147
  mask = np.zeros([height, width])
 
164
 
165
  return objs
166
 
167
+ _SEGMENT_DETECT_RE = re.compile(
168
+ r'(.*?)' +
169
+ r'<loc(\d{4})>' * 4 + r'\s*' +
170
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
171
+ r'\s*([^;<>]+)? ?(?:; )?',
172
+ )
173
 
174
  if __name__ == "__main__":
175
  demo.queue(max_size=10).launch(debug=True)