akshay7's picture
ADD: Ability to pull abstracts from article ids
d15d6ce
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
import torch
import arxiv
def main():
id_provided = True
st.set_page_config(
layout="wide",
initial_sidebar_state="auto",
page_title="Title Generator!",
page_icon=None,
)
st.title("Title Generator: Generate a title from the abstract of a paper")
st.text("")
st.text("")
example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""",
)
# st.selectbox("Provide the link/id for an arxiv paper", example_prompts)
# Take the message which needs to be processed
message = st.text_area("...or paste a paper's abstract to generate a title")
if len(message)<1:
message=example
id_provided = True
ids = message.split('/')[-1]
search = arxiv.Search(id_list=[ids])
for result in search.results():
message = result.summary
title = result.title
else:
id_provided = False
st.text("")
models_to_choose = [
"AryanLala/autonlp-Scientific_Title_Generator-34558227",
"shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full"
]
BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose)
def preprocess(text):
if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") |
(BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")):
return [text]
else:
st.error("Please select a model first")
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
return model, tokenizer
def get_summary(text):
with st.spinner(text="Processing your request"):
model, tokenizer = load_model()
preprocessed = preprocess(text)
inputs = tokenizer(
preprocessed, truncation=True, padding="longest", return_tensors="pt"
)
output = model.generate(
**inputs,
max_length=60,
num_beams=10,
num_return_sequences=1,
temperature=1.5,
)
target_text = tokenizer.batch_decode(output, skip_special_tokens=True)
return target_text[0]
# Define function to run when submit is clicked
def submit(message):
if len(message) > 0:
summary = get_summary(message)
if id_provided:
html_str = f"""
<style>
p.a {{
font: 20px Courier;
}}
</style>
<p class="a"><b>Title Generated:></b> {summary} </p>
<p class="a"><b>Original Title:></b> {title} </p>
"""
else:
html_str = f"""
<style>
p.a {{
font: 20px Courier;
}}
</style>
<p class="a"><b>Title Generated:></b> {summary} </p>
"""
st.markdown(html_str, unsafe_allow_html=True)
# st.markdown(emoji)
else:
st.error("The text can't be empty")
# Run algo when submit button is clicked
if st.button("Submit"):
submit(message)
with st.expander("Additional Information"):
st.markdown("""
The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset)
The task of the models is to suggest an appropraite title from the abstract of a scientific paper.
The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data
from the Cs.AI (Artificial Intelligence) category of papers.
The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full)
was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition)
Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids
""",unsafe_allow_html=True,)
st.text('\n')
st.text('\n')
st.markdown(
'''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89)
</span>''',
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()