import re re_attention = re.compile(r""" \\\(| \\\)| \\\[| \\]| \\\\| \\| \(| \[| :\s*([+-]?[.\d]+)\s*\)| \)| ]| [^\\()\[\]:]+| : """, re.X) re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) def parse_prompt_attention(text, emphasis): res = [] round_brackets = [] square_brackets = [] round_bracket_multiplier = 1.1 square_bracket_multiplier = 1 / 1.1 def multiply_range(start_position, multiplier): for p in range(start_position, len(res)): res[p][1] *= multiplier if emphasis == "None": # interpret literally res = [[text, 1.0]] else: for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) if text.startswith('\\'): res.append([text[1:], 1.0]) elif text == '(': round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) for i, part in enumerate(parts): if i > 0: res.append(["BREAK", -1]) res.append([part, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) for pos in square_brackets: multiply_range(pos, square_bracket_multiplier) if len(res) == 0: res = [["", 1.0]] i = 0 while i + 1 < len(res): if res[i][1] == res[i + 1][1]: res[i][0] += res[i + 1][0] res.pop(i + 1) else: i += 1 return res