Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_from_disk | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.quantization import quantize_embeddings | |
import faiss | |
from usearch.index import Index | |
import numpy as np | |
import os | |
base_path = os.getcwd() | |
full_path = os.path.join(base_path, 'conala') | |
conala_dataset = load_from_disk(full_path) | |
int8_view = Index.restore(os.path.join(base_path, 'conala_int8_usearch.index'), view=True) | |
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(os.path.join(base_path, 'conala.index')) | |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
def search(query, top_k: int = 20): | |
# 1. Embed the query as float32 | |
query_embedding = model.encode(query) | |
# 2. Quantize the query to ubinary. To perform actual search with faiss | |
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
# 3. Search the binary index | |
index = binary_index | |
_scores, binary_ids = index.search(query_embedding_ubinary, top_k) | |
binary_ids = binary_ids[0] | |
# 4. Load the corresponding int8 embeddings. To perform rescoring to calculate score of fetched documents. | |
int8_embeddings = int8_view[binary_ids].astype(int) | |
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
scores = query_embedding @ int8_embeddings.T | |
# 6. Sort the scores and return the top_k | |
indices = scores.argsort()[::-1][:top_k] | |
top_k_indices = binary_ids[indices] | |
top_k_scores = scores[indices] | |
top_k_codes = conala_dataset[top_k_indices] | |
return top_k_codes | |
def response_generator(user_prompt): | |
top_k_outputs = search(user_prompt) | |
probs = top_k_outputs['prob'] | |
snippets = top_k_outputs['snippet'] | |
idx = np.argsort(probs)[::-1] | |
results = np.array(snippets)[idx] | |
filtered_results = [] | |
for item in results: | |
if len(filtered_results)<3: | |
if item not in filtered_results: | |
filtered_results.append(item) | |
output_template = "User Query: {user_query}\nBelow are some examples of previous conversations.\nQuery: {query1} Solution: {solution1}\nQuery: {query2} Solution: {solution2}\nYou may use the above examples for reference only. Create your own solution and provide only the solution" | |
output_template = "The top three most relevant code snippets from the database are:\n\n1. {snippet1}\n\n2. {snippet2}\n\n3. {snippet3}" | |
output = f'{output_template.format(snippet1=filtered_results[0],snippet2=filtered_results[1],snippet3=filtered_results[2])}' | |
return {output_box:output} | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Embedding Quantization | |
## Quantized Semantic Search | |
- ***Embedding:*** all-MiniLM-L6-v2 | |
- ***Vetor DB:*** faiss, USearch | |
- ***Vector_DB Size:*** `5,93,891` | |
""") | |
state_var = gr.State([]) | |
input_box = gr.Textbox(autoscroll=True,visible=True,label='User',info="Enter a query.",value="How to extract the n-th elements from a list of tuples in python?") | |
output_box = gr.Textbox(autoscroll=True,max_lines=30,value="Output",label='Assistant') | |
gr.Interface(fn=response_generator, inputs=[input_box], outputs=[output_box], | |
delete_cache=(20,10), | |
allow_flagging='never') | |
demo.queue() | |
demo.launch() | |