File size: 8,165 Bytes
ed15883
 
 
1b8973e
ed15883
 
 
 
1b8973e
ed15883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b8973e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed15883
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# routers/embedding/__init__.py

import os
import re
import sys
import threading
import torch
from sentence_transformers import SentenceTransformer, util
from typing import Dict, List, Tuple, Set, LiteralString


class EmbeddingContext:
    # These don't change
    TOKEN_LEN_MAX_FOR_EMBEDDING = 512

    # Set when creating the object
    lock = None
    model = None
    openai_client = None
    model_name = ''
    config_type = ''
    embedding_shape = None
    embedding_dtype = None
    embedding_device = None

    # Updates constantly
    data = {}

    def __init__(self):
        try:
            from config import settings
        except:
            sys.path.append(os.path.abspath(
                os.path.join(os.path.dirname(__file__), '../..')))
            from config import settings

        self.lock = threading.Lock()
        config_type = settings.embedding_api
        model_name = settings.embedding_model

        if config_type == 'sbert':
            self.model = SentenceTransformer(model_name, use_auth_token=False)
            self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING
            print("Max Sequence Length:", self.model.max_seq_length)

            self.encode = self.encode_sbert
            if torch.cuda.is_available():
                self.model = self.model.to('cuda')

        elif config_type == 'openai':
            from openai import OpenAI
            self.openai_client = OpenAI(
                # base_url = settings.openai_api_base
                api_key=settings.OPENAI_API_KEY,
            )
            self.encode = self.encode_openai

        self.model_name = model_name
        self.config_type = config_type

        tmp = self.encode(['tmp'])
        self.embedding_shape = tmp.shape[1:]
        self.embedding_dtype = tmp.dtype
        self.embedding_device = tmp.device

    def encode(self, texts_to_embed):
        pass

    def encode_sbert(self, texts_to_embed):
        return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)

    def encode_openai(self, texts_to_embed):
        import math
        import time

        tokens_count = 0
        for text in texts_to_embed:
            tokens_count += len(self.get_tokens(text))

            chunks_num = math.ceil(tokens_count / 500000)
            chunk_size = math.ceil(len(texts_to_embed) / chunks_num)

            embeddings = []
            for i in range(chunks_num):
                start = i * chunk_size
                end = start + chunk_size
                chunk = texts_to_embed[start:end]

                embeddings_tmp = self.openai_client.embeddings.create(
                    model=self.model_name,
                    input=chunk,
                ).data

                if embeddings_tmp is None:
                    break

                embeddings.extend(embeddings_tmp)

                if i < chunks_num - 1:
                    time.sleep(60)  # Wait 1 minute before the next call

            return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings])

    def get_tokens(self, text):
        if self.model:
            return self.model.tokenizer.tokenize(text)

        tokens = []
        for token in re.split(r'(\W|\b)', text):
            if token.strip():
                tokens.append(token)

        return tokens


class SplitDocs:
    def split_in_topics(self,
                        filedir: LiteralString = None,
                        *,
                        pattern_filename=r'(?<!navigation)\.(md|rst)',
                        pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
                        patterns_titles=(
                            r'^# (.+)', r'^## (.+)', r'^### (.+)'),
                        ) -> List[Tuple[str, str]]:
        def matches_pattern(filename):
            return re.search(pattern_filename, filename) is not None

        def split_patterns_recursive(patterns, text, index=-1):
            sections = re.split(patterns[0], text, flags=re.MULTILINE)
            for i, section in enumerate(sections):
                if not section.strip():
                    continue
                is_match = bool(i & 1)
                if is_match:
                    yield (index, section)
                elif len(patterns) > 1:
                    for j, section_j in split_patterns_recursive(patterns[1:], section, index + 1):
                        yield (j, section_j)
                else:
                    yield (-1, section)

        for root, _, files in os.walk(filedir):
            for name in files:
                if not matches_pattern(name):
                    continue

                full_path = os.path.join(root, name)
                with open(full_path, 'r', encoding='utf-8') as file:
                    content = file.read()

                if pattern_content_sub:
                    content = re.sub(pattern_content_sub, '', content)

                rel_path = full_path.replace(filedir, '').replace('\\', '/')

                # Protect code parts
                patterns = (r'(```[\s\S]+?```)', *patterns_titles)

                last_titles = []
                last_titles_index = []
                content_accum = ''
                for i, section in split_patterns_recursive(patterns, content):
                    if i < 0:
                        content_accum += section
                        continue
                    if content_accum:
                        yield rel_path, last_titles, content_accum
                        content_accum = ''
                    if not last_titles_index or i > last_titles_index[-1]:
                        last_titles_index.append(i)
                        last_titles.append(section)
                        continue
                    while len(last_titles_index) > 1 and i < last_titles_index[-1]:
                        last_titles_index.pop()
                        last_titles.pop()
                    # Replace
                    last_titles_index[-1] = i
                    last_titles[-1] = section
                if content_accum or i != -1:
                    yield rel_path, last_titles, content_accum

    def reduce_text(_self, text):
        text = re.sub(r'^\n+', '', text)  # Strip
        text = re.sub(r'<.*?>', '', text)  # Remove HTML tags
        text = re.sub(r':\S*: ', '', text)  # Remove [:...:] patterns
        text = re.sub(r'\s*\n+', '\n', text)
        return text

    def embedding_header(_self, rel_path, titles):
        return f"{rel_path}\n# {' | '.join(titles)}\n\n"

    def split_for_embedding(self,
                            filedir: LiteralString = None,
                            *,
                            pattern_filename=r'(?<!navigation)\.(md|rst)',
                            pattern_content_sub=r'---\nhide:[\s\S]+?---\s*',
                            patterns_titles=(
                                r'^# (.+)', r'^## (.+)', r'^### (.+)'),
                            ):
        tokenizer = EMBEDDING_CTX.model.tokenizer
        max_tokens = EMBEDDING_CTX.model.max_seq_length
        texts = []

        for rel_path, titles, content in self.split_in_topics(
                filedir, pattern_filename=pattern_filename, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles):
            header = self.embedding_header(rel_path, titles)
            tokens_pre_len = len(tokenizer.tokenize(header))
            tokens_so_far = tokens_pre_len
            text_so_far = header
            for part in self.reduce_text(content).splitlines():
                part += '\n'
                part_tokens_len = len(tokenizer.tokenize(part))
                if tokens_so_far + part_tokens_len > max_tokens:
                    texts.append(text_so_far)
                    text_so_far = header
                    tokens_so_far = tokens_pre_len
                text_so_far += part
                tokens_so_far += part_tokens_len

            if tokens_so_far != tokens_pre_len:
                texts.append(text_so_far)

        return texts


EMBEDDING_CTX = EmbeddingContext()