Last commit not found
import faiss | |
import pickle | |
import datasets | |
import numpy as np | |
import requests | |
import streamlit as st | |
from vector_engine.utils import vector_search | |
from transformers import AutoModel, AutoTokenizer | |
from datasets import load_dataset | |
#@st.cache | |
def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'): | |
"""Read the data from huggingface.""" | |
return load_dataset(dataset_repo)['validation_asks'] | |
def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"): | |
"""Load and deserialize the Faiss index.""" | |
with open(path_to_faiss, "rb") as h: | |
data = pickle.load(h) | |
return faiss.deserialize_index(data) | |
def main(): | |
# Load data and models | |
data = read_data() | |
#model = load_bert_model() | |
#tok = load_tokenizer() | |
faiss_index = load_faiss_index() | |
model_id="sentence-transformers/nli-distilbert-base" | |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" | |
headers = {"Authorization": "Bearer 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"} | |
def query(texts): | |
response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) | |
return response.json() | |
st.title("Vector-based searches with Sentence Transformers and Faiss") | |
# User search | |
user_input = st.text_area("Search box", "ELI5 Dataset") | |
# Filters | |
st.sidebar.markdown("**Filters**") | |
num_results = st.sidebar.slider("Number of search results", 1, 50, 1) | |
vector = query([user_input]) | |
# Fetch results | |
if user_input: | |
# Get paper IDs | |
_, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results) | |
answers=row['answers']['text'] | |
answers_URLs = row['answers_urls']['url'] | |
for k in range(len(answers_URLs)): | |
answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) | |
for answer in answers] | |
frame = data | |
# Get individual results | |
for id_ in I.flatten().tolist(): | |
f = frame[id_] | |
st.write( | |
f"""**Title**: {f['title']} | |
**Top Answer**: {answers[0]} | |
""" | |
) | |
st.write("-"*20) | |
if __name__ == "__main__": | |
main() | |
#@st.cache(allow_output_mutation=True) | |
#def load_bert_model(name="nli-distilbert-base"): | |
# """Instantiate a sentence-level DistilBERT model.""" | |
# return AutoModel.from_pretrained(f'sentence-transformers/{name}') | |
# | |
#@st.cache(allow_output_mutation=True) | |
#def load_tokenizer(name="nli-distilbert-base"): | |
# return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}') | |
#@st.cache(allow_output_mutation=True) |