moistdio's picture
Upload folder using huggingface_hub
6831a54 verified
import re
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:\s*([+-]?[.\d]+)\s*\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res