File size: 2,129 Bytes
f7e9d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re


re_attention = re.compile(r"""

\\\(|

\\\)|

\\\[|

\\]|

\\\\|

\\|

\(|

\[|

:\s*([+-]?[.\d]+)\s*\)|

\)|

]|

[^\\()\[\]:]+|

:

""", re.X)

re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)


def parse_prompt_attention(text, emphasis):
    res = []
    round_brackets = []
    square_brackets = []

    round_bracket_multiplier = 1.1
    square_bracket_multiplier = 1 / 1.1

    def multiply_range(start_position, multiplier):
        for p in range(start_position, len(res)):
            res[p][1] *= multiplier

    if emphasis == "None":
        # interpret literally
        res = [[text, 1.0]]
    else:
        for m in re_attention.finditer(text):
            text = m.group(0)
            weight = m.group(1)

            if text.startswith('\\'):
                res.append([text[1:], 1.0])
            elif text == '(':
                round_brackets.append(len(res))
            elif text == '[':
                square_brackets.append(len(res))
            elif weight is not None and round_brackets:
                multiply_range(round_brackets.pop(), float(weight))
            elif text == ')' and round_brackets:
                multiply_range(round_brackets.pop(), round_bracket_multiplier)
            elif text == ']' and square_brackets:
                multiply_range(square_brackets.pop(), square_bracket_multiplier)
            else:
                parts = re.split(re_break, text)
                for i, part in enumerate(parts):
                    if i > 0:
                        res.append(["BREAK", -1])
                    res.append([part, 1.0])

        for pos in round_brackets:
            multiply_range(pos, round_bracket_multiplier)

        for pos in square_brackets:
            multiply_range(pos, square_bracket_multiplier)

        if len(res) == 0:
            res = [["", 1.0]]

        i = 0
        while i + 1 < len(res):
            if res[i][1] == res[i + 1][1]:
                res[i][0] += res[i + 1][0]
                res.pop(i + 1)
            else:
                i += 1

    return res