VishnuPottabatthini commited on
Commit
231a99e
·
verified ·
1 Parent(s): 4bb2b59

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import BartTokenizer, BartForConditionalGeneration
4
+
5
+ # Define the BART model and tokenizer
6
+ MODEL_NAME = 'VishnuPottabatthini/BART_demo' # Change this to the model you want to use
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load the tokenizer and model
10
+ tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
11
+ model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
12
+
13
+ # Define the summarization function
14
+ def summarize(text, state):
15
+ try:
16
+ # Tokenize the input text
17
+ inputs = tokenizer(
18
+ text,
19
+ return_tensors="pt",
20
+ truncation=True,
21
+ max_length=1024 # Adjust max length according to your model's capabilities
22
+ ).to(device)
23
+
24
+ # Generate the summary
25
+ summary_ids = model.generate(
26
+ inputs['input_ids'],
27
+ attention_mask=inputs['attention_mask'],
28
+ max_length=150, # Maximum length of the summary
29
+ min_length=30, # Minimum length of the summary
30
+ num_beams=4, # Beam search to improve the quality of generated text
31
+ early_stopping=True
32
+ )
33
+
34
+ # Decode the summary
35
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
36
+ return state + "\n" + summary, state + "\n" + summary
37
+
38
+ except Exception as e:
39
+ return str(e), state
40
+
41
+ # Create the Gradio interface
42
+ mf_summarize = gr.Interface(
43
+ fn=summarize,
44
+ inputs=[
45
+ gr.Textbox(placeholder="Enter text to summarize...", lines=10),
46
+ gr.State(value="")
47
+ ],
48
+ outputs=[
49
+ gr.Textbox(lines=15, label="Summary"),
50
+ gr.State()
51
+ ],
52
+ theme="huggingface",
53
+ title="BART Summarization",
54
+ live=True,
55
+ description=(
56
+ "Enter a long piece of text to generate a concise summary using a BART model. "
57
+ "This demo uses a custom BART model from 🤗 Transformers."
58
+ )
59
+ )
60
+
61
+ # Launch the Gradio interface
62
+ mf_summarize.launch()