Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,001 Bytes
82635c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import os
from PIL import Image
import uuid
import re
def parse_prompt_attention(text):
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def prompt_attention_to_invoke_prompt(attention):
tokens = []
for text, weight in attention:
# Round weight to 2 decimal places
weight = round(weight, 2)
if weight == 1.0:
tokens.append(text)
elif weight < 1.0:
if weight < 0.8:
tokens.append(f"({text}){weight}")
else:
tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
else:
if weight < 1.3:
tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
else:
tokens.append(f"({text}){weight}")
return "".join(tokens)
def concat_tensor(t):
t_list = torch.split(t, 1, dim=0)
t = torch.cat(t_list, dim=1)
return t
def merge_embeds(prompt_chanks, compel):
num_chanks = len(prompt_chanks)
if num_chanks != 0:
power_prompt = 1/(num_chanks*(num_chanks+1)//2)
prompt_embs = compel(prompt_chanks)
t_list = list(torch.split(prompt_embs, 1, dim=0))
for i in range(num_chanks):
t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
else:
prompt_emb = compel('')
return prompt_emb
def detokenize(chunk, actual_prompt):
chunk[-1] = chunk[-1].replace('</w>', '')
chanked_prompt = ''.join(chunk).strip()
while '</w>' in chanked_prompt:
if actual_prompt[chanked_prompt.find('</w>')] == ' ':
chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
else:
chanked_prompt = chanked_prompt.replace('</w>', '', 1)
actual_prompt = actual_prompt.replace(chanked_prompt,'')
return chanked_prompt.strip(), actual_prompt.strip()
def tokenize_line(line, tokenizer): # split into chunks
actual_prompt = line.lower().strip()
actual_tokens = tokenizer.tokenize(actual_prompt)
max_tokens = tokenizer.model_max_length - 2
comma_token = tokenizer.tokenize(',')[0]
chunks = []
chunk = []
for item in actual_tokens:
chunk.append(item)
if len(chunk) == max_tokens:
if chunk[-1] != comma_token:
for i in range(max_tokens-1, -1, -1):
if chunk[i] == comma_token:
actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
chunks.append(actual_chunk)
chunk = chunk[i+1:]
break
else:
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
chunks.append(actual_chunk)
chunk = []
else:
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
chunks.append(actual_chunk)
chunk = []
if chunk:
actual_chunk, _ = detokenize(chunk, actual_prompt)
chunks.append(actual_chunk)
return chunks
def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
if compel_process_sd:
return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
else:
# fix bug weights conversion excessive emphasis
prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
# Convert to Compel
attention = parse_prompt_attention(prompt)
global_attention_chanks = []
for att in attention:
for chank in att[0].split(','):
temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
for small_chank in temp_prompt_chanks:
temp_dict = {
"weight": round(att[1], 2),
"lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
"prompt": f'{small_chank},'
}
global_attention_chanks.append(temp_dict)
max_tokens = pipeline.tokenizer.model_max_length - 2
global_prompt_chanks = []
current_list = []
current_length = 0
for item in global_attention_chanks:
if current_length + item['lenght'] > max_tokens:
global_prompt_chanks.append(current_list)
current_list = [[item['prompt'], item['weight']]]
current_length = item['lenght']
else:
if not current_list:
current_list.append([item['prompt'], item['weight']])
else:
if item['weight'] != current_list[-1][1]:
current_list.append([item['prompt'], item['weight']])
else:
current_list[-1][0] += f" {item['prompt']}"
current_length += item['lenght']
if current_list:
global_prompt_chanks.append(current_list)
if only_convert_string:
return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
def add_comma_after_pattern_ti(text):
pattern = re.compile(r'\b\w+_\d+\b')
modified_text = pattern.sub(lambda x: x.group() + ',', text)
return modified_text
def save_image(img):
path = "./tmp/"
# Check if the input is a string (file path) and load the image if it is
if isinstance(img, str):
img = Image.open(img) # Load the image from the file path
# Ensure the Hugging Face path exists locally
if not os.path.exists(path):
os.makedirs(path)
# Generate a unique filename
unique_name = str(uuid.uuid4()) + ".webp"
unique_name = os.path.join(path, unique_name)
# Convert the image to WebP format
webp_img = img.convert("RGB") # Ensure the image is in RGB mode
# Save the image in WebP format with high quality
webp_img.save(unique_name, "WEBP", quality=90)
# Open the saved WebP file and return it as a PIL Image object
with Image.open(unique_name) as webp_file:
webp_image = webp_file.copy()
return unique_name |