Spaces:
Running
Running
# routers/embedding/__init__.py | |
import os | |
import re | |
import sys | |
import threading | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
from typing import Dict, List, Tuple, Set, LiteralString | |
class EmbeddingContext: | |
# These don't change | |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512 | |
# Set when creating the object | |
lock = None | |
model = None | |
openai_client = None | |
model_name = '' | |
config_type = '' | |
embedding_shape = None | |
embedding_dtype = None | |
embedding_device = None | |
# Updates constantly | |
data = {} | |
def __init__(self): | |
try: | |
from config import settings | |
except: | |
sys.path.append(os.path.abspath( | |
os.path.join(os.path.dirname(__file__), '../..'))) | |
from config import settings | |
self.lock = threading.Lock() | |
config_type = settings.embedding_api | |
model_name = settings.embedding_model | |
if config_type == 'sbert': | |
self.model = SentenceTransformer(model_name, use_auth_token=False) | |
self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING | |
print("Max Sequence Length:", self.model.max_seq_length) | |
self.encode = self.encode_sbert | |
if torch.cuda.is_available(): | |
self.model = self.model.to('cuda') | |
elif config_type == 'openai': | |
from openai import OpenAI | |
self.openai_client = OpenAI( | |
# base_url = settings.openai_api_base | |
api_key=settings.OPENAI_API_KEY, | |
) | |
self.encode = self.encode_openai | |
self.model_name = model_name | |
self.config_type = config_type | |
tmp = self.encode(['tmp']) | |
self.embedding_shape = tmp.shape[1:] | |
self.embedding_dtype = tmp.dtype | |
self.embedding_device = tmp.device | |
def encode(self, texts_to_embed): | |
pass | |
def encode_sbert(self, texts_to_embed): | |
return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True) | |
def encode_openai(self, texts_to_embed): | |
import math | |
import time | |
tokens_count = 0 | |
for text in texts_to_embed: | |
tokens_count += len(self.get_tokens(text)) | |
chunks_num = math.ceil(tokens_count / 500000) | |
chunk_size = math.ceil(len(texts_to_embed) / chunks_num) | |
embeddings = [] | |
for i in range(chunks_num): | |
start = i * chunk_size | |
end = start + chunk_size | |
chunk = texts_to_embed[start:end] | |
embeddings_tmp = self.openai_client.embeddings.create( | |
model=self.model_name, | |
input=chunk, | |
).data | |
if embeddings_tmp is None: | |
break | |
embeddings.extend(embeddings_tmp) | |
if i < chunks_num - 1: | |
time.sleep(60) # Wait 1 minute before the next call | |
return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings]) | |
def get_tokens(self, text): | |
if self.model: | |
return self.model.tokenizer.tokenize(text) | |
tokens = [] | |
for token in re.split(r'(\W|\b)', text): | |
if token.strip(): | |
tokens.append(token) | |
return tokens | |
class SplitDocs: | |
def split_in_topics(self, | |
filedir: LiteralString = None, | |
*, | |
pattern_filename=r'(?<!navigation)\.(md|rst)', | |
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*', | |
patterns_titles=( | |
r'^# (.+)', r'^## (.+)', r'^### (.+)'), | |
) -> List[Tuple[str, str]]: | |
def matches_pattern(filename): | |
return re.search(pattern_filename, filename) is not None | |
def split_patterns_recursive(patterns, text, index=-1): | |
sections = re.split(patterns[0], text, flags=re.MULTILINE) | |
for i, section in enumerate(sections): | |
if not section.strip(): | |
continue | |
is_match = bool(i & 1) | |
if is_match: | |
yield (index, section) | |
elif len(patterns) > 1: | |
for j, section_j in split_patterns_recursive(patterns[1:], section, index + 1): | |
yield (j, section_j) | |
else: | |
yield (-1, section) | |
for root, _, files in os.walk(filedir): | |
for name in files: | |
if not matches_pattern(name): | |
continue | |
full_path = os.path.join(root, name) | |
with open(full_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
if pattern_content_sub: | |
content = re.sub(pattern_content_sub, '', content) | |
rel_path = full_path.replace(filedir, '').replace('\\', '/') | |
# Protect code parts | |
patterns = (r'(```[\s\S]+?```)', *patterns_titles) | |
last_titles = [] | |
last_titles_index = [] | |
content_accum = '' | |
for i, section in split_patterns_recursive(patterns, content): | |
if i < 0: | |
content_accum += section | |
continue | |
if content_accum: | |
yield rel_path, last_titles, content_accum | |
content_accum = '' | |
if not last_titles_index or i > last_titles_index[-1]: | |
last_titles_index.append(i) | |
last_titles.append(section) | |
continue | |
while len(last_titles_index) > 1 and i < last_titles_index[-1]: | |
last_titles_index.pop() | |
last_titles.pop() | |
# Replace | |
last_titles_index[-1] = i | |
last_titles[-1] = section | |
if content_accum or i != -1: | |
yield rel_path, last_titles, content_accum | |
def reduce_text(_self, text): | |
text = re.sub(r'^\n+', '', text) # Strip | |
text = re.sub(r'<.*?>', '', text) # Remove HTML tags | |
text = re.sub(r':\S*: ', '', text) # Remove [:...:] patterns | |
text = re.sub(r'\s*\n+', '\n', text) | |
return text | |
def embedding_header(_self, rel_path, titles): | |
return f"{rel_path}\n# {' | '.join(titles)}\n\n" | |
def split_for_embedding(self, | |
filedir: LiteralString = None, | |
*, | |
pattern_filename=r'(?<!navigation)\.(md|rst)', | |
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*', | |
patterns_titles=( | |
r'^# (.+)', r'^## (.+)', r'^### (.+)'), | |
): | |
tokenizer = EMBEDDING_CTX.model.tokenizer | |
max_tokens = EMBEDDING_CTX.model.max_seq_length | |
texts = [] | |
for rel_path, titles, content in self.split_in_topics( | |
filedir, pattern_filename=pattern_filename, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles): | |
header = self.embedding_header(rel_path, titles) | |
tokens_pre_len = len(tokenizer.tokenize(header)) | |
tokens_so_far = tokens_pre_len | |
text_so_far = header | |
for part in self.reduce_text(content).splitlines(): | |
part += '\n' | |
part_tokens_len = len(tokenizer.tokenize(part)) | |
if tokens_so_far + part_tokens_len > max_tokens: | |
texts.append(text_so_far) | |
text_so_far = header | |
tokens_so_far = tokens_pre_len | |
text_so_far += part | |
tokens_so_far += part_tokens_len | |
if tokens_so_far != tokens_pre_len: | |
texts.append(text_so_far) | |
return texts | |
EMBEDDING_CTX = EmbeddingContext() | |