tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +44 -143
matryoshka.py
CHANGED
@@ -94,78 +94,6 @@ EXAMPLE_DOC_STRING = """
|
|
94 |
```
|
95 |
"""
|
96 |
|
97 |
-
import mlx
|
98 |
-
from mlx.data.core import CharTrie
|
99 |
-
|
100 |
-
def read_dictionary(token_file):
|
101 |
-
trie_key_scores = []
|
102 |
-
trie = CharTrie()
|
103 |
-
|
104 |
-
f = open(token_file, "rb")
|
105 |
-
sep = "\u2581".encode()
|
106 |
-
|
107 |
-
max_score = 0
|
108 |
-
for line in f:
|
109 |
-
line = line.rstrip()
|
110 |
-
token, score = line.split(b"\t")
|
111 |
-
score = -float(score)
|
112 |
-
|
113 |
-
token = token.replace(sep, b" ")
|
114 |
-
if trie.search(token):
|
115 |
-
raise RuntimeError(b"token " + token + b" already exists")
|
116 |
-
trie.insert(token)
|
117 |
-
trie_key_scores.append(score)
|
118 |
-
max_score = max(max_score, score)
|
119 |
-
|
120 |
-
eos, bos, pad = -1, -1, -1
|
121 |
-
for i in range(trie.num_keys()):
|
122 |
-
key = "".join(trie.key(i))
|
123 |
-
if key == "</s>":
|
124 |
-
eos = i
|
125 |
-
if key == "<unk>":
|
126 |
-
bos = i
|
127 |
-
if key == "<pad>":
|
128 |
-
pad = i
|
129 |
-
|
130 |
-
return trie, trie_key_scores, eos, bos, pad
|
131 |
-
|
132 |
-
max_caption_length = 512
|
133 |
-
max_token_length = 9
|
134 |
-
padding_token = "<pad>"
|
135 |
-
class Tokenizer:
|
136 |
-
def __init__(self, token_file):
|
137 |
-
(
|
138 |
-
self._trie,
|
139 |
-
self._trie_key_scores,
|
140 |
-
self.eos,
|
141 |
-
self.bos,
|
142 |
-
self.pad,
|
143 |
-
) = read_dictionary(token_file)
|
144 |
-
self.vocab_size = self._trie.num_keys()
|
145 |
-
|
146 |
-
@property
|
147 |
-
def trie(self):
|
148 |
-
return self._trie
|
149 |
-
|
150 |
-
@property
|
151 |
-
def trie_key_scores(self):
|
152 |
-
return self._trie_key_scores
|
153 |
-
|
154 |
-
def tokens2text(self, tokens):
|
155 |
-
return "".join([self._trie.key_string(tok) for tok in tokens])
|
156 |
-
|
157 |
-
def token_id(self, token):
|
158 |
-
node = self._trie.search(token)
|
159 |
-
if node is None:
|
160 |
-
raise ValueError(f"token: {token} not found in vocab.")
|
161 |
-
return node.id
|
162 |
-
|
163 |
-
tokenizer = Tokenizer("/kaggle/input/t5-vocab/t5.vocab.txt")
|
164 |
-
|
165 |
-
mlx_tokenizer = mlx.data.core.Tokenizer(
|
166 |
-
tokenizer._trie, ignore_unk=True, trie_key_scores=tokenizer.trie_key_scores
|
167 |
-
)
|
168 |
-
|
169 |
|
170 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
171 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
@@ -3978,46 +3906,34 @@ class MatryoshkaPipeline(
|
|
3978 |
|
3979 |
if prompt_embeds is None:
|
3980 |
# textual inversion: process multi-vector tokens if necessary
|
3981 |
-
|
3982 |
-
|
3983 |
-
|
3984 |
-
|
3985 |
-
|
3986 |
-
|
3987 |
-
|
3988 |
-
|
3989 |
-
|
3990 |
-
|
3991 |
-
|
3992 |
-
|
3993 |
-
|
3994 |
-
|
3995 |
-
|
3996 |
-
|
3997 |
-
|
3998 |
-
|
3999 |
-
|
4000 |
-
|
4001 |
-
|
4002 |
-
|
4003 |
-
|
4004 |
-
|
4005 |
-
|
4006 |
-
|
4007 |
-
|
4008 |
-
|
4009 |
-
# )
|
4010 |
-
# logger.warning(
|
4011 |
-
# "The following part of your input was truncated because CLIP can only handle sequences up to"
|
4012 |
-
# f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
4013 |
-
# )
|
4014 |
-
|
4015 |
-
# if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
4016 |
-
# attention_mask = text_inputs.attention_mask.to(device)
|
4017 |
-
# else:
|
4018 |
-
# attention_mask = None
|
4019 |
-
PAD_TOKEN = tokenizer.token_id(padding_token)
|
4020 |
-
attention_mask = (text_input_ids != PAD_TOKEN).float().to(device)
|
4021 |
|
4022 |
if clip_skip is None:
|
4023 |
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
@@ -4072,36 +3988,25 @@ class MatryoshkaPipeline(
|
|
4072 |
uncond_tokens = negative_prompt
|
4073 |
|
4074 |
# textual inversion: process multi-vector tokens if necessary
|
4075 |
-
|
4076 |
-
|
4077 |
|
4078 |
max_length = prompt_embeds.shape[1]
|
4079 |
-
|
4080 |
-
|
4081 |
-
|
4082 |
-
|
4083 |
-
|
4084 |
-
|
4085 |
-
|
4086 |
-
|
4087 |
-
|
4088 |
-
|
4089 |
-
|
4090 |
-
|
4091 |
-
pad_length = max_token_length - len(tokens)
|
4092 |
-
tokens = tokens + [tokenizer.token_id(padding_token)] * pad_length
|
4093 |
-
max_len = min(len(tokens), max_token_length)
|
4094 |
-
uncond_input_ids = torch.tensor(tokens[: max_len]).reshape(1, -1)
|
4095 |
-
|
4096 |
-
# if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
4097 |
-
# attention_mask = uncond_input_ids.attention_mask.to(device)
|
4098 |
-
# else:
|
4099 |
-
# attention_mask = None
|
4100 |
-
PAD_TOKEN = tokenizer.token_id(padding_token)
|
4101 |
-
attention_mask = (uncond_input_ids != PAD_TOKEN).float().to(device)
|
4102 |
|
4103 |
negative_prompt_embeds = self.text_encoder(
|
4104 |
-
|
4105 |
attention_mask=attention_mask,
|
4106 |
)
|
4107 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
@@ -4293,7 +4198,7 @@ class MatryoshkaPipeline(
|
|
4293 |
)
|
4294 |
|
4295 |
if latents is None:
|
4296 |
-
latents = randn_tensor(shape, generator=
|
4297 |
if scales is not None:
|
4298 |
out = [latents]
|
4299 |
for s in scales[1:]:
|
@@ -4377,10 +4282,6 @@ class MatryoshkaPipeline(
|
|
4377 |
def interrupt(self):
|
4378 |
return self._interrupt
|
4379 |
|
4380 |
-
@property
|
4381 |
-
def model_type(self):
|
4382 |
-
return "nested_unet"
|
4383 |
-
|
4384 |
@torch.no_grad()
|
4385 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
4386 |
def __call__(
|
|
|
94 |
```
|
95 |
"""
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
99 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
|
3906 |
|
3907 |
if prompt_embeds is None:
|
3908 |
# textual inversion: process multi-vector tokens if necessary
|
3909 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
3910 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
3911 |
+
|
3912 |
+
text_inputs = self.tokenizer(
|
3913 |
+
prompt,
|
3914 |
+
padding="max_length",
|
3915 |
+
max_length=self.tokenizer.model_max_length,
|
3916 |
+
truncation=True,
|
3917 |
+
return_tensors="pt",
|
3918 |
+
)
|
3919 |
+
text_input_ids = text_inputs.input_ids
|
3920 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
3921 |
+
|
3922 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
3923 |
+
text_input_ids, untruncated_ids
|
3924 |
+
):
|
3925 |
+
removed_text = self.tokenizer.batch_decode(
|
3926 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
3927 |
+
)
|
3928 |
+
logger.warning(
|
3929 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
3930 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
3931 |
+
)
|
3932 |
+
|
3933 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
3934 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
3935 |
+
else:
|
3936 |
+
attention_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3937 |
|
3938 |
if clip_skip is None:
|
3939 |
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
|
|
3988 |
uncond_tokens = negative_prompt
|
3989 |
|
3990 |
# textual inversion: process multi-vector tokens if necessary
|
3991 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
3992 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
3993 |
|
3994 |
max_length = prompt_embeds.shape[1]
|
3995 |
+
uncond_input = self.tokenizer(
|
3996 |
+
uncond_tokens,
|
3997 |
+
padding="max_length",
|
3998 |
+
max_length=max_length,
|
3999 |
+
truncation=True,
|
4000 |
+
return_tensors="pt",
|
4001 |
+
)
|
4002 |
+
|
4003 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
4004 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
4005 |
+
else:
|
4006 |
+
attention_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4007 |
|
4008 |
negative_prompt_embeds = self.text_encoder(
|
4009 |
+
uncond_input.input_ids.to(device),
|
4010 |
attention_mask=attention_mask,
|
4011 |
)
|
4012 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
|
|
4198 |
)
|
4199 |
|
4200 |
if latents is None:
|
4201 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
4202 |
if scales is not None:
|
4203 |
out = [latents]
|
4204 |
for s in scales[1:]:
|
|
|
4282 |
def interrupt(self):
|
4283 |
return self._interrupt
|
4284 |
|
|
|
|
|
|
|
|
|
4285 |
@torch.no_grad()
|
4286 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
4287 |
def __call__(
|