File size: 4,993 Bytes
5306da4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import List
import os
from qdrent import store_embeddings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_path = './models/e5-base-v2'
# model_path = '/Volumes/AnuragSSD/anurag/Projects/vocrt/models/e5-base-v2'
model = SentenceTransformer(model_path)
embedding_model = SentenceTransformer(model_path)
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
# def custom_token_text_splitter(
# text: str,
# max_tokens: int = 350,
# overlap_tokens: int = 100,
# separators: List[str] = ["\n\n", "\n", ". ", "? ", "! ", ", ", " ", "-"],
# min_chunk_tokens: int = 50,
# ) -> List[str]:
# def count_tokens(text):
# return len(tokenizer.encode(text, add_special_tokens=True))
# def split_text(text_chunk: str, current_separator_index: int) -> List[str]:
# if current_separator_index >= len(separators):
# tokens = tokenizer.encode(text_chunk, add_special_tokens=True)
# if len(tokens) <= max_tokens:
# return [text_chunk]
# else:
# chunks = []
# step = max_tokens - overlap_tokens
# for i in range(0, len(tokens), step):
# chunk_tokens = tokens[i:i+max_tokens]
# chunk_text = tokenizer.decode(
# chunk_tokens, skip_special_tokens=True)
# if chunk_text.strip():
# chunks.append(chunk_text)
# return chunks
# else:
# separator = separators[current_separator_index]
# if not separator:
# return split_text(text_chunk, current_separator_index + 1)
# splits = text_chunk.split(separator)
# chunks = []
# temp_chunk = ""
# for i, split in enumerate(splits):
# piece_to_add = separator + split if temp_chunk else split
# # Check the token count if we add this piece to temp_chunk
# potential_new_chunk = temp_chunk + piece_to_add
# token_count = count_tokens(potential_new_chunk)
# if token_count <= max_tokens + overlap_tokens:
# temp_chunk = potential_new_chunk
# if i == len(splits) - 1 and temp_chunk.strip():
# chunks.append(temp_chunk.strip())
# else:
# if temp_chunk.strip():
# chunks.append(temp_chunk.strip())
# temp_chunk = split
# final_chunks = []
# for chunk in chunks:
# if count_tokens(chunk) > max_tokens:
# final_chunks.extend(split_text(
# chunk, current_separator_index + 1))
# else:
# final_chunks.append(chunk)
# return final_chunks
# chunks = split_text(text, 0)
# if min_chunk_tokens > 0:
# filtered_chunks = []
# for chunk in chunks:
# if count_tokens(chunk) >= min_chunk_tokens or len(chunks) == 1:
# filtered_chunks.append(chunk)
# chunks = filtered_chunks
# return chunks
async def get_and_store_embeddings(input_texts, session_id, name, title, summary, categories):
try:
# chunks = custom_token_text_splitter(
# input_texts,
# max_tokens=400,
# overlap_tokens=100,
# separators=["\n\n", "\n", ". ", "? ", "! ", ", ", " "],
# min_chunk_tokens=50,
# )
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400, chunk_overlap=100)
chunks = text_splitter.split_text(input_texts)
# # Printing chunks and their token counts
# for i, chunk in enumerate(chunks):
# token_count = len(tokenizer.encode(
# chunk, add_special_tokens=False))
# print(f"Chunk {i+1} ({token_count} tokens):")
# print(chunk.strip())
# print("-" * 70)
# Preparing chunks with prefixes
prefixed_chunks = [f"passage: {chunk.strip()}" for chunk in chunks]
# Encoding the chunks
chunk_embeddings = embedding_model.encode(
prefixed_chunks,
normalize_embeddings=True
)
# print("embeddings : ", chunk_embeddings)
await store_embeddings(session_id, chunk_embeddings, chunks, name, title, summary, categories)
return True
except Exception as e:
print("Error in getting chunks and upserting into qdrant : ", e)
return False
def get_query_embeddings(text):
query = f"query : {text}"
chunk_embeddings = embedding_model.encode(
query,
normalize_embeddings=True
)
return chunk_embeddings
|