Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
import torch | |
import arxiv | |
def main(): | |
id_provided = True | |
st.set_page_config( | |
layout="wide", | |
initial_sidebar_state="auto", | |
page_title="Title Generator!", | |
page_icon=None, | |
) | |
st.title("Title Generator: Generate a title from the abstract of a paper") | |
st.text("") | |
st.text("") | |
example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""", | |
) | |
# st.selectbox("Provide the link/id for an arxiv paper", example_prompts) | |
# Take the message which needs to be processed | |
message = st.text_area("...or paste a paper's abstract to generate a title") | |
if len(message)<1: | |
message=example | |
id_provided = True | |
ids = message.split('/')[-1] | |
search = arxiv.Search(id_list=[ids]) | |
for result in search.results(): | |
message = result.summary | |
title = result.title | |
else: | |
id_provided = False | |
st.text("") | |
models_to_choose = [ | |
"AryanLala/autonlp-Scientific_Title_Generator-34558227", | |
"shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full" | |
] | |
BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose) | |
def preprocess(text): | |
if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") | | |
(BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")): | |
return [text] | |
else: | |
st.error("Please select a model first") | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) | |
return model, tokenizer | |
def get_summary(text): | |
with st.spinner(text="Processing your request"): | |
model, tokenizer = load_model() | |
preprocessed = preprocess(text) | |
inputs = tokenizer( | |
preprocessed, truncation=True, padding="longest", return_tensors="pt" | |
) | |
output = model.generate( | |
**inputs, | |
max_length=60, | |
num_beams=10, | |
num_return_sequences=1, | |
temperature=1.5, | |
) | |
target_text = tokenizer.batch_decode(output, skip_special_tokens=True) | |
return target_text[0] | |
# Define function to run when submit is clicked | |
def submit(message): | |
if len(message) > 0: | |
summary = get_summary(message) | |
if id_provided: | |
html_str = f""" | |
<style> | |
p.a {{ | |
font: 20px Courier; | |
}} | |
</style> | |
<p class="a"><b>Title Generated:></b> {summary} </p> | |
<p class="a"><b>Original Title:></b> {title} </p> | |
""" | |
else: | |
html_str = f""" | |
<style> | |
p.a {{ | |
font: 20px Courier; | |
}} | |
</style> | |
<p class="a"><b>Title Generated:></b> {summary} </p> | |
""" | |
st.markdown(html_str, unsafe_allow_html=True) | |
# st.markdown(emoji) | |
else: | |
st.error("The text can't be empty") | |
# Run algo when submit button is clicked | |
if st.button("Submit"): | |
submit(message) | |
with st.expander("Additional Information"): | |
st.markdown(""" | |
The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset) | |
The task of the models is to suggest an appropraite title from the abstract of a scientific paper. | |
The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data | |
from the Cs.AI (Artificial Intelligence) category of papers. | |
The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full) | |
was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition) | |
Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids | |
""",unsafe_allow_html=True,) | |
st.text('\n') | |
st.text('\n') | |
st.markdown( | |
'''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89) | |
</span>''', | |
unsafe_allow_html=True, | |
) | |
if __name__ == "__main__": | |
main() | |