Spaces:
Runtime error
Runtime error
File size: 4,996 Bytes
237aa79 d15d6ce 237aa79 d15d6ce 237aa79 d15d6ce 237aa79 d15d6ce 237aa79 6f50845 237aa79 6f50845 237aa79 d15d6ce 237aa79 d15d6ce 237aa79 d15d6ce 237aa79 d15d6ce a232db1 6f50845 d15d6ce 6f50845 a232db1 6f50845 d15d6ce 6f50845 237aa79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
import torch
import arxiv
def main():
id_provided = True
st.set_page_config(
layout="wide",
initial_sidebar_state="auto",
page_title="Title Generator!",
page_icon=None,
)
st.title("Title Generator: Generate a title from the abstract of a paper")
st.text("")
st.text("")
example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""",
)
# st.selectbox("Provide the link/id for an arxiv paper", example_prompts)
# Take the message which needs to be processed
message = st.text_area("...or paste a paper's abstract to generate a title")
if len(message)<1:
message=example
id_provided = True
ids = message.split('/')[-1]
search = arxiv.Search(id_list=[ids])
for result in search.results():
message = result.summary
title = result.title
else:
id_provided = False
st.text("")
models_to_choose = [
"AryanLala/autonlp-Scientific_Title_Generator-34558227",
"shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full"
]
BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose)
def preprocess(text):
if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") |
(BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")):
return [text]
else:
st.error("Please select a model first")
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
return model, tokenizer
def get_summary(text):
with st.spinner(text="Processing your request"):
model, tokenizer = load_model()
preprocessed = preprocess(text)
inputs = tokenizer(
preprocessed, truncation=True, padding="longest", return_tensors="pt"
)
output = model.generate(
**inputs,
max_length=60,
num_beams=10,
num_return_sequences=1,
temperature=1.5,
)
target_text = tokenizer.batch_decode(output, skip_special_tokens=True)
return target_text[0]
# Define function to run when submit is clicked
def submit(message):
if len(message) > 0:
summary = get_summary(message)
if id_provided:
html_str = f"""
<style>
p.a {{
font: 20px Courier;
}}
</style>
<p class="a"><b>Title Generated:></b> {summary} </p>
<p class="a"><b>Original Title:></b> {title} </p>
"""
else:
html_str = f"""
<style>
p.a {{
font: 20px Courier;
}}
</style>
<p class="a"><b>Title Generated:></b> {summary} </p>
"""
st.markdown(html_str, unsafe_allow_html=True)
# st.markdown(emoji)
else:
st.error("The text can't be empty")
# Run algo when submit button is clicked
if st.button("Submit"):
submit(message)
with st.expander("Additional Information"):
st.markdown("""
The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset)
The task of the models is to suggest an appropraite title from the abstract of a scientific paper.
The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data
from the Cs.AI (Artificial Intelligence) category of papers.
The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full)
was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition)
Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids
""",unsafe_allow_html=True,)
st.text('\n')
st.text('\n')
st.markdown(
'''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89)
</span>''',
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()
|