Spaces:
Runtime error
Runtime error
File size: 5,363 Bytes
1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 b71da86 1228bd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# import streamlit as st
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# import nltk
# import math
# import torch
# model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
# # model_name = "afnanmmir/t5-base-axriv-to-abstract-3"
# max_input_length = 1024
# max_output_length = 256
# st.header("Generate summaries")
# st_model_load = st.text('Loading summary generator model...')
# # @st.cache(allow_output_mutation=True)
# @st.cache_data
# def load_model():
# print("Loading model...")
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# nltk.download('punkt')
# print("Model loaded!")
# return tokenizer, model
# tokenizer, model = load_model()
# st.success('Model loaded!')
# st_model_load.text("")
# with st.sidebar:
# # st.header("Model parameters")
# # if 'num_titles' not in st.session_state:
# # st.session_state.num_titles = 5
# # def on_change_num_titles():
# # st.session_state.num_titles = num_titles
# # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
# # if 'temperature' not in st.session_state:
# # st.session_state.temperature = 0.7
# # def on_change_temperatures():
# # st.session_state.temperature = temperature
# # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
# # st.markdown("_High temperature means that results are more random_")
# if 'text' not in st.session_state:
# st.session_state.text = ""
# st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
# def generate_summary():
# st.session_state.text = st_text_area
# # tokenize text
# inputs = ["summarize: " + st_text_area]
# # print(inputs)
# inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
# # print("Tokenized inputs: ")
# # print(inputs)
# outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=64)
# # print("outputs", outputs)
# decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# # print("Decoded_outputs", decoded_outputs)
# predicted_summaries = nltk.sent_tokenize(decoded_outputs.strip())
# # print("Predicted summaries", predicted_summaries)
# # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# # predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
# st.session_state.summaries = predicted_summaries
# # generate title button
# st_generate_button = st.button('Generate summary', on_click=generate_summary)
# # title generation labels
# if 'summaries' not in st.session_state:
# st.session_state.summaries = []
# if len(st.session_state.summaries) > 0:
# # print("In summaries if")
# with st.container():
# st.subheader("Generated summaries")
# for summary in st.session_state.summaries:
# st.markdown("__" + summary + "__")
# -------------------------------
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
import math
import torch
model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
max_input_length = 1024
max_output_length = 256
min_output_length = 64
st.header("Generate summaries for articles")
st_model_load = st.text('Loading summary generator model...')
@st.cache(allow_output_mutation=True)
def load_model():
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
nltk.download('punkt')
print("Model loaded!")
return tokenizer, model
tokenizer, model = load_model()
st.success('Model loaded!')
st_model_load.text("")
if 'text' not in st.session_state:
st.session_state.text = ""
st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
def generate_summary():
st.session_state.text = st_text_area
# tokenize text
inputs = ["summarize: " + st_text_area]
inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
# compute predictions
outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=min_output_length)
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
st.session_state.summaries = predicted_summaries
# generate summary button
st_generate_button = st.button('Generate summary', on_click=generate_summary)
# summary generation labels
if 'summaries' not in st.session_state:
st.session_state.summaries = []
if len(st.session_state.summaries) > 0:
with st.container():
st.subheader("Generated summaries")
for summary in st.session_state.summaries:
st.markdown("__" + summary + "__") |