Spaces:
Sleeping
Sleeping
import faiss | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
from torch import Tensor | |
from transformers import AutoModel, AutoTokenizer | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
def average_pool(last_hidden_states: Tensor, | |
attention_mask: Tensor) -> Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large') | |
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large') | |
model.eval() | |
return model, tokenizer | |
def load_title_data(): | |
title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t") | |
return title_df | |
def load_title_embeddings(): | |
npz_comp = np.load("anlp2024.npz") | |
title_embeddings = npz_comp["arr_0"] | |
return title_embeddings | |
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df): | |
batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
query_embeddings = F.normalize(query_embeddings, p=2, dim=1) | |
_, ids = index.search(x=np.array([query_embeddings]), k=top_k) | |
retrieved_titles = [] | |
retrieved_pids = [] | |
for id in ids[0]: | |
retrieved_titles.append(title_df.loc[id, "title"]) | |
retrieved_pids.append(title_df.loc[id, "pid"]) | |
df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles}) | |
return df | |
if __name__ == "__main__": | |
model, tokenizer = load_model_and_tokenizer() | |
title_df = load_title_data() | |
title_embeddings = load_title_embeddings() | |
index = faiss.IndexFlatL2(768) | |
index.add(title_embeddings) | |
st.markdown("## NLP2024 類似論文検索") | |
input_text = st.text_input('input', '', placeholder='ここに論文のタイトルを入力してください') | |
top_k = st.number_input('top_k', min_value=1, value=10, step=1) | |
if st.button('検索'): | |
stripped_input_text = input_text.strip() | |
df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df) | |
st.table(df) |