llmahmad commited on
Commit
97814f0
·
verified ·
1 Parent(s): 9d7ed27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -1,18 +1,21 @@
1
  import os
2
- os.system('pip install streamlit transformers torch')
3
-
4
  import streamlit as st
5
- from transformers import pipeline
6
- import torch
7
 
8
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
9
 
10
  # Load the model and tokenizer
11
  model_path = '.' # Path to the current directory where files are located
12
 
13
- tokenizer = AutoTokenizer.from_pretrained(model_path)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
15
- summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)
 
 
 
 
 
16
 
17
  st.title("Text Summarization with Fine-Tuned Model")
18
  st.write("Enter text to generate a summary using the fine-tuned summarization model.")
@@ -21,9 +24,9 @@ text = st.text_area("Input Text", height=200)
21
  if st.button("Summarize"):
22
  if text:
23
  with st.spinner("Summarizing..."):
24
- summary = summarizer(text, max_length=150, min_length=30, do_sample=False)
25
  st.success("Summary Generated")
26
- st.write(summary[0]['summary_text'])
27
  else:
28
  st.warning("Please enter some text to summarize.")
29
 
@@ -39,4 +42,4 @@ if __name__ == "__main__":
39
  </style>
40
  """,
41
  unsafe_allow_html=True
42
- )
 
1
  import os
 
 
2
  import streamlit as st
3
+ from transformers import PegasusTokenizer, PegasusForConditionalGeneration
 
4
 
5
+ # Install necessary packages (only needed if not already installed)
6
+ os.system('pip install streamlit transformers torch')
7
 
8
  # Load the model and tokenizer
9
  model_path = '.' # Path to the current directory where files are located
10
 
11
+ tokenizer = PegasusTokenizer.from_pretrained(model_path)
12
+ model = PegasusForConditionalGeneration.from_pretrained(model_path)
13
+
14
+ def summarize_text(text):
15
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="longest")
16
+ summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=30, do_sample=False)
17
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
18
+ return summary
19
 
20
  st.title("Text Summarization with Fine-Tuned Model")
21
  st.write("Enter text to generate a summary using the fine-tuned summarization model.")
 
24
  if st.button("Summarize"):
25
  if text:
26
  with st.spinner("Summarizing..."):
27
+ summary = summarize_text(text)
28
  st.success("Summary Generated")
29
+ st.write(summary)
30
  else:
31
  st.warning("Please enter some text to summarize.")
32
 
 
42
  </style>
43
  """,
44
  unsafe_allow_html=True
45
+ )