File size: 2,404 Bytes
360f505
 
 
e0169c8
 
 
 
 
 
 
 
 
 
d7fdb42
360f505
 
 
 
 
e0169c8
 
 
 
 
 
 
 
 
 
 
 
d7fdb42
 
 
 
 
360f505
e0169c8
d7fdb42
e0169c8
 
d7fdb42
e0169c8
d7fdb42
 
e0169c8
d7fdb42
 
 
 
360f505
d7fdb42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360f505
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import shutil
import traceback

import lancedb
import torch
import pyarrow as pa
import pandas as pd
from pathlib import Path
import tqdm
import numpy as np

from sentence_transformers import SentenceTransformer

from markdown_to_text import *
from settings import *


shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
db = lancedb.connect(LANCEDB_DIRECTORY)
batch_size = 32

model = SentenceTransformer(EMB_MODEL_NAME)
model.eval()

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

schema = pa.schema([
    pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
    pa.field(TEXT_COLUMN_NAME, pa.string()),
    pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
])
tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")

input_dir = Path(MARKDOWN_SOURCE_DIR)
files = list(input_dir.rglob("*"))

chunks = []
for file in files:
    if not os.path.isfile(file):
        continue

    file_path, file_ext = os.path.splitext(os.path.relpath(file, input_dir))
    if file_ext != '.md':
        print(f'Skipped {file_ext} extension: {file}')
        continue

    doc_header = ' / '.join(split_path(file_path)) + ':\n\n'
    with open(file, encoding='utf-8') as f:
        f = f.read()
        f = remove_comments(f)
        f = split_markdown(f)
        chunks.extend((doc_header + chunk, os.path.abspath(file)) for chunk in f)

from matplotlib import pyplot as plt
plt.hist([len(c) for c, d in chunks], bins=100)
plt.show()

for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
    texts, doc_paths = [], []
    for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
        if len(text) > 0:
            texts.append(text)
            doc_paths.append(doc_path)

    encoded = model.encode(texts, normalize_embeddings=True, device=device)
    encoded = [list(vec) for vec in encoded]

    df = pd.DataFrame({
        VECTOR_COLUMN_NAME: encoded,
        TEXT_COLUMN_NAME: texts,
        DOCUMENT_PATH_COLUMN_NAME: doc_paths,
    })

    tbl.add(df)


# '''
# create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
# with the size of the transformer docs, index is not really needed
# but we'll do it for demonstration purposes
# '''
# tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)