kaisugi's picture
update
de40660
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
import math
import os
import re
os.environ['KMP_DUPLICATE_LIB_OK']='True'
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
model.eval()
return model, tokenizer
@st.cache(allow_output_mutation=True)
def load_sentence_data():
sentence_df = pd.read_csv("sentence_data_858k.csv.gz")
return sentence_df
@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
npz_comp = np.load("sentence_embeddings_858k.npz")
sentence_embeddings = npz_comp["arr_0"]
faiss.normalize_L2(sentence_embeddings)
D = 768
N = 857610
Xt = sentence_embeddings[:100000]
X = sentence_embeddings
# Param of PQ
M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc.
nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
# Param of IVF
nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N)
# Param of HNSW
hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32
# Setup
quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)
# Train
index.train(Xt)
# Add
index.add(X)
# Search
index.nprobe = 8 # Runtime param. The number of cells that are visited for search.
return sentence_embeddings, index
@st.cache(allow_output_mutation=True)
def formulaic_phrase_extraction(sentences, model, tokenizer):
THRESHOLD = 0.01
LAYER = 10
output_sentences = []
with torch.no_grad():
inputs = tokenizer.batch_encode_plus(
sentences,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
attention = outputs[-1]
cls_attentions = torch.mean(attention[LAYER][0], dim=0)
for sentence, cls_attention in zip(sentences, cls_attentions):
check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
tokens = tokenizer.tokenize(sentence)
cur_tokens = tokens.copy()
while True:
flg = False
for idx, token in enumerate(cur_tokens):
if token.startswith("##"):
flg = True
back_token = token.replace("##", "")
front_token = cur_tokens.pop(idx-1)
cur_tokens[idx-1] = front_token + back_token
back_bool_val = check_bool_arr[idx]
front_bool_val = check_bool_arr.pop(idx-1)
check_bool_arr[idx-1] = front_bool_val and back_bool_val
if not flg:
break
result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
output_sentences.append(result)
return output_sentences
@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
with torch.no_grad():
inputs = tokenizer.encode_plus(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
outputs = model(**inputs)
query_embeddings = outputs.last_hidden_state[:, 0, :][0]
query_embeddings = query_embeddings.detach().cpu().numpy()
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_sentences = []
retrieved_paper_ids = []
for id in ids[0]:
cur_sentence = sentence_df.loc[id, "sentence"]
cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"
if len(exclude_word_list) == 0:
retrieved_sentences.append(cur_sentence)
retrieved_paper_ids.append(cur_link)
else:
exclude_word_list_regex = '|'.join(exclude_word_list)
pat = re.compile(f'{exclude_word_list_regex}')
if not bool(pat.search(cur_sentence)):
retrieved_sentences.append(cur_sentence)
retrieved_paper_ids.append(cur_link)
if phrase_annotated:
retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)
return retrieved_sentences, retrieved_paper_ids
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
sentence_df = load_sentence_data()
sentence_embeddings, index = load_sentence_embeddings_and_index()
st.markdown("## AI-based Paraphrasing for Academic Writing")
input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
input_words = st.text_input("exclude words (comma separated)", "good, result")
agree = st.checkbox('Include phrase annotation')
if st.button('search'):
exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)
result_table_markdown = "| sentence | source link |\n|:---|:---|\n"
for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
st.markdown(result_table_markdown, unsafe_allow_html=True)
st.markdown("---\n#### How this works")
st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging self-attention patterns within ScitoricsBERT.")