Spaces:
Runtime error
Runtime error
import os | |
import time | |
from itertools import islice | |
import shutil | |
from threading import Thread | |
import lancedb | |
import gradio as gr | |
import polars as pl | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
STYLE = """ | |
.gradio-container td span { | |
overflow: auto !important; | |
} | |
""".strip() | |
# | |
EMBEDDING_MODEL = SentenceTransformer("TaylorAI/bge-micro") | |
MAX_N_ROWS = 3_000_000 | |
N_ROWS_BATCH = 5_000 | |
N_SEARCH_RESULTS = 15 | |
CRAWL_DUMP = "CC-MAIN-2020-05" | |
DB = None | |
DISPLAY_COLUMNS = [ | |
"text", | |
"url", | |
"token_count", | |
"count", | |
] | |
DISPLAY_COLUMN_TYPES = [ | |
"str", | |
"str", | |
"number", | |
"number", | |
] | |
DISPLAY_COLUMN_WIDTHS = [ | |
"300px", | |
"100px", | |
"50px", | |
"25px", | |
] | |
def rename_embedding_column(row): | |
vector = row["embedding"] | |
row["vector"] = vector | |
del row["embedding"] | |
return row | |
def read_header_markdown() -> str: | |
with open("./README.md", "r") as fp: | |
text = fp.read(-1) | |
# Get only the markdown following the HF metadata section. | |
text = text.split("\n---\n")[-1] | |
return text.replace("{{CRAWL_DUMP}}", CRAWL_DUMP) | |
def db(): | |
global DB | |
if DB is None: | |
DB = lancedb.connect("data") | |
return DB | |
def load_data_sample(): | |
time.sleep(5) | |
# remove any data that was already there; we want to replace it. | |
if os.path.exists("data"): | |
shutil.rmtree("data") | |
rows = load_dataset( | |
"airtrain-ai/fineweb-edu-fortified", | |
name=CRAWL_DUMP, | |
split="train", | |
streaming=True, | |
) | |
print("Loading data") | |
# at this point you could iterate over the rows. | |
# Here, we'll take a sample of rows with size | |
# MAX_N_ROWS. Using islice will load only the amount | |
# we asked for and no extras. | |
sample = islice(rows, MAX_N_ROWS) | |
table = None | |
n_rows_loaded = 0 | |
while True: | |
batch = list(islice(sample, N_ROWS_BATCH)) | |
if len(batch) == 0: | |
break | |
# We'll put it in a vector DB for easy vector search. | |
# rename "embedding" column to "vector" | |
data = [rename_embedding_column(row) for row in batch] | |
n_rows_loaded += len(data) | |
if table is None: | |
print("Creating table") | |
table = db().create_table("data", data=data) | |
# index the embedding column for fast search. | |
print("Indexing table") | |
table.create_index(num_sub_vectors=1) | |
else: | |
table.add(data) | |
print(f"Loaded {n_rows_loaded} rows") | |
print("Done loading data") | |
def search(search_phrase: str) -> tuple[pl.DataFrame, int]: | |
while "data" not in db().table_names(): | |
# Data is loaded asynchronously. Make sure there is at least | |
# some in the table before searching. | |
time.sleep(1) | |
# Create our search vector | |
embedding = EMBEDDING_MODEL.encode([search_phrase])[0] | |
# Search | |
table = db().open_table("data") | |
data_frame = table.search(embedding).limit(N_SEARCH_RESULTS).to_polars() | |
return ( | |
# Return only what we want to display | |
data_frame.select(*[pl.col(c) for c in DISPLAY_COLUMNS]).to_pandas(), | |
table.count_rows(), | |
) | |
with gr.Blocks(css=STYLE) as demo: | |
gr.HTML(f"<style>{STYLE}</style>") | |
with gr.Row(): | |
gr.Markdown(read_header_markdown()) | |
with gr.Row(): | |
input_text = gr.Textbox(label="Search phrase", scale=100) | |
search_button = gr.Button("Search", scale=1, min_width=100) | |
with gr.Row(): | |
rows_searched = gr.Number( | |
label="Rows searched", | |
show_label=True, | |
) | |
with gr.Row(): | |
search_results = gr.DataFrame( | |
headers=DISPLAY_COLUMNS, | |
type="pandas", | |
datatype=DISPLAY_COLUMN_TYPES, | |
row_count=N_SEARCH_RESULTS, | |
col_count=(len(DISPLAY_COLUMNS), "fixed"), | |
column_widths=DISPLAY_COLUMN_WIDTHS, | |
elem_classes=".df-text-col", | |
) | |
search_button.click( | |
search, | |
[input_text], | |
[search_results, rows_searched], | |
) | |
# load data on another thread so we can start searching even before it's | |
# all loaded. | |
data_load_thread = Thread(target=load_data_sample, daemon=True) | |
data_load_thread.start() | |
print("Launching app") | |
demo.launch() | |