|
|
|
|
|
import re |
|
import os |
|
import folder_paths |
|
|
|
import comfy.utils |
|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
from .utils import install_package |
|
try: |
|
from lark import Lark, Transformer, v_args |
|
except: |
|
print('install lark-parser...') |
|
install_package('lark-parser') |
|
from lark import Lark, Transformer, v_args |
|
|
|
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator') |
|
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en') |
|
zh_en_model, zh_en_tokenizer = None, None |
|
|
|
def correct_prompt_syntax(prompt=""): |
|
|
|
corrected_elements = [] |
|
|
|
prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',') |
|
|
|
prompt = re.sub(r'\s+', ' ', prompt).strip() |
|
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']') |
|
|
|
|
|
prompt_elements = prompt.split(',') |
|
|
|
def balance_brackets(element, open_bracket, close_bracket): |
|
open_brackets_count = element.count(open_bracket) |
|
close_brackets_count = element.count(close_bracket) |
|
return element + close_bracket * (open_brackets_count - close_brackets_count) |
|
|
|
for element in prompt_elements: |
|
element = element.strip() |
|
|
|
|
|
if not element: |
|
continue |
|
|
|
|
|
if element[0] in '([': |
|
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']') |
|
elif element[0] == '<': |
|
corrected_element = balance_brackets(element, '<', '>') |
|
else: |
|
|
|
corrected_element = element.lstrip(')]') |
|
|
|
corrected_elements.append(corrected_element) |
|
|
|
|
|
return ','.join(corrected_elements) |
|
|
|
def detect_language(input_str): |
|
|
|
count_cn = count_en = 0 |
|
for char in input_str: |
|
if '\u4e00' <= char <= '\u9fff': |
|
count_cn += 1 |
|
elif char.isalpha(): |
|
count_en += 1 |
|
|
|
|
|
if count_cn > count_en: |
|
return "cn" |
|
elif count_en > count_cn: |
|
return "en" |
|
else: |
|
return "unknow" |
|
|
|
def has_chinese(text): |
|
has_cn = False |
|
_text = text |
|
_text = re.sub(r'<.*?>', '', _text) |
|
_text = re.sub(r'__.*?__', '', _text) |
|
_text = re.sub(r'embedding:.*?$', '', _text) |
|
for char in _text: |
|
if '\u4e00' <= char <= '\u9fff': |
|
has_cn = True |
|
break |
|
elif char.isalpha(): |
|
continue |
|
return has_cn |
|
|
|
def translate(text): |
|
global zh_en_model_path, zh_en_model, zh_en_tokenizer |
|
|
|
if not os.path.exists(zh_en_model_path): |
|
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
if zh_en_model is None: |
|
|
|
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() |
|
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) |
|
|
|
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
with torch.no_grad(): |
|
encoded = zh_en_tokenizer([text], return_tensors="pt") |
|
encoded.to(zh_en_model.device) |
|
sequences = zh_en_model.generate(**encoded) |
|
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] |
|
|
|
@v_args(inline=True) |
|
class ChinesePromptTranslate(Transformer): |
|
|
|
def sentence(self, *args): |
|
return ", ".join(args) |
|
|
|
def phrase(self, *args): |
|
return "".join(args) |
|
|
|
def emphasis(self, *args): |
|
|
|
return "(" + "".join(args) + ")" |
|
|
|
def weak_emphasis(self, *args): |
|
print('weak_emphasis:', args) |
|
return "[" + "".join(args) + "]" |
|
|
|
def embedding(self, *args): |
|
print('prompt embedding', args[0]) |
|
if len(args) == 1: |
|
embedding_name = str(args[0]) |
|
return f"embedding:{embedding_name}" |
|
elif len(args) > 1: |
|
embedding_name, *numbers = args |
|
|
|
if len(numbers) == 2: |
|
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}" |
|
elif len(numbers) == 1: |
|
return f"embedding:{embedding_name}:{numbers[0]}" |
|
else: |
|
return f"embedding:{embedding_name}" |
|
|
|
def lora(self, *args): |
|
if len(args) == 1: |
|
return f"<lora:{args[0]}>" |
|
elif len(args) > 1: |
|
|
|
_, loar_name, *numbers = args |
|
loar_name = str(loar_name).strip() |
|
if len(numbers) == 2: |
|
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>" |
|
elif len(numbers) == 1: |
|
return f"<lora:{loar_name}:{numbers[0]}>" |
|
else: |
|
return f"<lora:{loar_name}>" |
|
|
|
def weight(self, word, number): |
|
translated_word = translate(str(word)).rstrip('.') |
|
return f"({translated_word}:{str(number).strip()})" |
|
|
|
def schedule(self, *args): |
|
print('prompt schedule', args) |
|
data = [str(arg).strip() for arg in args] |
|
|
|
return f"[{':'.join(data)}]" |
|
|
|
def word(self, word): |
|
|
|
word = str(word) |
|
match_cn = re.search(r'@.*?@', word) |
|
if re.search(r'__.*?__', word): |
|
return word.rstrip('.') |
|
elif match_cn: |
|
chinese = match_cn.group() |
|
before = word.split('@', 1) |
|
before = before[0] if len(before) > 0 else '' |
|
before = translate(str(before)).rstrip('.') if before else '' |
|
after = word.rsplit('@', 1) |
|
after = after[len(after)-1] if len(after) > 1 else '' |
|
after = translate(after).rstrip('.') if after else '' |
|
return before + chinese.replace('@', '').rstrip('.') + after |
|
elif detect_language(word) == "cn": |
|
return translate(word).rstrip('.') |
|
else: |
|
return word.rstrip('.') |
|
|
|
|
|
|
|
grammar = """ |
|
start: sentence |
|
sentence: phrase ("," phrase)* |
|
phrase: emphasis | weight | word | lora | embedding | schedule |
|
emphasis: "(" sentence ")" -> emphasis |
|
| "[" sentence "]" -> weak_emphasis |
|
weight: "(" word ":" NUMBER ")" |
|
schedule: "[" word ":" word ":" NUMBER "]" |
|
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">" |
|
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)? |
|
word: WORD |
|
|
|
NUMBER: /\s*-?\d+(\.\d+)?\s*/ |
|
WORD: /[^,:\(\)\[\]<>]+/ |
|
""" |
|
def zh_to_en(text): |
|
global zh_en_model_path, zh_en_model, zh_en_tokenizer |
|
|
|
pbar = comfy.utils.ProgressBar(len(text) + 1) |
|
texts = [correct_prompt_syntax(t) for t in text] |
|
|
|
install_package('sentencepiece', '0.2.0') |
|
|
|
if not os.path.exists(zh_en_model_path): |
|
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' |
|
|
|
if zh_en_model is None: |
|
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() |
|
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) |
|
|
|
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
prompt_result = [] |
|
|
|
en_texts = [] |
|
|
|
for t in texts: |
|
if t: |
|
|
|
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate()) |
|
|
|
result = parser.parse(t).children |
|
|
|
|
|
en_texts.append(result[0]) |
|
|
|
zh_en_model.to('cpu') |
|
|
|
|
|
|
|
pbar.update(1) |
|
for t in en_texts: |
|
prompt_result.append(t) |
|
pbar.update(1) |
|
|
|
|
|
if len(prompt_result) == 0: |
|
prompt_result = [""] |
|
|
|
return prompt_result |