from transformers import AutoTokenizer from sentence_transformers import SentenceTransformer from langchain_text_splitters import RecursiveCharacterTextSplitter from typing import List import os from qdrent import store_embeddings os.environ["TOKENIZERS_PARALLELISM"] = "false" model_path = './models/e5-base-v2' # model_path = '/Volumes/AnuragSSD/anurag/Projects/vocrt/models/e5-base-v2' model = SentenceTransformer(model_path) embedding_model = SentenceTransformer(model_path) # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) # def custom_token_text_splitter( # text: str, # max_tokens: int = 350, # overlap_tokens: int = 100, # separators: List[str] = ["\n\n", "\n", ". ", "? ", "! ", ", ", " ", "-"], # min_chunk_tokens: int = 50, # ) -> List[str]: # def count_tokens(text): # return len(tokenizer.encode(text, add_special_tokens=True)) # def split_text(text_chunk: str, current_separator_index: int) -> List[str]: # if current_separator_index >= len(separators): # tokens = tokenizer.encode(text_chunk, add_special_tokens=True) # if len(tokens) <= max_tokens: # return [text_chunk] # else: # chunks = [] # step = max_tokens - overlap_tokens # for i in range(0, len(tokens), step): # chunk_tokens = tokens[i:i+max_tokens] # chunk_text = tokenizer.decode( # chunk_tokens, skip_special_tokens=True) # if chunk_text.strip(): # chunks.append(chunk_text) # return chunks # else: # separator = separators[current_separator_index] # if not separator: # return split_text(text_chunk, current_separator_index + 1) # splits = text_chunk.split(separator) # chunks = [] # temp_chunk = "" # for i, split in enumerate(splits): # piece_to_add = separator + split if temp_chunk else split # # Check the token count if we add this piece to temp_chunk # potential_new_chunk = temp_chunk + piece_to_add # token_count = count_tokens(potential_new_chunk) # if token_count <= max_tokens + overlap_tokens: # temp_chunk = potential_new_chunk # if i == len(splits) - 1 and temp_chunk.strip(): # chunks.append(temp_chunk.strip()) # else: # if temp_chunk.strip(): # chunks.append(temp_chunk.strip()) # temp_chunk = split # final_chunks = [] # for chunk in chunks: # if count_tokens(chunk) > max_tokens: # final_chunks.extend(split_text( # chunk, current_separator_index + 1)) # else: # final_chunks.append(chunk) # return final_chunks # chunks = split_text(text, 0) # if min_chunk_tokens > 0: # filtered_chunks = [] # for chunk in chunks: # if count_tokens(chunk) >= min_chunk_tokens or len(chunks) == 1: # filtered_chunks.append(chunk) # chunks = filtered_chunks # return chunks async def get_and_store_embeddings(input_texts, session_id, name, title, summary, categories): try: # chunks = custom_token_text_splitter( # input_texts, # max_tokens=400, # overlap_tokens=100, # separators=["\n\n", "\n", ". ", "? ", "! ", ", ", " "], # min_chunk_tokens=50, # ) text_splitter = RecursiveCharacterTextSplitter( chunk_size=400, chunk_overlap=100) chunks = text_splitter.split_text(input_texts) # # Printing chunks and their token counts # for i, chunk in enumerate(chunks): # token_count = len(tokenizer.encode( # chunk, add_special_tokens=False)) # print(f"Chunk {i+1} ({token_count} tokens):") # print(chunk.strip()) # print("-" * 70) # Preparing chunks with prefixes prefixed_chunks = [f"passage: {chunk.strip()}" for chunk in chunks] # Encoding the chunks chunk_embeddings = embedding_model.encode( prefixed_chunks, normalize_embeddings=True ) # print("embeddings : ", chunk_embeddings) await store_embeddings(session_id, chunk_embeddings, chunks, name, title, summary, categories) return True except Exception as e: print("Error in getting chunks and upserting into qdrant : ", e) return False def get_query_embeddings(text): query = f"query : {text}" chunk_embeddings = embedding_model.encode( query, normalize_embeddings=True ) return chunk_embeddings