Spaces:
Runtime error
Runtime error
File size: 4,199 Bytes
6831a54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend import memory_management
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
class PromptChunk:
def __init__(self):
self.tokens = []
self.multipliers = []
class T5TextProcessingEngine:
def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256):
super().__init__()
self.text_encoder = text_encoder.transformer
self.tokenizer = tokenizer
self.emphasis = emphasis.get_current_option(emphasis_name)()
self.min_length = min_length
self.id_end = 1
self.id_pad = 0
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize(self, texts):
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_transformers(self, tokens):
device = memory_management.text_encoder_device()
tokens = tokens.to(device)
self.text_encoder.shared.to(device=device, dtype=torch.float32)
z = self.text_encoder(
input_ids=tokens,
)
return z
def tokenize_line(self, line):
parsed = parsing.parse_prompt_attention(line)
tokenized = self.tokenize([text for text, _ in parsed])
chunks = []
chunk = PromptChunk()
token_count = 0
def next_chunk():
nonlocal token_count
nonlocal chunk
chunk.tokens = chunk.tokens + [self.id_end]
chunk.multipliers = chunk.multipliers + [1.0]
current_chunk_length = len(chunk.tokens)
token_count += current_chunk_length
remaining_count = self.min_length - current_chunk_length
if remaining_count > 0:
chunk.tokens += [self.id_pad] * remaining_count
chunk.multipliers += [1.0] * remaining_count
chunks.append(chunk)
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
next_chunk()
continue
position = 0
while position < len(tokens):
token = tokens[position]
chunk.tokens.append(token)
chunk.multipliers.append(weight)
position += 1
if chunk.tokens or not chunks:
next_chunk()
return chunks, token_count
def __call__(self, texts):
zs = []
cache = {}
for line in texts:
if line in cache:
line_z_values = cache[line]
else:
chunks, token_count = self.tokenize_line(line)
line_z_values = []
for chunk in chunks:
tokens = chunk.tokens
multipliers = chunk.multipliers
z = self.process_tokens([tokens], [multipliers])[0]
line_z_values.append(z)
cache[line] = line_z_values
zs.extend(line_z_values)
return torch.stack(zs)
def process_tokens(self, batch_tokens, batch_multipliers):
tokens = torch.asarray(batch_tokens)
z = self.encode_with_transformers(tokens)
self.emphasis.tokens = batch_tokens
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
self.emphasis.z = z
self.emphasis.after_transformers()
z = self.emphasis.z
return z
|