File size: 13,452 Bytes
4b89d6b
 
ea7f5b6
 
2471de4
63b3783
ea7f5b6
0840f0a
436c4c1
92afc5b
ee305a4
 
436c4c1
 
 
 
0840f0a
 
 
2471de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee305a4
2471de4
 
 
 
ee305a4
2471de4
 
0840f0a
 
 
 
 
 
 
 
 
 
2471de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960f419
2471de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7f5b6
2471de4
 
 
 
 
 
 
 
 
ea7f5b6
2471de4
 
 
ea7f5b6
ee305a4
2471de4
ee305a4
2471de4
ee305a4
2471de4
 
 
 
 
63b3783
 
2471de4
 
63b3783
 
2471de4
 
 
 
 
ee305a4
63b3783
2471de4
 
63b3783
 
2471de4
 
 
 
 
436c4c1
 
2471de4
 
 
436c4c1
2471de4
 
 
 
 
436c4c1
2471de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import nltk
nltk.download('stopwords')
import random
import gradio as gr
import time
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
# from lcs import find_common_subsequences, find_common_gram_positions
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
from detectability import SentenceDetectabilityCalculator
from distortion import SentenceDistortionCalculator
from euclidean_distance import SentenceEuclideanDistanceCalculator
from threeD_plot import gen_three_D_plot

from twokenize import tokenize_sentences, tokenize_sentence
from non_melting_points import find_non_melting_points

class WatermarkingPipeline:
    def __init__(self):
        # Existing initialization code...
        self.user_prompt = None
        self.paraphrased_sentences = None
        self.analyzed_paraphrased_sentences = None
        self.selected_sentences = None
        self.discarded_sentences = None
        self.common_grams = None
        self.subsequences = None
        self.common_grams_position = None
        self.masked_sentences = None
        self.masked_words = None
        self.masked_logits = None
        self.sampled_sentences = None
        self.reparaphrased_sentences = None
        self.distortion_list = None
        self.detectability_list = None
        self.euclidean_dist_list = None
    
    def step1_paraphrasing(self, prompt, threshold=0.7):
        start_time = time.time()
        
        self.user_prompt = prompt
        self.paraphrased_sentences = generate_paraphrase(prompt)
        if self.paraphrased_sentences is None:
            return "Error in generating paraphrases", "Error: Could not complete step"
        
        self.analyzed_paraphrased_sentences, self.selected_sentences, self.discarded_sentences = \
            analyze_entailment(self.user_prompt, self.paraphrased_sentences, threshold)

        self.user_prompt_tokenized = tokenize_sentence(self.user_prompt)
        self.selected_sentences_tokenized = tokenize_sentences(self.selected_sentences)
        self.discarded_sentences_tokenized = tokenize_sentences(self.discarded_sentences)

        all_tokenized_sentences = []
        all_tokenized_sentences.append(self.user_prompt_tokenized)
        all_tokenized_sentences.extend(self.selected_sentences_tokenized)

        self.common_grams = find_non_melting_points(all_tokenized_sentences)
        
        highlighted_user_prompt = highlight_common_words(
            self.common_grams, [self.user_prompt], "Highlighted LCS in the User Prompt"
        )
        highlighted_accepted_sentences = highlight_common_words_dict(
            self.common_grams, self.selected_sentences, "Paraphrased Sentences"
        )
        highlighted_discarded_sentences = highlight_common_words_dict(
            self.common_grams, self.discarded_sentences, "Discarded Sentences"
        )
        
        execution_time = time.time() - start_time
        time_info = f"Step 1 completed in {execution_time:.2f} seconds"
        
        return [
            highlighted_user_prompt, 
            highlighted_accepted_sentences, 
            highlighted_discarded_sentences,
            time_info
        ]
        
    def step2_masking(self):
        start_time = time.time()
        
        if self.paraphrased_sentences is None:
            return [None] * 10 + ["Error: Please complete step 1 first"]

        # Existing step2 code...
        self.masked_sentences = []
        self.masked_words = []
        self.masked_logits = []

        for sentence in self.paraphrased_sentences:
            for mask_func in [mask_non_stopword, mask_non_stopword_pseudorandom, 
                            lambda s: high_entropy_words(s, self.common_grams)]:
                masked_sent, logits, words = mask_func(sentence)
                self.masked_sentences.append(masked_sent)
                self.masked_words.append(words)
                self.masked_logits.append(logits)

        trees = []
        masked_index = 0
        colors = ["red", "blue", "brown", "green"]
        highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]

        for i, sentence in enumerate(self.paraphrased_sentences):
            next_masked = self.masked_sentences[masked_index:masked_index + 3]
            tree = generate_subplot1(sentence, next_masked, highlight_info, self.common_grams)
            trees.append(tree)
            masked_index += 3

        execution_time = time.time() - start_time
        time_info = f"Step 2 completed in {execution_time:.2f} seconds"
        
        return trees + [time_info]

    def step3_sampling(self):
        start_time = time.time()
        
        if self.masked_sentences is None:
            return [None] * 10 + ["Error: Please complete step 2 first"]

        # Existing step3 code...
        self.sampled_sentences = []
        trees = []
        colors = ["red", "blue", "brown", "green"]
        highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]

        sampling_techniques = [
            ('inverse_transform', 1.0),
            ('exponential_minimum', 1.0),
            ('temperature', 1.0),
            ('greedy', 1.0)
        ]

        masked_index = 0
        while masked_index < len(self.masked_sentences):
            current_masked = self.masked_sentences[masked_index:masked_index + 3]
            current_words = self.masked_words[masked_index:masked_index + 3]
            current_logits = self.masked_logits[masked_index:masked_index + 3]
            
            batch_samples = []
            for masked_sent, words, logits in zip(current_masked, current_words, current_logits):
                for technique, temp in sampling_techniques:
                    sampled = sample_word(masked_sent, words, logits, 
                                       sampling_technique=technique, 
                                       temperature=temp)
                    batch_samples.append(sampled)
            
            self.sampled_sentences.extend(batch_samples)
            
            if current_masked:
                tree = generate_subplot2(
                    current_masked,
                    batch_samples,
                    highlight_info,
                    self.common_grams
                )
                trees.append(tree)
            
            masked_index += 3

        if len(trees) < 10:
            trees.extend([None] * (10 - len(trees)))
            
        execution_time = time.time() - start_time
        time_info = f"Step 3 completed in {execution_time:.2f} seconds"
        
        return trees[:10] + [time_info]
    
    def step4_reparaphrase(self):
        start_time = time.time()
        
        if self.sampled_sentences is None:
            return ["Error: Please complete step 3 first"] * 120 + ["Error: Please complete step 3 first"]
        
        # Existing step4 code...
        self.reparaphrased_sentences = []
        for i in range(13):
            self.reparaphrased_sentences.append(generate_paraphrase(self.sampled_sentences[i]))
            
        reparaphrased_sentences_list = []
        for i in range(0, len(self.reparaphrased_sentences), 10):
            batch = self.reparaphrased_sentences[i:i + 10]
            if len(batch) == 10:
                html_block = reparaphrased_sentences_html(batch)
                reparaphrased_sentences_list.append(html_block)
                
        execution_time = time.time() - start_time
        time_info = f"Step 4 completed in {execution_time:.2f} seconds"
        
        return reparaphrased_sentences_list + [time_info]
    
    def step5_metrics(self):
        start_time = time.time()
        
        if self.reparaphrased_sentences is None:
            return "Please complete step 4 first", "Error: Please complete step 4 first"
        
        # Existing step5 code...
        distortion_calculator = SentenceDistortionCalculator(self.user_prompt, self.reparaphrased_sentences)
        distortion_calculator.calculate_all_metrics()
        distortion_calculator.normalize_metrics()
        distortion_calculator.calculate_combined_distortion()
        distortion = distortion_calculator.get_combined_distortions()
        self.distortion_list = [each[1] for each in distortion.items()]
        
        detectability_calculator = SentenceDetectabilityCalculator(self.user_prompt, self.reparaphrased_sentences)
        detectability_calculator.calculate_all_metrics()
        detectability_calculator.normalize_metrics()
        detectability_calculator.calculate_combined_detectability()
        detectability = detectability_calculator.get_combined_detectabilities()
        self.detectability_list = [each[1] for each in detectability.items()]
        
        euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(self.user_prompt, self.reparaphrased_sentences)
        euclidean_dist_calculator.calculate_all_metrics()
        euclidean_dist_calculator.normalize_metrics()
        euclidean_dist = detectability_calculator.get_combined_detectabilities()
        self.euclidean_dist_list = [each[1] for each in euclidean_dist.items()]
        
        three_D_plot = gen_three_D_plot(
            self.detectability_list, 
            self.distortion_list, 
            self.euclidean_dist_list
        )
        
        execution_time = time.time() - start_time
        time_info = f"Step 5 completed in {execution_time:.2f} seconds"
        
        return three_D_plot, time_info

def create_gradio_interface():
    pipeline = WatermarkingPipeline()
    
    with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
        gr.Markdown("# **AIISC Watermarking Model**")
        
        with gr.Column():
            gr.Markdown("## Input Prompt")
            user_input = gr.Textbox(label="Enter Your Prompt")

        gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
        paraphrase_button = gr.Button("Generate Paraphrases")
        highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")

        with gr.Tabs():
            with gr.TabItem("Accepted Paraphrased Sentences"):
                highlighted_accepted_sentences = gr.HTML()
            with gr.TabItem("Discarded Paraphrased Sentences"):
                highlighted_discarded_sentences = gr.HTML()
        step1_time = gr.Textbox(label="Execution Time", interactive=False)
            
        gr.Markdown("## Step 2: Where to Mask?")
        masking_button = gr.Button("Apply Masking")
        gr.Markdown("### Masked Sentence Trees")
        with gr.Tabs():
            tree1_tabs = []
            for i in range(10):
                with gr.TabItem(f"Masked Sentence {i+1}"):
                    tree1 = gr.Plot()
                    tree1_tabs.append(tree1)
        step2_time = gr.Textbox(label="Execution Time", interactive=False)
            
        gr.Markdown("## Step 3: How to Mask?")
        sampling_button = gr.Button("Sample Words")
        gr.Markdown("### Sampled Sentence Trees")
        with gr.Tabs():
            tree2_tabs = []
            for i in range(10):
                with gr.TabItem(f"Sampled Sentence {i+1}"):
                    tree2 = gr.Plot()
                    tree2_tabs.append(tree2)
        step3_time = gr.Textbox(label="Execution Time", interactive=False)
        
        gr.Markdown("## Step 4: Re-paraphrasing")
        reparaphrase_button = gr.Button("Re-paraphrase")
        gr.Markdown("### Reparaphrased Sentences")
        with gr.Tabs():
            reparaphrased_sentences_tabs = []
            for i in range(120):
                with gr.TabItem(f"Reparaphrased Batch {i+1}"):
                    reparaphrased_sent_html = gr.HTML()
                    reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
        step4_time = gr.Textbox(label="Execution Time", interactive=False)
        
        gr.Markdown("## Step 5: Finding Sweet Spot")
        metrics_button = gr.Button("Calculate Metrics")
        gr.Markdown("### 3D Visualization of Metrics")
        three_D_plot = gr.Plot()
        step5_time = gr.Textbox(label="Execution Time", interactive=False)
        
        paraphrase_button.click(
            pipeline.step1_paraphrasing, 
            inputs=user_input, 
            outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences, step1_time]
        )
        
        masking_button.click(
            pipeline.step2_masking, 
            inputs=None, 
            outputs=tree1_tabs + [step2_time]
        )

        sampling_button.click(
            pipeline.step3_sampling, 
            inputs=None, 
            outputs=tree2_tabs + [step3_time],
            show_progress=True
        )
        
        reparaphrase_button.click(
            pipeline.step4_reparaphrase, 
            inputs=None, 
            outputs=reparaphrased_sentences_tabs + [step4_time]
        )
        
        metrics_button.click(
            pipeline.step5_metrics, 
            inputs=None, 
            outputs=[three_D_plot, step5_time]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True)