import torch from collections import namedtuple from backend.text_processing import parsing, emphasis from backend import memory_management from modules.shared import opts PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) class PromptChunk: def __init__(self): self.tokens = [] self.multipliers = [] class T5TextProcessingEngine: def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256): super().__init__() self.text_encoder = text_encoder.transformer self.tokenizer = tokenizer self.emphasis = emphasis.get_current_option(opts.emphasis)() self.min_length = min_length self.id_end = 1 self.id_pad = 0 vocab = self.tokenizer.get_vocab() self.comma_token = vocab.get(',', None) self.token_mults = {} tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: if c == '[': mult /= 1.1 if c == ']': mult *= 1.1 if c == '(': mult *= 1.1 if c == ')': mult /= 1.1 if mult != 1.0: self.token_mults[ident] = mult def tokenize(self, texts): tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] return tokenized def encode_with_transformers(self, tokens): device = memory_management.text_encoder_device() tokens = tokens.to(device) self.text_encoder.shared.to(device=device, dtype=torch.float32) z = self.text_encoder( input_ids=tokens, ) return z def tokenize_line(self, line): parsed = parsing.parse_prompt_attention(line, self.emphasis.name) tokenized = self.tokenize([text for text, _ in parsed]) chunks = [] chunk = PromptChunk() token_count = 0 def next_chunk(): nonlocal token_count nonlocal chunk chunk.tokens = chunk.tokens + [self.id_end] chunk.multipliers = chunk.multipliers + [1.0] current_chunk_length = len(chunk.tokens) token_count += current_chunk_length remaining_count = self.min_length - current_chunk_length if remaining_count > 0: chunk.tokens += [self.id_pad] * remaining_count chunk.multipliers += [1.0] * remaining_count chunks.append(chunk) chunk = PromptChunk() for tokens, (text, weight) in zip(tokenized, parsed): if text == 'BREAK' and weight == -1: next_chunk() continue position = 0 while position < len(tokens): token = tokens[position] chunk.tokens.append(token) chunk.multipliers.append(weight) position += 1 if chunk.tokens or not chunks: next_chunk() return chunks, token_count def __call__(self, texts): zs = [] cache = {} self.emphasis = emphasis.get_current_option(opts.emphasis)() for line in texts: if line in cache: line_z_values = cache[line] else: chunks, token_count = self.tokenize_line(line) line_z_values = [] # pad all chunks to length of longest chunk max_tokens = 0 for chunk in chunks: max_tokens = max (len(chunk.tokens), max_tokens) for chunk in chunks: tokens = chunk.tokens multipliers = chunk.multipliers remaining_count = max_tokens - len(tokens) if remaining_count > 0: tokens += [self.id_pad] * remaining_count multipliers += [1.0] * remaining_count z = self.process_tokens([tokens], [multipliers])[0] line_z_values.append(z) cache[line] = line_z_values zs.extend(line_z_values) return torch.stack(zs) def process_tokens(self, batch_tokens, batch_multipliers): tokens = torch.asarray(batch_tokens) z = self.encode_with_transformers(tokens) self.emphasis.tokens = batch_tokens self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z) self.emphasis.z = z self.emphasis.after_transformers() z = self.emphasis.z return z