File size: 2,827 Bytes
d40f2bc
 
 
 
 
 
 
 
 
 
 
85aba76
 
d40f2bc
 
85aba76
d40f2bc
85aba76
d40f2bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309aedc
d40f2bc
 
 
 
 
 
 
 
 
3dadf7d
d40f2bc
 
 
 
 
 
 
 
 
 
 
 
 
ee89cad
 
 
 
 
 
 
 
d40f2bc
e7a5355
 
 
 
 
 
be14f81
4cd2aa6
d40f2bc
 
 
 
4cd2aa6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import faiss
import pickle
import datasets
import numpy as np
import requests
import streamlit as st
from vector_engine.utils import vector_search 
from transformers import AutoModel, AutoTokenizer

from datasets import load_dataset

#@st.cache
@st.cache_data
def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'):
    """Read the data from huggingface."""
    return load_dataset(dataset_repo)['validation_asks']

@st.cache_data
def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"):
    """Load and deserialize the Faiss index."""
    with open(path_to_faiss, "rb") as h:
        data = pickle.load(h)
    return faiss.deserialize_index(data)

def main():
    # Load data and models
    data = read_data()
    #model = load_bert_model()
    #tok = load_tokenizer()
    faiss_index = load_faiss_index()

    model_id="sentence-transformers/nli-distilbert-base"
    
    api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
    headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}

    def query(texts):
        response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}})
        return response.json()


    st.title("Vector-based searches with Sentence Transformers and Faiss")

    # User search
    user_input = st.text_area("Search box", "What is spacetime made out of?")

    # Filters
    st.sidebar.markdown("**Filters**")
    num_results = st.sidebar.slider("Number of search results", 1, 50, 1)

    vector = query([user_input])
    # Fetch results
    if user_input:
        # Get paper IDs
        _, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results)
        
        # Get individual results
        for id_ in I.flatten().tolist():
            row = data[id_]
            
            answers=row['answers']['text']
            answers_URLs = row['answers_urls']['url']
            for k in range(len(answers_URLs)):
                answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers]
            
            
            st.write(
                f"**Title**: {row['title']}")
            st.write(
                f"**Score**: {row['answers']['score'][0]}")
            st.write(
                f"**Top Answer**: {answers[0]}
                ")
            st.write("-"*20) 



if __name__ == "__main__":
    main()

#@st.cache(allow_output_mutation=True)
#def load_bert_model(name="nli-distilbert-base"):
#    """Instantiate a sentence-level DistilBERT model."""
#    return AutoModel.from_pretrained(f'sentence-transformers/{name}')
#
#@st.cache(allow_output_mutation=True)
#def load_tokenizer(name="nli-distilbert-base"):
#    return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}')

#@st.cache(allow_output_mutation=True)