import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np import torch import arxiv def main(): id_provided = True st.set_page_config( layout="wide", initial_sidebar_state="auto", page_title="Title Generator!", page_icon=None, ) st.title("Title Generator: Generate a title from the abstract of a paper") st.text("") st.text("") example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""", ) # st.selectbox("Provide the link/id for an arxiv paper", example_prompts) # Take the message which needs to be processed message = st.text_area("...or paste a paper's abstract to generate a title") if len(message)<1: message=example id_provided = True ids = message.split('/')[-1] search = arxiv.Search(id_list=[ids]) for result in search.results(): message = result.summary title = result.title else: id_provided = False st.text("") models_to_choose = [ "AryanLala/autonlp-Scientific_Title_Generator-34558227", "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full" ] BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose) def preprocess(text): if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") | (BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")): return [text] else: st.error("Please select a model first") @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False) def load_model(): tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) return model, tokenizer def get_summary(text): with st.spinner(text="Processing your request"): model, tokenizer = load_model() preprocessed = preprocess(text) inputs = tokenizer( preprocessed, truncation=True, padding="longest", return_tensors="pt" ) output = model.generate( **inputs, max_length=60, num_beams=10, num_return_sequences=1, temperature=1.5, ) target_text = tokenizer.batch_decode(output, skip_special_tokens=True) return target_text[0] # Define function to run when submit is clicked def submit(message): if len(message) > 0: summary = get_summary(message) if id_provided: html_str = f"""
Title Generated:> {summary}
Original Title:> {title}
""" else: html_str = f"""Title Generated:> {summary}
""" st.markdown(html_str, unsafe_allow_html=True) # st.markdown(emoji) else: st.error("The text can't be empty") # Run algo when submit button is clicked if st.button("Submit"): submit(message) with st.expander("Additional Information"): st.markdown(""" The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset) The task of the models is to suggest an appropraite title from the abstract of a scientific paper. The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data from the Cs.AI (Artificial Intelligence) category of papers. The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full) was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition) Also, Thank you to arXiv for use of its open access interoperability. It allows us to pull the required abstracts from passed ids """,unsafe_allow_html=True,) st.text('\n') st.text('\n') st.markdown( '''App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89) ''', unsafe_allow_html=True, ) if __name__ == "__main__": main()