File size: 3,399 Bytes
ef56763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

import gradio as gr
from datasets import load_from_disk
import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
import numpy as np
import os

base_path = os.getcwd()
full_path = os.path.join(base_path, 'conala')
conala_dataset = load_from_disk(full_path)

int8_view = Index.restore(os.path.join(base_path, 'conala_int8_usearch.index'), view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(os.path.join(base_path, 'conala.index'))

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def search(query, top_k: int = 20):
    # 1. Embed the query as float32
    query_embedding = model.encode(query)

    # 2. Quantize the query to ubinary. To perform actual search with faiss
    query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")


    # 3. Search the binary index 
    index =  binary_index
    _scores, binary_ids = index.search(query_embedding_ubinary, top_k)
    binary_ids = binary_ids[0]


    # 4. Load the corresponding int8 embeddings. To perform rescoring to calculate score of fetched documents.
    int8_embeddings = int8_view[binary_ids].astype(int)

    # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
    scores = query_embedding @ int8_embeddings.T

    # 6. Sort the scores and return the top_k
    indices = scores.argsort()[::-1][:top_k]
    top_k_indices = binary_ids[indices]
    top_k_scores = scores[indices]

    top_k_codes = conala_dataset[top_k_indices]

    return top_k_codes


def response_generator(user_prompt):
    top_k_outputs = search(user_prompt)
    probs = top_k_outputs['prob']
    snippets = top_k_outputs['snippet']
    idx = np.argsort(probs)[::-1]
    results = np.array(snippets)[idx]
    filtered_results = []
    for item in results:
        if len(filtered_results)<3:
            if item not in filtered_results:
                filtered_results.append(item)

    output_template = "User Query: {user_query}\nBelow are some examples of previous conversations.\nQuery: {query1} Solution: {solution1}\nQuery: {query2} Solution: {solution2}\nYou may use the above examples for reference only. Create your own solution and provide only the solution"
    output_template = "The top three most relevant code snippets from the database are:\n\n1. {snippet1}\n\n2. {snippet2}\n\n3. {snippet3}"
    output = f'{output_template.format(snippet1=filtered_results[0],snippet2=filtered_results[1],snippet3=filtered_results[2])}'

    return {output_box:output}  


with gr.Blocks() as demo:
    
    gr.Markdown(
    """
    # Embedding Quantization

    ## Quantized Semantic Search

    - ***Embedding:*** all-MiniLM-L6-v2
    - ***Vetor DB:*** faiss, USearch
    - ***Vector_DB Size:*** `5,93,891`

    """)

    state_var = gr.State([])


    input_box = gr.Textbox(autoscroll=True,visible=True,label='User',info="Enter a query.",value="How to extract the n-th elements from a list of tuples in python?")
    output_box = gr.Textbox(autoscroll=True,max_lines=30,value="Output",label='Assistant')
    gr.Interface(fn=response_generator, inputs=[input_box], outputs=[output_box],
                 delete_cache=(20,10),
                 allow_flagging='never')
    
demo.queue()
demo.launch()