File size: 5,653 Bytes
4b89d6b ea7f5b6 63b3783 ea7f5b6 960f419 92afc5b ee305a4 ea7f5b6 ee305a4 63b3783 ee305a4 63b3783 ee305a4 63b3783 ee305a4 63b3783 2493822 ee305a4 63b3783 ee305a4 63b3783 ee305a4 63b3783 ea7f5b6 960f419 ea7f5b6 92afc5b ea7f5b6 ee305a4 960f419 63b3783 5d9cd0b 63b3783 5d9cd0b ea7f5b6 ee305a4 63b3783 ea7f5b6 63b3783 ea7f5b6 63b3783 ea7f5b6 63b3783 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import nltk
nltk.download('stopwords')
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
import plotly.graph_objs as go
from transformers import pipeline
from matplotlib.colors import ListedColormap, rgb2hex
import random
import gradio as gr
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
from lcs import find_common_subsequences
from highlighter import highlight_common_words, highlight_common_words_dict
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
# Function for the Gradio interface
def model(prompt):
user_prompt = prompt
paraphrased_sentences = generate_paraphrase(user_prompt)
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
length_accepted_sentences = len(selected_sentences)
common_grams = find_common_subsequences(user_prompt, selected_sentences)
masked_sentences = []
masked_words = []
masked_logits = []
for sentence in paraphrased_sentences:
masked_sent, logits, words = mask_non_stopword(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = high_entropy_words(sentence, common_grams)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
sampled_sentences = []
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
print(len(sampled_sentences))
colors = ["red", "blue", "brown", "green"]
def select_color():
return random.choice(colors)
highlight_info = [(word, select_color()) for _, word in common_grams]
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
trees1 = []
trees2 = []
masked_index = 0
sampled_index = 0
for i, sentence in enumerate(paraphrased_sentences):
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
trees1.append(tree1)
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
trees2.append(tree2)
masked_index += 3
sampled_index += 12
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# **AIISC Watermarking Model**")
with gr.Row():
user_input = gr.Textbox(label="User Prompt")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
with gr.Row():
highlighted_user_prompt = gr.HTML()
with gr.Row():
with gr.Tabs():
with gr.TabItem("Paraphrased Sentences"):
highlighted_accepted_sentences = gr.HTML()
with gr.TabItem("Discarded Sentences"):
highlighted_discarded_sentences = gr.HTML()
# Adding labels before the tree plots
with gr.Row():
gr.Markdown("### Where to Watermark?") # Label for masked sentences trees
with gr.Row():
with gr.Tabs():
tree1_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree1 = gr.Plot()
tree1_tabs.append(tree1)
with gr.Row():
gr.Markdown("### How to Watermark?") # Label for sampled sentences trees
with gr.Row():
with gr.Tabs():
tree2_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree2 = gr.Plot()
tree2_tabs.append(tree2)
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
clear_button.click(lambda: "", inputs=None, outputs=user_input)
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
demo.launch(share=True)
|