File size: 3,826 Bytes
d40f2bc
 
 
 
 
 
 
 
 
 
 
85aba76
 
85657cd
d40f2bc
bc258f6
d40f2bc
85aba76
85657cd
d40f2bc
 
 
 
 
 
 
 
 
 
85657cd
d40f2bc
 
309aedc
d40f2bc
 
 
 
 
 
f935acc
d40f2bc
f935acc
 
 
 
 
 
 
669cc9e
 
 
f935acc
d40f2bc
3dadf7d
d40f2bc
 
 
 
 
 
 
 
 
 
 
 
 
ee89cad
 
 
 
 
 
 
 
afbfe79
85657cd
afbfe79
 
be14f81
4cd2aa6
d40f2bc
 
 
 
4cd2aa6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import faiss
import pickle
import datasets
import numpy as np
import requests
import streamlit as st
from vector_engine.utils import vector_search 
from transformers import AutoModel, AutoTokenizer

from datasets import load_dataset

#@st.cache
@st.cache_data
def read_data(dataset_repo='dhmeltzer/ELI5_embedded'):
    """Read the data from huggingface."""
    return load_dataset(dataset_repo)['train']

@st.cache_data
def load_faiss_index(path_to_faiss="./faiss_index.pickle"):
    """Load and deserialize the Faiss index."""
    with open(path_to_faiss, "rb") as h:
        data = pickle.load(h)
    return faiss.deserialize_index(data)

def main():
    # Load data and models
    data = read_data()
    faiss_index = load_faiss_index()

    model_id="sentence-transformers/all-MiniLM-L6-v2"
    
    api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
    headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}

    def query(texts):
        response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}})
        return response.json()


    st.title("Vector-based of the r/ELI5 dataset with Sentence Transformers and Faiss")

    st.markdown("""This application lets you perform a semantic search through questions in the r/ELI5 <a href="https://huggingface.co/datasets/eli5">dataset</a>.
                    The questions and user input are encoded into a high-dimensional vectors space using a Sentence-Transformer model, and in particular the checkpoint <a href="https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2">sentence-transformers/all-MiniLM-L6-v2</a>.
                    To perform the search we use FAISS, which performs an efficient similarity search through the (vectorized) questions.
                    The ELI5 dataset contains posts from three subreddits, AskScience (asks), AskHistorians (askh), and ExplainLikeImFive (eli5).
                    The score corresponds to the rating each answer recieved when posted on Reddit.
                    We unfortunately cannot verify the veracity of any of the answers posted!
                    """)

    st.markdown(""" To change the number of search results listed, simply move the slider located in the sidebar. 
                """)
    
    # User search
    user_input = st.text_area("Search box", "What is spacetime made out of?")

    # Filters
    st.sidebar.markdown("**Filters**")
    num_results = st.sidebar.slider("Number of search results", 1, 50, 1)

    vector = query([user_input])
    # Fetch results
    if user_input:
        # Get paper IDs
        _, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results)
        
        # Get individual results
        for id_ in I.flatten().tolist():
            row = data[id_]
            
            answers=row['answers']['text']
            answers_URLs = row['answers_urls']['url']
            for k in range(len(answers_URLs)):
                answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers]
            
            
            st.write(f"**Title**: {row['title']}")
            st.write(f"**Split**: {row['split']}")
            st.write(f"**Score**: {row['answers']['score'][0]}")
            st.write(f"**Top Answer**: {answers[0]}")
            st.write("-"*20) 



if __name__ == "__main__":
    main()

#@st.cache(allow_output_mutation=True)
#def load_bert_model(name="nli-distilbert-base"):
#    """Instantiate a sentence-level DistilBERT model."""
#    return AutoModel.from_pretrained(f'sentence-transformers/{name}')
#
#@st.cache(allow_output_mutation=True)
#def load_tokenizer(name="nli-distilbert-base"):
#    return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}')

#@st.cache(allow_output_mutation=True)