Files changed (1) hide show
  1. app.py +185 -64
app.py CHANGED
@@ -8,76 +8,194 @@ from compel import Compel, ReturnedEmbeddingsType
8
 
9
  import re
10
 
11
- def tokenize_line(text, tokenizer):
12
- tokens = tokenizer.tokenize(text)
13
- return tokens
14
-
 
15
  def parse_prompt_attention(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  res = []
17
- pattern = re.compile(r"\(([^)]+):([\d\.]+)\)")
18
- matches = pattern.findall(text)
19
- for match in matches:
20
- res.append((match[0], float(match[1])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  return res
22
 
23
- def prompt_attention_to_invoke_prompt(attention_list):
24
- prompt = ""
25
- for item in attention_list:
26
- prompt += f"({item[0]}:{item[1]}) "
27
- return prompt.strip()
28
-
29
- def merge_embeds(prompts, compel):
30
- embeds = []
31
- pooled_embeds = []
32
- for prompt in prompts:
33
- conditioning, pooled = compel(prompt)
34
- embeds.append(conditioning)
35
- pooled_embeds.append(pooled)
36
- # 合并嵌入,这里使用平均值,可以根据需要调整
37
- merged_embed = torch.mean(torch.stack(embeds), dim=0)
38
- merged_pooled = torch.mean(torch.stack(pooled_embeds), dim=0)
39
- return merged_embed, merged_pooled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
 
42
  if compel_process_sd:
43
  return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
44
  else:
45
  # fix bug weights conversion excessive emphasis
46
- prompt = prompt.replace("((", "(").replace("))", ")")
47
 
48
  # Convert to Compel
49
  attention = parse_prompt_attention(prompt)
50
-
51
- # 新增处理,当 attention 为空时
52
- if not attention:
53
- if only_convert_string:
54
- return prompt
55
- else:
56
- conditioning, pooled = compel(prompt)
57
- return conditioning, pooled
58
 
59
- global_attention_chunks = []
60
- # 下面的部分保持不变
61
  for att in attention:
62
- for chunk in att[0].split(','):
63
- temp_prompt_chunks = tokenize_line(chunk, pipeline.tokenizer)
64
- for small_chunk in temp_prompt_chunks:
65
  temp_dict = {
66
  "weight": round(att[1], 2),
67
- "length": len(pipeline.tokenizer.tokenize(f'{small_chunk},')),
68
- "prompt": f'{small_chunk},'
69
  }
70
- global_attention_chunks.append(temp_dict)
71
 
72
  max_tokens = pipeline.tokenizer.model_max_length - 2
73
- global_prompt_chunks = []
74
  current_list = []
75
  current_length = 0
76
- for item in global_attention_chunks:
77
- if current_length + item['length'] > max_tokens:
78
- global_prompt_chunks.append(current_list)
79
  current_list = [[item['prompt'], item['weight']]]
80
- current_length = item['length']
81
  else:
82
  if not current_list:
83
  current_list.append([item['prompt'], item['weight']])
@@ -86,14 +204,19 @@ def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_pr
86
  current_list.append([item['prompt'], item['weight']])
87
  else:
88
  current_list[-1][0] += f" {item['prompt']}"
89
- current_length += item['length']
90
  if current_list:
91
- global_prompt_chunks.append(current_list)
92
 
93
  if only_convert_string:
94
- return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks])
 
 
95
 
96
- return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks], compel)
 
 
 
97
 
98
  if not torch.cuda.is_available():
99
  DESCRIPTION += "\n<p>你现在运行在CPU上 但是此项目只支持GPU.</p>"
@@ -138,24 +261,22 @@ def infer(
138
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
139
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
140
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
141
- requires_pooled=[False, True]
 
142
  )
143
  # 在 infer 函数中调用 get_embed_new
144
- conditioning, pooled = get_embed_new(prompt, pipe, compel_instance)
145
-
146
- # 处理反向提示(negative_prompt)
147
- if use_negative_prompt and negative_prompt:
148
- negative_conditioning, negative_pooled = get_embed_new(negative_prompt, pipe, compel_instance)
149
- else:
150
- negative_conditioning = None
151
- negative_pooled = None
152
 
153
  # 在调用 pipe 时,使用新的参数名称(确保参数名称正确)
154
  image = pipe(
155
- prompt_embeds=conditioning,
156
- pooled_prompt_embeds=pooled,
157
- negative_prompt_embeds=negative_conditioning,
158
- negative_pooled_prompt_embeds=negative_pooled,
159
  width=width,
160
  height=height,
161
  guidance_scale=guidance_scale,
 
8
 
9
  import re
10
 
11
+ # =====================================
12
+ # Prompt weights
13
+ # =====================================
14
+ import torch
15
+ import re
16
  def parse_prompt_attention(text):
17
+ re_attention = re.compile(r"""
18
+ \\\(|
19
+ \\\)|
20
+ \\\[|
21
+ \\]|
22
+ \\\\|
23
+ \\|
24
+ \(|
25
+ \[|
26
+ :([+-]?[.\d]+)\)|
27
+ \)|
28
+ ]|
29
+ [^\\()\[\]:]+|
30
+ :
31
+ """, re.X)
32
+
33
  res = []
34
+ round_brackets = []
35
+ square_brackets = []
36
+
37
+ round_bracket_multiplier = 1.1
38
+ square_bracket_multiplier = 1 / 1.1
39
+
40
+ def multiply_range(start_position, multiplier):
41
+ for p in range(start_position, len(res)):
42
+ res[p][1] *= multiplier
43
+
44
+ for m in re_attention.finditer(text):
45
+ text = m.group(0)
46
+ weight = m.group(1)
47
+
48
+ if text.startswith('\\'):
49
+ res.append([text[1:], 1.0])
50
+ elif text == '(':
51
+ round_brackets.append(len(res))
52
+ elif text == '[':
53
+ square_brackets.append(len(res))
54
+ elif weight is not None and len(round_brackets) > 0:
55
+ multiply_range(round_brackets.pop(), float(weight))
56
+ elif text == ')' and len(round_brackets) > 0:
57
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
58
+ elif text == ']' and len(square_brackets) > 0:
59
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
60
+ else:
61
+ parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
62
+ for i, part in enumerate(parts):
63
+ if i > 0:
64
+ res.append(["BREAK", -1])
65
+ res.append([part, 1.0])
66
+
67
+ for pos in round_brackets:
68
+ multiply_range(pos, round_bracket_multiplier)
69
+
70
+ for pos in square_brackets:
71
+ multiply_range(pos, square_bracket_multiplier)
72
+
73
+ if len(res) == 0:
74
+ res = [["", 1.0]]
75
+
76
+ # merge runs of identical weights
77
+ i = 0
78
+ while i + 1 < len(res):
79
+ if res[i][1] == res[i + 1][1]:
80
+ res[i][0] += res[i + 1][0]
81
+ res.pop(i + 1)
82
+ else:
83
+ i += 1
84
+
85
  return res
86
 
87
+ def prompt_attention_to_invoke_prompt(attention):
88
+ tokens = []
89
+ for text, weight in attention:
90
+ # Round weight to 2 decimal places
91
+ weight = round(weight, 2)
92
+ if weight == 1.0:
93
+ tokens.append(text)
94
+ elif weight < 1.0:
95
+ if weight < 0.8:
96
+ tokens.append(f"({text}){weight}")
97
+ else:
98
+ tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
99
+ else:
100
+ if weight < 1.3:
101
+ tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
102
+ else:
103
+ tokens.append(f"({text}){weight}")
104
+ return "".join(tokens)
105
+
106
+ def concat_tensor(t):
107
+ t_list = torch.split(t, 1, dim=0)
108
+ t = torch.cat(t_list, dim=1)
109
+ return t
110
+
111
+ def merge_embeds(prompt_chanks, compel):
112
+ num_chanks = len(prompt_chanks)
113
+ if num_chanks != 0:
114
+ power_prompt = 1/(num_chanks*(num_chanks+1)//2)
115
+ prompt_embs = compel(prompt_chanks)
116
+ t_list = list(torch.split(prompt_embs, 1, dim=0))
117
+ for i in range(num_chanks):
118
+ t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
119
+ prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
120
+ else:
121
+ prompt_emb = compel('')
122
+ return prompt_emb
123
+
124
+ def detokenize(chunk, actual_prompt):
125
+ chunk[-1] = chunk[-1].replace('</w>', '')
126
+ chanked_prompt = ''.join(chunk).strip()
127
+ while '</w>' in chanked_prompt:
128
+ if actual_prompt[chanked_prompt.find('</w>')] == ' ':
129
+ chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
130
+ else:
131
+ chanked_prompt = chanked_prompt.replace('</w>', '', 1)
132
+ actual_prompt = actual_prompt.replace(chanked_prompt,'')
133
+ return chanked_prompt.strip(), actual_prompt.strip()
134
+
135
+ def tokenize_line(line, tokenizer): # split into chunks
136
+ actual_prompt = line.lower().strip()
137
+ actual_tokens = tokenizer.tokenize(actual_prompt)
138
+ max_tokens = tokenizer.model_max_length - 2
139
+ comma_token = tokenizer.tokenize(',')[0]
140
+
141
+ chunks = []
142
+ chunk = []
143
+ for item in actual_tokens:
144
+ chunk.append(item)
145
+ if len(chunk) == max_tokens:
146
+ if chunk[-1] != comma_token:
147
+ for i in range(max_tokens-1, -1, -1):
148
+ if chunk[i] == comma_token:
149
+ actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
150
+ chunks.append(actual_chunk)
151
+ chunk = chunk[i+1:]
152
+ break
153
+ else:
154
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
155
+ chunks.append(actual_chunk)
156
+ chunk = []
157
+ else:
158
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
159
+ chunks.append(actual_chunk)
160
+ chunk = []
161
+ if chunk:
162
+ actual_chunk, _ = detokenize(chunk, actual_prompt)
163
+ chunks.append(actual_chunk)
164
+
165
+ return chunks
166
 
167
  def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
168
+
169
  if compel_process_sd:
170
  return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
171
  else:
172
  # fix bug weights conversion excessive emphasis
173
+ prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
174
 
175
  # Convert to Compel
176
  attention = parse_prompt_attention(prompt)
177
+ global_attention_chanks = []
 
 
 
 
 
 
 
178
 
 
 
179
  for att in attention:
180
+ for chank in att[0].split(','):
181
+ temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
182
+ for small_chank in temp_prompt_chanks:
183
  temp_dict = {
184
  "weight": round(att[1], 2),
185
+ "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
186
+ "prompt": f'{small_chank},'
187
  }
188
+ global_attention_chanks.append(temp_dict)
189
 
190
  max_tokens = pipeline.tokenizer.model_max_length - 2
191
+ global_prompt_chanks = []
192
  current_list = []
193
  current_length = 0
194
+ for item in global_attention_chanks:
195
+ if current_length + item['lenght'] > max_tokens:
196
+ global_prompt_chanks.append(current_list)
197
  current_list = [[item['prompt'], item['weight']]]
198
+ current_length = item['lenght']
199
  else:
200
  if not current_list:
201
  current_list.append([item['prompt'], item['weight']])
 
204
  current_list.append([item['prompt'], item['weight']])
205
  else:
206
  current_list[-1][0] += f" {item['prompt']}"
207
+ current_length += item['lenght']
208
  if current_list:
209
+ global_prompt_chanks.append(current_list)
210
 
211
  if only_convert_string:
212
+ return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
213
+
214
+ return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
215
 
216
+ def add_comma_after_pattern_ti(text):
217
+ pattern = re.compile(r'\b\w+_\d+\b')
218
+ modified_text = pattern.sub(lambda x: x.group() + ',', text)
219
+ return modified_text
220
 
221
  if not torch.cuda.is_available():
222
  DESCRIPTION += "\n<p>你现在运行在CPU上 但是此项目只支持GPU.</p>"
 
261
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
262
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
263
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
264
+ requires_pooled=[False, True],
265
+ truncate_long_prompts=False
266
  )
267
  # 在 infer 函数中调用 get_embed_new
268
+ if not use_negative_prompt:
269
+ negative_prompt = ""
270
+ prompt = get_embed_new(prompt, pipe, compel, only_convert_string=True)
271
+ negative_prompt = get_embed_new(negative_prompt, pipe, compel, only_convert_string=True)
272
+ conditioning, pooled = compel([prompt, neg_prompt]) # 必须同时处理来保证长度相等
 
 
 
273
 
274
  # 在调用 pipe 时,使用新的参数名称(确保参数名称正确)
275
  image = pipe(
276
+ prompt_embeds=conditioning[0:1],
277
+ pooled_prompt_embeds=pooled[0:1],
278
+ negative_prompt_embeds=conditioning[1:2],
279
+ negative_pooled_prompt_embeds=pooled[1:2],
280
  width=width,
281
  height=height,
282
  guidance_scale=guidance_scale,