tolgacangoz commited on
Commit
bf2b36b
·
verified ·
1 Parent(s): 91661c8

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +44 -143
matryoshka.py CHANGED
@@ -94,78 +94,6 @@ EXAMPLE_DOC_STRING = """
94
  ```
95
  """
96
 
97
- import mlx
98
- from mlx.data.core import CharTrie
99
-
100
- def read_dictionary(token_file):
101
- trie_key_scores = []
102
- trie = CharTrie()
103
-
104
- f = open(token_file, "rb")
105
- sep = "\u2581".encode()
106
-
107
- max_score = 0
108
- for line in f:
109
- line = line.rstrip()
110
- token, score = line.split(b"\t")
111
- score = -float(score)
112
-
113
- token = token.replace(sep, b" ")
114
- if trie.search(token):
115
- raise RuntimeError(b"token " + token + b" already exists")
116
- trie.insert(token)
117
- trie_key_scores.append(score)
118
- max_score = max(max_score, score)
119
-
120
- eos, bos, pad = -1, -1, -1
121
- for i in range(trie.num_keys()):
122
- key = "".join(trie.key(i))
123
- if key == "</s>":
124
- eos = i
125
- if key == "<unk>":
126
- bos = i
127
- if key == "<pad>":
128
- pad = i
129
-
130
- return trie, trie_key_scores, eos, bos, pad
131
-
132
- max_caption_length = 512
133
- max_token_length = 9
134
- padding_token = "<pad>"
135
- class Tokenizer:
136
- def __init__(self, token_file):
137
- (
138
- self._trie,
139
- self._trie_key_scores,
140
- self.eos,
141
- self.bos,
142
- self.pad,
143
- ) = read_dictionary(token_file)
144
- self.vocab_size = self._trie.num_keys()
145
-
146
- @property
147
- def trie(self):
148
- return self._trie
149
-
150
- @property
151
- def trie_key_scores(self):
152
- return self._trie_key_scores
153
-
154
- def tokens2text(self, tokens):
155
- return "".join([self._trie.key_string(tok) for tok in tokens])
156
-
157
- def token_id(self, token):
158
- node = self._trie.search(token)
159
- if node is None:
160
- raise ValueError(f"token: {token} not found in vocab.")
161
- return node.id
162
-
163
- tokenizer = Tokenizer("/kaggle/input/t5-vocab/t5.vocab.txt")
164
-
165
- mlx_tokenizer = mlx.data.core.Tokenizer(
166
- tokenizer._trie, ignore_unk=True, trie_key_scores=tokenizer.trie_key_scores
167
- )
168
-
169
 
170
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
171
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
@@ -3978,46 +3906,34 @@ class MatryoshkaPipeline(
3978
 
3979
  if prompt_embeds is None:
3980
  # textual inversion: process multi-vector tokens if necessary
3981
- # if isinstance(self, TextualInversionLoaderMixin):
3982
- # prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
3983
-
3984
- # text_inputs = self.tokenizer(
3985
- # prompt,
3986
- # padding="max_length",
3987
- # max_length=self.tokenizer.model_max_length,
3988
- # truncation=True,
3989
- # return_tensors="pt",
3990
- # )
3991
-
3992
- d = prompt[: max_caption_length]
3993
- d = " " + d # Prepad caption with space (or mlx tokenizes wrongly)
3994
- tokens = mlx_tokenizer.tokenize_shortest(d)
3995
- tokens = tokens + [tokenizer.eos] # Append tokenization with eos symbol
3996
- if len(tokens) < max_token_length:
3997
- pad_length = max_token_length - len(tokens)
3998
- tokens = tokens + [tokenizer.token_id(padding_token)] * pad_length
3999
- max_len = min(len(tokens), max_token_length)
4000
- text_input_ids = torch.tensor(tokens[: max_len]).reshape(1, -1)
4001
- # text_input_ids = text_inputs.input_ids
4002
- # untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
4003
-
4004
- # if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
4005
- # text_input_ids, untruncated_ids
4006
- # ):
4007
- # removed_text = self.tokenizer.batch_decode(
4008
- # untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
4009
- # )
4010
- # logger.warning(
4011
- # "The following part of your input was truncated because CLIP can only handle sequences up to"
4012
- # f" {self.tokenizer.model_max_length} tokens: {removed_text}"
4013
- # )
4014
-
4015
- # if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
4016
- # attention_mask = text_inputs.attention_mask.to(device)
4017
- # else:
4018
- # attention_mask = None
4019
- PAD_TOKEN = tokenizer.token_id(padding_token)
4020
- attention_mask = (text_input_ids != PAD_TOKEN).float().to(device)
4021
 
4022
  if clip_skip is None:
4023
  prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
@@ -4072,36 +3988,25 @@ class MatryoshkaPipeline(
4072
  uncond_tokens = negative_prompt
4073
 
4074
  # textual inversion: process multi-vector tokens if necessary
4075
- # if isinstance(self, TextualInversionLoaderMixin):
4076
- # uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
4077
 
4078
  max_length = prompt_embeds.shape[1]
4079
- # uncond_input = self.tokenizer(
4080
- # uncond_tokens,
4081
- # padding="max_length",
4082
- # max_length=max_length,
4083
- # truncation=True,
4084
- # return_tensors="pt",
4085
- # )
4086
- d = uncond_tokens[0][: max_caption_length]
4087
- d = " " + d
4088
- tokens = mlx_tokenizer.tokenize_shortest(d)
4089
- tokens = tokens + [tokenizer.eos]
4090
- if len(tokens) < max_token_length:
4091
- pad_length = max_token_length - len(tokens)
4092
- tokens = tokens + [tokenizer.token_id(padding_token)] * pad_length
4093
- max_len = min(len(tokens), max_token_length)
4094
- uncond_input_ids = torch.tensor(tokens[: max_len]).reshape(1, -1)
4095
-
4096
- # if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
4097
- # attention_mask = uncond_input_ids.attention_mask.to(device)
4098
- # else:
4099
- # attention_mask = None
4100
- PAD_TOKEN = tokenizer.token_id(padding_token)
4101
- attention_mask = (uncond_input_ids != PAD_TOKEN).float().to(device)
4102
 
4103
  negative_prompt_embeds = self.text_encoder(
4104
- uncond_input_ids.to(device),
4105
  attention_mask=attention_mask,
4106
  )
4107
  negative_prompt_embeds = negative_prompt_embeds[0]
@@ -4293,7 +4198,7 @@ class MatryoshkaPipeline(
4293
  )
4294
 
4295
  if latents is None:
4296
- latents = randn_tensor(shape, generator=torch.manual_seed(0), device=device, dtype=dtype)
4297
  if scales is not None:
4298
  out = [latents]
4299
  for s in scales[1:]:
@@ -4377,10 +4282,6 @@ class MatryoshkaPipeline(
4377
  def interrupt(self):
4378
  return self._interrupt
4379
 
4380
- @property
4381
- def model_type(self):
4382
- return "nested_unet"
4383
-
4384
  @torch.no_grad()
4385
  @replace_example_docstring(EXAMPLE_DOC_STRING)
4386
  def __call__(
 
94
  ```
95
  """
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
99
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
 
3906
 
3907
  if prompt_embeds is None:
3908
  # textual inversion: process multi-vector tokens if necessary
3909
+ if isinstance(self, TextualInversionLoaderMixin):
3910
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
3911
+
3912
+ text_inputs = self.tokenizer(
3913
+ prompt,
3914
+ padding="max_length",
3915
+ max_length=self.tokenizer.model_max_length,
3916
+ truncation=True,
3917
+ return_tensors="pt",
3918
+ )
3919
+ text_input_ids = text_inputs.input_ids
3920
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
3921
+
3922
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
3923
+ text_input_ids, untruncated_ids
3924
+ ):
3925
+ removed_text = self.tokenizer.batch_decode(
3926
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
3927
+ )
3928
+ logger.warning(
3929
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
3930
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
3931
+ )
3932
+
3933
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3934
+ attention_mask = text_inputs.attention_mask.to(device)
3935
+ else:
3936
+ attention_mask = None
 
 
 
 
 
 
 
 
 
 
 
 
3937
 
3938
  if clip_skip is None:
3939
  prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
 
3988
  uncond_tokens = negative_prompt
3989
 
3990
  # textual inversion: process multi-vector tokens if necessary
3991
+ if isinstance(self, TextualInversionLoaderMixin):
3992
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
3993
 
3994
  max_length = prompt_embeds.shape[1]
3995
+ uncond_input = self.tokenizer(
3996
+ uncond_tokens,
3997
+ padding="max_length",
3998
+ max_length=max_length,
3999
+ truncation=True,
4000
+ return_tensors="pt",
4001
+ )
4002
+
4003
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
4004
+ attention_mask = uncond_input.attention_mask.to(device)
4005
+ else:
4006
+ attention_mask = None
 
 
 
 
 
 
 
 
 
 
 
4007
 
4008
  negative_prompt_embeds = self.text_encoder(
4009
+ uncond_input.input_ids.to(device),
4010
  attention_mask=attention_mask,
4011
  )
4012
  negative_prompt_embeds = negative_prompt_embeds[0]
 
4198
  )
4199
 
4200
  if latents is None:
4201
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
4202
  if scales is not None:
4203
  out = [latents]
4204
  for s in scales[1:]:
 
4282
  def interrupt(self):
4283
  return self._interrupt
4284
 
 
 
 
 
4285
  @torch.no_grad()
4286
  @replace_example_docstring(EXAMPLE_DOC_STRING)
4287
  def __call__(