Eemansleepdeprived commited on
Commit
651bb25
·
verified ·
1 Parent(s): 0868be8

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +117 -0
  2. requirements.txt +3 -0
main.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
+ import re
5
+ import os
6
+
7
+ def load_model():
8
+ """Load the model from local storage"""
9
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f"Using device: {torch_device}")
11
+
12
+ # Load tokenizer and model from local directory
13
+ tokenizer = PegasusTokenizer.from_pretrained('./models')
14
+ model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device)
15
+ return tokenizer, model, torch_device
16
+
17
+ def split_into_paragraphs(text):
18
+ """Split text into paragraphs while preserving empty lines."""
19
+ paragraphs = text.split('\n\n')
20
+ return [p.strip() for p in paragraphs if p.strip()]
21
+
22
+ def split_into_sentences(paragraph):
23
+ """Split paragraph into sentences using regex."""
24
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
25
+ return [s.strip() for s in sentences if s.strip()]
26
+
27
+ def get_response(input_text, num_return_sequences, tokenizer, model, torch_device):
28
+ batch = tokenizer.prepare_seq2seq_batch(
29
+ [input_text],
30
+ truncation=True,
31
+ padding='longest',
32
+ max_length=80,
33
+ return_tensors="pt"
34
+ ).to(torch_device)
35
+
36
+ translated = model.generate(
37
+ **batch,
38
+ num_beams=10,
39
+ num_return_sequences=num_return_sequences,
40
+ temperature=1.0,
41
+ repetition_penalty=2.8,
42
+ length_penalty=1.2,
43
+ max_length=80,
44
+ min_length=5,
45
+ no_repeat_ngram_size=3
46
+ )
47
+
48
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
49
+ return tgt_text[0]
50
+
51
+ def get_response_from_text(context, tokenizer, model, torch_device):
52
+ """Process entire text while preserving paragraph structure."""
53
+ paragraphs = split_into_paragraphs(context)
54
+ paraphrased_paragraphs = []
55
+
56
+ for paragraph in paragraphs:
57
+ sentences = split_into_sentences(paragraph)
58
+ paraphrased_sentences = []
59
+
60
+ for sentence in sentences:
61
+ if len(sentence.split()) < 3:
62
+ paraphrased_sentences.append(sentence)
63
+ continue
64
+
65
+ try:
66
+ paraphrased = get_response(sentence, 1, tokenizer, model, torch_device)
67
+ if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
68
+ paraphrased_sentences.append(paraphrased)
69
+ else:
70
+ paraphrased_sentences.append(sentence)
71
+ except Exception as e:
72
+ print(f"Error processing sentence: {e}")
73
+ paraphrased_sentences.append(sentence)
74
+
75
+ paraphrased_paragraphs.append(' '.join(paraphrased_sentences))
76
+
77
+ return '\n\n'.join(paraphrased_paragraphs)
78
+
79
+ def create_interface():
80
+ """Create and configure the Gradio interface"""
81
+ # Load model and tokenizer
82
+ tokenizer, model, torch_device = load_model()
83
+
84
+ def greet(context):
85
+ return get_response_from_text(context, tokenizer, model, torch_device)
86
+
87
+ # Create interface with improved styling
88
+ iface = gr.Interface(
89
+ fn=greet,
90
+ inputs=gr.Textbox(
91
+ lines=15,
92
+ label="Input Text",
93
+ placeholder="Enter your text here...",
94
+ elem_classes="input-text"
95
+ ),
96
+ outputs=gr.Textbox(
97
+ lines=15,
98
+ label="Paraphrased Text",
99
+ elem_classes="output-text"
100
+ ),
101
+ title="Advanced Text Paraphraser",
102
+ description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.",
103
+ theme="default",
104
+ css="""
105
+ .input-text, .output-text {
106
+ font-size: 16px !important;
107
+ font-family: Arial, sans-serif !important;
108
+ min-height: 300px !important;
109
+ }
110
+ """
111
+ )
112
+ return iface
113
+
114
+ if __name__ == "__main__":
115
+ # Create and launch the interface
116
+ interface = create_interface()
117
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==5.6.0
2
+ torch==2.2.2
3
+ transformers==4.45.2