File size: 4,993 Bytes
5306da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import List
import os
from qdrent import store_embeddings
os.environ["TOKENIZERS_PARALLELISM"] = "false"


model_path = './models/e5-base-v2'
# model_path = '/Volumes/AnuragSSD/anurag/Projects/vocrt/models/e5-base-v2'
model = SentenceTransformer(model_path)


embedding_model = SentenceTransformer(model_path)
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)


# def custom_token_text_splitter(
#     text: str,
#     max_tokens: int = 350,
#     overlap_tokens: int = 100,
#     separators: List[str] = ["\n\n", "\n", ". ", "? ", "! ", ", ", " ", "-"],
#     min_chunk_tokens: int = 50,
# ) -> List[str]:

#     def count_tokens(text):
#         return len(tokenizer.encode(text, add_special_tokens=True))

#     def split_text(text_chunk: str, current_separator_index: int) -> List[str]:
#         if current_separator_index >= len(separators):
#             tokens = tokenizer.encode(text_chunk, add_special_tokens=True)
#             if len(tokens) <= max_tokens:
#                 return [text_chunk]
#             else:
#                 chunks = []
#                 step = max_tokens - overlap_tokens
#                 for i in range(0, len(tokens), step):
#                     chunk_tokens = tokens[i:i+max_tokens]
#                     chunk_text = tokenizer.decode(
#                         chunk_tokens, skip_special_tokens=True)
#                     if chunk_text.strip():
#                         chunks.append(chunk_text)
#                 return chunks
#         else:
#             separator = separators[current_separator_index]
#             if not separator:
#                 return split_text(text_chunk, current_separator_index + 1)
#             splits = text_chunk.split(separator)
#             chunks = []
#             temp_chunk = ""

#             for i, split in enumerate(splits):

#                 piece_to_add = separator + split if temp_chunk else split

#                 # Check the token count if we add this piece to temp_chunk
#                 potential_new_chunk = temp_chunk + piece_to_add
#                 token_count = count_tokens(potential_new_chunk)

#                 if token_count <= max_tokens + overlap_tokens:
#                     temp_chunk = potential_new_chunk
#                     if i == len(splits) - 1 and temp_chunk.strip():
#                         chunks.append(temp_chunk.strip())
#                 else:
#                     if temp_chunk.strip():
#                         chunks.append(temp_chunk.strip())
#                     temp_chunk = split

#             final_chunks = []
#             for chunk in chunks:
#                 if count_tokens(chunk) > max_tokens:
#                     final_chunks.extend(split_text(
#                         chunk, current_separator_index + 1))
#                 else:
#                     final_chunks.append(chunk)
#             return final_chunks

#     chunks = split_text(text, 0)

#     if min_chunk_tokens > 0:
#         filtered_chunks = []
#         for chunk in chunks:
#             if count_tokens(chunk) >= min_chunk_tokens or len(chunks) == 1:
#                 filtered_chunks.append(chunk)
#         chunks = filtered_chunks

#     return chunks


async def get_and_store_embeddings(input_texts, session_id, name, title, summary, categories):
    try:

        # chunks = custom_token_text_splitter(
        #     input_texts,
        #     max_tokens=400,
        #     overlap_tokens=100,
        #     separators=["\n\n", "\n", ". ", "? ", "! ", ", ", " "],
        #     min_chunk_tokens=50,
        # )

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=400, chunk_overlap=100)

        chunks = text_splitter.split_text(input_texts)

        # # Printing chunks and their token counts
        # for i, chunk in enumerate(chunks):
        #     token_count = len(tokenizer.encode(
        #         chunk, add_special_tokens=False))
        #     print(f"Chunk {i+1} ({token_count} tokens):")
        #     print(chunk.strip())
        #     print("-" * 70)

        # Preparing chunks with prefixes
        prefixed_chunks = [f"passage: {chunk.strip()}" for chunk in chunks]

        # Encoding the chunks
        chunk_embeddings = embedding_model.encode(
            prefixed_chunks,
            normalize_embeddings=True
        )

        # print("embeddings : ", chunk_embeddings)

        await store_embeddings(session_id, chunk_embeddings, chunks, name, title, summary, categories)
        return True
    except Exception as e:
        print("Error in getting chunks and upserting into qdrant : ", e)
        return False


def get_query_embeddings(text):
    query = f"query : {text}"
    chunk_embeddings = embedding_model.encode(
        query,
        normalize_embeddings=True
    )
    return chunk_embeddings