embeds / main.py
chainyo's picture
add password for hf space
43478d2
raw
history blame
2.33 kB
import pinecone
import requests
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
from config import config
def search(text: str, k: int = 5):
"""Get the k closest articles to the text."""
embeds = _get_embeddings(text)
r = requests.post(
f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
headers={
"Api-Key": config.pinecone_api_key,
"accept": "application/json",
"content-type": "application/json",
},
json={
"vector": embeds,
"top_k": k,
"includeMetadata": True,
"includeValues": False,
},
)
if r.status_code == 200:
return r.json()
else:
raise Exception(f"Error: {r.status_code} - {r.text}")
def _get_embeddings(text: str):
inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
last_hidden_states = st.session_state.model(**inputs_ids)[0]
return last_hidden_states.mean(dim=1).squeeze().tolist()
st.title("PubMed Embeddings")
st.subheader("Search for a PubMed article and get its id.")
with st.text_input("Password", type="password") as password:
if password == config.password:
st.write("Password correct!")
text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")
with st.spinner("Loading Embedding Model..."):
pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
if "index" not in st.session_state:
st.session_state.index = pinecone.Index(config.pinecone_index)
if "tokenizer" not in st.session_state:
st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if "model" not in st.session_state:
st.session_state.model = AutoModel.from_pretrained(config.model_name)
if st.button("Search"):
with st.spinner("Searching..."):
results = search(text)
for res in results["matches"]:
st.write(f"{res['id']} - confidence: {res['score']:.2f}")
else:
st.write("Password incorrect!")