Razavipour commited on
Commit
8a8b25b
·
verified ·
1 Parent(s): 36a11f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -1,28 +1,44 @@
1
  import gradio as gr
2
  from transformers import LEDForConditionalGeneration, LEDTokenizer
3
  import torch
 
4
 
5
- # Load the model and tokenizer
6
- model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4")
 
 
 
7
  tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4")
8
 
9
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
- model = model.to(device)
 
 
 
 
11
 
12
- # Define the function for generating summaries
13
  def generate_summary(plot_synopsis):
14
- inputs = tokenizer(plot_synopsis, max_length=3000, truncation=True, padding="max_length", return_tensors="pt")
 
 
15
  inputs = inputs.to(device)
16
- outputs = model.generate(inputs['input_ids'], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True)
 
 
 
 
17
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
  return summary
19
 
20
- # Create a Gradio interface
21
- interface = gr.Interface(fn=generate_summary,
22
- inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter plot synopsis here..."),
23
- outputs=gr.Textbox(label="Plot Summary"),
24
- title="Plot Summary Generator",
25
- description="This demo generates a plot summary based on a plot synopsis using a fine-tuned LED model.")
 
 
26
 
27
- # Launch the interface
28
  interface.launch()
 
1
  import gradio as gr
2
  from transformers import LEDForConditionalGeneration, LEDTokenizer
3
  import torch
4
+ from datasets import load_dataset
5
 
6
+ # Set device to GPU if available
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load the LED model and tokenizer
10
+ model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device)
11
  tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4")
12
 
13
+ # Normalize the input text (plot synopsis)
14
+ def normalize_text(text):
15
+ text = text.lower() # Lowercase the text
16
+ text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines
17
+ text = re.sub(r'[^\w\s]', '', text) # Remove non-alphanumeric characters
18
+ return text
19
 
20
+ # Function to preprocess and generate summaries
21
  def generate_summary(plot_synopsis):
22
+ # Preprocess the plot_synopsis
23
+ inputs = tokenizer("summarize: " + normalize_text(plot_synopsis),
24
+ max_length=3000, truncation=True, padding="max_length", return_tensors="pt")
25
  inputs = inputs.to(device)
26
+
27
+ # Generate the summary
28
+ outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20,
29
+ length_penalty=2.0, num_beams=4, early_stopping=True)
30
+
31
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
  return summary
33
 
34
+ # Gradio interface to take plot synopsis and output a generated summary
35
+ interface = gr.Interface(
36
+ fn=generate_summary,
37
+ inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."),
38
+ outputs=gr.Textbox(label="Generated Summary"),
39
+ title="Plot Summary Generator",
40
+ description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model."
41
+ )
42
 
43
+ # Launch the Gradio interface
44
  interface.launch()