Spaces:
Sleeping
Sleeping
Commit
·
0860d85
1
Parent(s):
fbe88e2
update manual control to sliders
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import nltk
|
2 |
import spacy
|
3 |
-
nltk
|
|
|
4 |
spacy.cli.download('en_core_web_sm')
|
5 |
|
6 |
import torch
|
@@ -18,13 +18,11 @@ from sklearn.impute import IterativeImputer
|
|
18 |
from sklearn.linear_model import Ridge
|
19 |
|
20 |
|
21 |
-
def process_examples(samples
|
22 |
processed = []
|
23 |
for sample in samples:
|
24 |
-
|
25 |
-
|
26 |
-
pd.DataFrame({'Index': full_names, 'Source': sample['sentence1_ling'], 'Target': sample['sentence2_ling']})
|
27 |
-
])
|
28 |
return processed
|
29 |
|
30 |
args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt')
|
@@ -34,7 +32,9 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
34 |
|
35 |
lng_names = [name_map[x] for x in lng_names]
|
36 |
examples = json.load(open('assets/examples.json'))
|
37 |
-
|
|
|
|
|
38 |
|
39 |
stats = json.load(open('assets/stats.json'))
|
40 |
|
@@ -86,10 +86,10 @@ def visibility(mode):
|
|
86 |
output.append(gr.update(visible=False))
|
87 |
return output
|
88 |
|
89 |
-
def generate(sent1,
|
90 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
91 |
-
ling1 = scaler.transform([
|
92 |
-
ling2 = scaler.transform([
|
93 |
inputs = {'sentence1_input_ids': input_ids,
|
94 |
'sentence1_ling': torch.tensor(ling1).float().to(device),
|
95 |
'sentence2_ling': torch.tensor(ling2).float().to(device),
|
@@ -102,13 +102,37 @@ def generate(sent1, ling):
|
|
102 |
|
103 |
return pred
|
104 |
|
105 |
-
def
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
|
|
|
|
|
|
109 |
|
|
|
|
|
|
|
|
|
110 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
111 |
-
ling2 = torch.tensor(scaler.transform([
|
112 |
inputs = {
|
113 |
'sentence1_input_ids': input_ids,
|
114 |
'sentence2_ling': ling2,
|
@@ -118,20 +142,23 @@ def generate_with_feedback(sent1, ling, approx):
|
|
118 |
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer)
|
119 |
|
120 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
121 |
-
|
|
|
122 |
|
123 |
-
def generate_random(sent1,
|
|
|
|
|
124 |
preds, interpolations = [], []
|
125 |
for c in range(count):
|
126 |
idx = np.random.randint(0, len(ling_collection))
|
127 |
ling_ex = ling_collection[idx]
|
128 |
-
|
129 |
-
pred, interpolation =
|
130 |
preds.append(pred)
|
131 |
interpolations.append(interpolation)
|
132 |
-
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
133 |
|
134 |
-
def estimate_gen(sent1, sent2,
|
135 |
if 'approximate' in approx:
|
136 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
137 |
with torch.no_grad():
|
@@ -143,13 +170,12 @@ def estimate_gen(sent1, sent2, ling, approx):
|
|
143 |
raise ValueError()
|
144 |
|
145 |
ling_pred = round_ling(ling_pred)
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
return results
|
151 |
|
152 |
-
def estimate_tgt(sent2,
|
153 |
if 'approximate' in approx:
|
154 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
155 |
with torch.no_grad():
|
@@ -161,10 +187,10 @@ def estimate_tgt(sent2, ling, approx):
|
|
161 |
raise ValueError()
|
162 |
|
163 |
ling_pred = round_ling(ling_pred)
|
164 |
-
|
165 |
-
return
|
166 |
|
167 |
-
def estimate_src(sent1,
|
168 |
if 'approximate' in approx:
|
169 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
170 |
with torch.no_grad():
|
@@ -175,51 +201,41 @@ def estimate_src(sent1, ling, approx):
|
|
175 |
else:
|
176 |
raise ValueError()
|
177 |
|
178 |
-
|
179 |
-
return
|
180 |
|
181 |
-
def
|
182 |
-
ling['Target'] = scaler.inverse_transform([np.random.randn(*ling['Target'].shape)])[0]
|
183 |
-
return ling
|
184 |
-
|
185 |
-
def rand_ex_target(ling):
|
186 |
idx = np.random.randint(0, len(ling_collection))
|
187 |
ling_ex = ling_collection[idx]
|
188 |
-
|
189 |
-
return
|
190 |
-
|
191 |
-
def
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
196 |
scale_stepsize = np.random.uniform(1.0, 5.0)
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
203 |
scale_stepsize = np.random.uniform(1.0, 5.0)
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
def impute(ling):
|
210 |
-
ling['Target'] = ling['Target'].replace("", np.nan)
|
211 |
-
ling['Target'] = scaler.transform([ling['Target']])[0]
|
212 |
-
estimator = Ridge(alpha=1e3, fit_intercept=False)
|
213 |
-
imputer = IterativeImputer(estimator=estimator, imputation_order='random', max_iter=100)
|
214 |
|
215 |
-
combined_matrix = np.vstack([ling_collection, ling['Target']])
|
216 |
-
interpolated_matrix = imputer.fit_transform(combined_matrix)
|
217 |
-
interpolated_vector = interpolated_matrix[-1]
|
218 |
-
|
219 |
-
interp_raw = scaler.inverse_transform([interpolated_vector])[0]
|
220 |
-
|
221 |
-
ling['Target'] = round_ling(interp_raw)
|
222 |
-
return ling
|
223 |
|
224 |
title = """
|
225 |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
|
@@ -228,46 +244,44 @@ title = """
|
|
228 |
The model can generate diverse paraphrases of a given sentence, each adjusted to maintain consistent meaning while varying
|
229 |
in linguistic complexity according to the desired level.</p>
|
230 |
<p style="font-size:1.2em;">It is important to note that not all index combinations are feasible (e.g., a sentence of "length" 5 with 10 "unique words").
|
231 |
-
To ensure high-quality outputs, our approach
|
232 |
achievable set of indices for the given target.</p>
|
233 |
"""
|
234 |
|
235 |
guide = """
|
236 |
-
|
237 |
-
|
238 |
-
**
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
**Complexity-Matched Paraphrasing**:
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
-
|
248 |
-
generated text. We provided a set of tools for manual adjustments of the desired linguistic complexity of
|
249 |
-
the target sentence. These tools enable the user to extract linguistic indices from a given sentence,
|
250 |
-
generate a random (yet coherent) set of linguistic indices, and add or subtract to them.
|
251 |
-
These tools are designed for experimental use and require the user to possess linguistic expertise for
|
252 |
-
effective input of linguistic indices. To use these tools, select "Tools to assist in setting linguistic
|
253 |
-
indices." Once indices are entered, click "Generate."
|
254 |
-
|
255 |
-
|
256 |
-
Second, you may select to use exact or approximate computation of linguistic indices. Approximate computation is significantly faster.
|
257 |
-
|
258 |
-
Third, you may view the intermediate sentences of the quality control process by selecting the checkbox under "Advanced Options".
|
259 |
-
|
260 |
-
Fourth, you may try out some examples by clicking on "Examples...". Examples consist of a source sentences
|
261 |
-
and a sample set of target linguistic indices.
|
262 |
-
|
263 |
-
Please make your choice below.
|
264 |
|
|
|
|
|
|
|
|
|
|
|
265 |
"""
|
266 |
|
267 |
-
|
268 |
-
ling = gr.Dataframe(value = [[x, 0, 0] for x in lng_names],
|
269 |
-
headers=['Index', 'Source', 'Target'],
|
270 |
-
datatype=['str', 'number', 'number'], visible=False)
|
271 |
css = """
|
272 |
#guide span.svelte-1w6vloh {font-size: 22px !important; font-weight: 600 !important}
|
273 |
#mode span.svelte-1gfkn6j {font-size: 18px !important; font-weight: 600 !important}
|
@@ -300,61 +314,164 @@ body {
|
|
300 |
background-color: #000; /* Adjust the color as needed */
|
301 |
margin-bottom: 20px; /* Adjust the margin as needed */
|
302 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
"""
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
with gr.Blocks(
|
306 |
theme=gr.themes.Default(
|
307 |
spacing_size=gr.themes.sizes.spacing_md,
|
308 |
text_size=gr.themes.sizes.text_md,
|
309 |
),
|
310 |
css=css) as demo:
|
311 |
-
|
|
|
312 |
gr.Markdown(title)
|
|
|
|
|
313 |
with gr.Accordion("🚀 Quick Start Guide", open=False, elem_id='guide'):
|
314 |
gr.Markdown(guide)
|
315 |
|
316 |
with gr.Group(elem_classes='top-separator'):
|
317 |
pass
|
|
|
|
|
318 |
with gr.Group(elem_id='mode'):
|
319 |
mode = gr.Radio(
|
320 |
-
value='
|
321 |
label='Operation Modes',
|
322 |
type="index",
|
323 |
-
choices=['🔄
|
324 |
'⚖️ Complexity-Matched Paraphrasing',
|
325 |
'🎛️ Manual Linguistic Control'],
|
326 |
)
|
327 |
with gr.Accordion("⚙️ Advanced Options", open=False):
|
|
|
328 |
approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)',
|
329 |
choices=['Use approximate computation of linguistic indices (faster)',
|
330 |
'Use exact computation of linguistic indices'], container=False, show_label=False)
|
331 |
control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices')
|
332 |
|
333 |
-
with gr.Accordion("📑 Examples...", open=False):
|
334 |
-
gr.Examples(examples, [sent1, ling], examples_per_page=4, label=None)
|
335 |
|
|
|
336 |
with gr.Row():
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
with gr.Column():
|
339 |
sent2 = gr.Textbox(label='Generated text')
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
with gr.
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
|
359 |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
|
360 |
impute_btn = gr.Button("Impute Missing Values", size='lg', visible=False)
|
@@ -362,33 +479,195 @@ with gr.Blocks(
|
|
362 |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
|
363 |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
|
364 |
with gr.Row():
|
365 |
-
sub_btn = gr.Button('
|
366 |
-
add_btn = gr.Button('
|
367 |
with gr.Row():
|
368 |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence →", visible=False)
|
369 |
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False, elem_id='estimate')
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
|
375 |
-
estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
|
376 |
-
rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
|
377 |
-
impute_btn.click(impute, inputs=[ling], outputs=[ling])
|
378 |
-
copy_btn.click(copy, inputs=[ling], outputs=[ling])
|
379 |
-
generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
|
380 |
-
generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
|
381 |
-
outputs=[sent2, interpolation, ling])
|
382 |
-
add_btn.click(add, inputs=[ling], outputs=[ling])
|
383 |
-
sub_btn.click(sub, inputs=[ling], outputs=[ling])
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
group1 = [generate_random_btn, count]
|
386 |
group2 = [estimate_gen_btn, sent_ling_gen]
|
387 |
-
group3 = [generate_btn, estimate_src_btn, impute_btn, estimate_tgt_btn, sent_ling_est,
|
|
|
388 |
components = group1 + group2 + group3
|
|
|
389 |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
|
390 |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
|
391 |
outputs=[interpolation])
|
392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
print('Finished loading')
|
394 |
demo.launch(share=True)
|
|
|
|
|
1 |
import spacy
|
2 |
+
import nltk
|
3 |
+
nltk.download('wordnet', quiet=True)
|
4 |
spacy.cli.download('en_core_web_sm')
|
5 |
|
6 |
import torch
|
|
|
18 |
from sklearn.linear_model import Ridge
|
19 |
|
20 |
|
21 |
+
def process_examples(samples):
|
22 |
processed = []
|
23 |
for sample in samples:
|
24 |
+
example = [sample['sentence1']] + [str(x) for x in sample['sentence1_ling']] + sample['sentence2_ling']
|
25 |
+
processed.append(example)
|
|
|
|
|
26 |
return processed
|
27 |
|
28 |
args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt')
|
|
|
32 |
|
33 |
lng_names = [name_map[x] for x in lng_names]
|
34 |
examples = json.load(open('assets/examples.json'))
|
35 |
+
example_ids = [44, 148, 86, 96, 98, 62, 114, 138]
|
36 |
+
examples = [examples[i] for i in example_ids]
|
37 |
+
examples = process_examples(examples)
|
38 |
|
39 |
stats = json.load(open('assets/stats.json'))
|
40 |
|
|
|
86 |
output.append(gr.update(visible=False))
|
87 |
return output
|
88 |
|
89 |
+
def generate(sent1, ling_dict):
|
90 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
91 |
+
ling1 = scaler.transform([ling_dict['Source']])
|
92 |
+
ling2 = scaler.transform([ling_dict['Target']])
|
93 |
inputs = {'sentence1_input_ids': input_ids,
|
94 |
'sentence1_ling': torch.tensor(ling1).float().to(device),
|
95 |
'sentence2_ling': torch.tensor(ling2).float().to(device),
|
|
|
102 |
|
103 |
return pred
|
104 |
|
105 |
+
def impute_targets():
|
106 |
+
target_values = []
|
107 |
+
for i in range(len(shared_state.target)):
|
108 |
+
if i in shared_state.active_indices:
|
109 |
+
target_values.append(shared_state.target[i])
|
110 |
+
else:
|
111 |
+
target_values.append(np.nan)
|
112 |
+
|
113 |
+
target_values = np.array(target_values)
|
114 |
+
target_values_scaled = scaler.transform([target_values])[0]
|
115 |
+
estimator = Ridge(alpha=1e3, fit_intercept=False)
|
116 |
+
imputer = IterativeImputer(estimator=estimator, imputation_order='random', max_iter=100)
|
117 |
+
|
118 |
+
combined_matrix = np.vstack([ling_collection_scaled, target_values_scaled])
|
119 |
+
interpolated_matrix = imputer.fit_transform(combined_matrix)
|
120 |
+
interpolated_vector = interpolated_matrix[-1]
|
121 |
+
interp_raw = scaler.inverse_transform([interpolated_vector])[0]
|
122 |
+
|
123 |
+
shared_state.target = round_ling(interp_raw).tolist()
|
124 |
+
return shared_state.target
|
125 |
|
126 |
+
def generate_with_feedback(sent1, approx):
|
127 |
+
if sent1 == '':
|
128 |
+
raise gr.Error('Please input a source text.')
|
129 |
|
130 |
+
# First impute any inactive targets
|
131 |
+
if len(shared_state.active_indices) < len(shared_state.target):
|
132 |
+
impute_targets()
|
133 |
+
|
134 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
135 |
+
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
|
136 |
inputs = {
|
137 |
'sentence1_input_ids': input_ids,
|
138 |
'sentence2_ling': ling2,
|
|
|
142 |
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer)
|
143 |
|
144 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
145 |
+
# Return both the generation results and the updated slider values
|
146 |
+
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
|
147 |
|
148 |
+
def generate_random(sent1, count, approx):
|
149 |
+
if sent1 == '':
|
150 |
+
raise gr.Error('Please input a source text.')
|
151 |
preds, interpolations = [], []
|
152 |
for c in range(count):
|
153 |
idx = np.random.randint(0, len(ling_collection))
|
154 |
ling_ex = ling_collection[idx]
|
155 |
+
shared_state.target = ling_ex.copy()
|
156 |
+
pred, interpolation = generate_with_feedback(sent1, approx)
|
157 |
preds.append(pred)
|
158 |
interpolations.append(interpolation)
|
159 |
+
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
160 |
|
161 |
+
def estimate_gen(sent1, sent2, approx):
|
162 |
if 'approximate' in approx:
|
163 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
164 |
with torch.no_grad():
|
|
|
170 |
raise ValueError()
|
171 |
|
172 |
ling_pred = round_ling(ling_pred)
|
173 |
+
shared_state.target = ling_pred.copy()
|
174 |
+
|
175 |
+
gen = generate_with_feedback(sent1, approx)
|
176 |
+
return gen[0], gen[1], [gr.update(value=val) for val in shared_state.target]
|
|
|
177 |
|
178 |
+
def estimate_tgt(sent2, ling_dict, approx):
|
179 |
if 'approximate' in approx:
|
180 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
181 |
with torch.no_grad():
|
|
|
187 |
raise ValueError()
|
188 |
|
189 |
ling_pred = round_ling(ling_pred)
|
190 |
+
ling_dict['Target'] = ling_pred
|
191 |
+
return ling_dict
|
192 |
|
193 |
+
def estimate_src(sent1, ling_dict, approx):
|
194 |
if 'approximate' in approx:
|
195 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
196 |
with torch.no_grad():
|
|
|
201 |
else:
|
202 |
raise ValueError()
|
203 |
|
204 |
+
ling_dict['Source'] = ling_pred
|
205 |
+
return ling_dict
|
206 |
|
207 |
+
def rand_ex_target():
|
|
|
|
|
|
|
|
|
208 |
idx = np.random.randint(0, len(ling_collection))
|
209 |
ling_ex = ling_collection[idx]
|
210 |
+
shared_state.target = ling_ex.copy()
|
211 |
+
return [gr.update(value=val) for val in shared_state.target]
|
212 |
+
|
213 |
+
def copy_source_to_target():
|
214 |
+
if "" in shared_state.source:
|
215 |
+
raise gr.Error("Source linguistic features not initialized. Please estimate them first.")
|
216 |
+
shared_state.target = shared_state.source.copy()
|
217 |
+
return [gr.update(value=val) for val in shared_state.target]
|
218 |
+
|
219 |
+
def add_to_target():
|
220 |
+
if not shared_state.active_indices:
|
221 |
+
raise gr.Error("No features are activated. Please activate features to modify.")
|
222 |
scale_stepsize = np.random.uniform(1.0, 5.0)
|
223 |
+
new_targets = np.array(shared_state.target)
|
224 |
+
for i in shared_state.active_indices:
|
225 |
+
new_targets[i] += scale_stepsize * scale_ratio[i]
|
226 |
+
shared_state.target = round_ling(new_targets).tolist()
|
227 |
+
return [gr.update(value=val) for val in shared_state.target]
|
228 |
+
|
229 |
+
def subtract_from_target():
|
230 |
+
if not shared_state.active_indices:
|
231 |
+
raise gr.Error("No features are activated. Please activate features to modify.")
|
232 |
scale_stepsize = np.random.uniform(1.0, 5.0)
|
233 |
+
new_targets = np.array(shared_state.target)
|
234 |
+
for i in shared_state.active_indices:
|
235 |
+
new_targets[i] -= scale_stepsize * scale_ratio[i]
|
236 |
+
shared_state.target = round_ling(new_targets).tolist()
|
237 |
+
return [gr.update(value=val) for val in shared_state.target]
|
|
|
|
|
|
|
|
|
|
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
title = """
|
241 |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
|
|
|
244 |
The model can generate diverse paraphrases of a given sentence, each adjusted to maintain consistent meaning while varying
|
245 |
in linguistic complexity according to the desired level.</p>
|
246 |
<p style="font-size:1.2em;">It is important to note that not all index combinations are feasible (e.g., a sentence of "length" 5 with 10 "unique words").
|
247 |
+
To ensure high-quality outputs, our approach compares the initial generation with the target linguistic indices, and performs iterative refinement to match the closest, yet coherent
|
248 |
achievable set of indices for the given target.</p>
|
249 |
"""
|
250 |
|
251 |
guide = """
|
252 |
+
1. **Select Operation Mode**: Choose from the available modes:
|
253 |
+
- **Linguistically-diverse Paraphrase Generation**: Generate diverse paraphrases.
|
254 |
+
- **Steps**:
|
255 |
+
1. Enter the source text in the provided textbox.
|
256 |
+
2. Specify the number of paraphrases you want.
|
257 |
+
3. Click "Generate" to produce paraphrases with varying linguistic complexity.
|
258 |
+
- **Complexity-Matched Paraphrasing**: Match the complexity of the input text.
|
259 |
+
- **Steps**:
|
260 |
+
1. Enter the source text in the provided textbox.
|
261 |
+
2. Provide another sentence to extract linguistic indices.
|
262 |
+
3. Click "Generate" to produce a paraphrase matching the complexity of the given sentence.
|
263 |
+
- **Manual Linguistic Control**: Manually adjust linguistic features using sliders.
|
264 |
+
- **Steps**:
|
265 |
+
1. Enter the source text in the provided textbox.
|
266 |
+
2. Activate or deactivate features of interest using the checkboxes.
|
267 |
+
3. Use the sliders to adjust linguistic features.
|
268 |
+
4. **Use Tools**: Access additional tools under "Tools to assist in setting linguistic indices" for advanced control.
|
269 |
+
- **Impute Missing Values**: Automatically fill inactive features.
|
270 |
+
- **Random Target**: Generate a random set of linguistic indices.
|
271 |
+
- **Copy Source to Target**: Copy linguistic indices from the source to the target.
|
272 |
+
- **Add/Subtract Complexity**: Adjust the complexity of the target indices.
|
273 |
+
5. Click "Generate" to produce the output text based on the adjusted features.
|
274 |
|
275 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
+
# Updated Advanced Options Description
|
278 |
+
advanced_options_description = """
|
279 |
+
**Advanced Options**:
|
280 |
+
- **Approximate vs. Exact Computation**: Choose between faster approximate computation or more precise exact computation of linguistic indices.
|
281 |
+
- **View Intermediate Generations**: Enable this option to see the intermediate sentences generated during the quality control process.
|
282 |
"""
|
283 |
|
284 |
+
|
|
|
|
|
|
|
285 |
css = """
|
286 |
#guide span.svelte-1w6vloh {font-size: 22px !important; font-weight: 600 !important}
|
287 |
#mode span.svelte-1gfkn6j {font-size: 18px !important; font-weight: 600 !important}
|
|
|
314 |
background-color: #000; /* Adjust the color as needed */
|
315 |
margin-bottom: 20px; /* Adjust the margin as needed */
|
316 |
}
|
317 |
+
|
318 |
+
.features-container {
|
319 |
+
border: 1px solid rgba(0, 0, 0, 0.1);
|
320 |
+
border-radius: 8px;
|
321 |
+
background: white;
|
322 |
+
}
|
323 |
+
|
324 |
+
/* Style the inner column to be scrollable */
|
325 |
+
.features-container > div > .column {
|
326 |
+
max-height: 400px;
|
327 |
+
overflow-y: scroll;
|
328 |
+
padding: 10px;
|
329 |
+
}
|
330 |
+
|
331 |
+
/* Scrollbar styles now apply to the inner column */
|
332 |
+
.features-container > div > .column::-webkit-scrollbar {
|
333 |
+
width: 8px;
|
334 |
+
}
|
335 |
+
|
336 |
+
.features-container > div > .column::-webkit-scrollbar-track {
|
337 |
+
background: #f1f1f1;
|
338 |
+
border-radius: 4px;
|
339 |
+
}
|
340 |
+
|
341 |
+
.features-container > div > .column::-webkit-scrollbar-thumb {
|
342 |
+
background: #888;
|
343 |
+
border-radius: 4px;
|
344 |
+
}
|
345 |
+
|
346 |
+
.features-container > div > .column::-webkit-scrollbar-thumb:hover {
|
347 |
+
background: #555;
|
348 |
+
}
|
349 |
+
|
350 |
+
.features-container .label-wrap span {
|
351 |
+
font-weight: 600;
|
352 |
+
font-size: 18px;
|
353 |
+
}
|
354 |
"""
|
355 |
|
356 |
+
sent1 = gr.Textbox(label='Source text')
|
357 |
+
ling_sliders = []
|
358 |
+
ling_dict = {'Source': [""] * len(lng_names), 'Target': [0] * len(lng_names)}
|
359 |
+
active_indices = []
|
360 |
+
target_sliders = []
|
361 |
+
source_values = []
|
362 |
+
active_checkboxes = []
|
363 |
+
for i in range(len(lng_names)):
|
364 |
+
source_values.append(gr.Textbox(placeholder="Not initialized",
|
365 |
+
lines=1, label="Source", interactive=False,
|
366 |
+
container=False, scale=1))
|
367 |
+
active_checkboxes.append(gr.Checkbox(label="Activate", value=False))
|
368 |
+
target_sliders.append(
|
369 |
+
gr.Slider(
|
370 |
+
minimum=stats['min'][i],
|
371 |
+
maximum=stats['max'][i],
|
372 |
+
value=stats['min'][i],
|
373 |
+
step=0.001 if not stats['is_int'][i] else 1,
|
374 |
+
label=None,
|
375 |
+
interactive=False
|
376 |
+
)
|
377 |
+
)
|
378 |
+
|
379 |
+
# Move SharedState class and instance to top
|
380 |
+
class SharedState:
|
381 |
+
def __init__(self, n_features):
|
382 |
+
self.source = [""] * n_features
|
383 |
+
self.target = [0] * n_features
|
384 |
+
self.active_indices = set()
|
385 |
+
|
386 |
+
def update_target(self, index, value):
|
387 |
+
self.target[index] = value
|
388 |
+
return self.target.copy()
|
389 |
+
|
390 |
+
def update_source(self, index, value):
|
391 |
+
self.source[index] = value
|
392 |
+
return self.source.copy()
|
393 |
+
|
394 |
+
def toggle_active(self, index, value):
|
395 |
+
if value:
|
396 |
+
self.active_indices.add(index)
|
397 |
+
else:
|
398 |
+
self.active_indices.discard(index)
|
399 |
+
return list(self.active_indices)
|
400 |
+
|
401 |
+
def get_state(self):
|
402 |
+
return {
|
403 |
+
'Source': self.source.copy(),
|
404 |
+
'Target': self.target.copy(),
|
405 |
+
'active_indices': list(self.active_indices)
|
406 |
+
}
|
407 |
+
|
408 |
+
shared_state = SharedState(len(lng_names))
|
409 |
+
|
410 |
with gr.Blocks(
|
411 |
theme=gr.themes.Default(
|
412 |
spacing_size=gr.themes.sizes.spacing_md,
|
413 |
text_size=gr.themes.sizes.text_md,
|
414 |
),
|
415 |
css=css) as demo:
|
416 |
+
# Header
|
417 |
+
gr.Image('assets/logo.png', height=100, container=False, show_download_button=False, show_fullscreen_button=False)
|
418 |
gr.Markdown(title)
|
419 |
+
|
420 |
+
# Guide
|
421 |
with gr.Accordion("🚀 Quick Start Guide", open=False, elem_id='guide'):
|
422 |
gr.Markdown(guide)
|
423 |
|
424 |
with gr.Group(elem_classes='top-separator'):
|
425 |
pass
|
426 |
+
|
427 |
+
# Mode Selection
|
428 |
with gr.Group(elem_id='mode'):
|
429 |
mode = gr.Radio(
|
430 |
+
value='Linguistically-diverse Paraphrase Generation',
|
431 |
label='Operation Modes',
|
432 |
type="index",
|
433 |
+
choices=['🔄 Linguistically-diverse Paraphrase Generation',
|
434 |
'⚖️ Complexity-Matched Paraphrasing',
|
435 |
'🎛️ Manual Linguistic Control'],
|
436 |
)
|
437 |
with gr.Accordion("⚙️ Advanced Options", open=False):
|
438 |
+
gr.Markdown(advanced_options_description)
|
439 |
approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)',
|
440 |
choices=['Use approximate computation of linguistic indices (faster)',
|
441 |
'Use exact computation of linguistic indices'], container=False, show_label=False)
|
442 |
control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices')
|
443 |
|
|
|
|
|
444 |
|
445 |
+
# Main Input/Output
|
446 |
with gr.Row():
|
447 |
+
with gr.Column():
|
448 |
+
sent1.render()
|
449 |
+
|
450 |
+
count = gr.Number(label='Number of generated sentences', value=3, precision=0, scale=1, visible=True)
|
451 |
+
|
452 |
+
sent_ling_gen = gr.Textbox(label='Copy the style of this sentence', scale=1, visible=False)
|
453 |
+
|
454 |
+
|
455 |
with gr.Column():
|
456 |
sent2 = gr.Textbox(label='Generated text')
|
457 |
+
generate_random_btn = gr.Button("Generate", variant='primary', scale=1, visible=True)
|
458 |
+
estimate_gen_btn = gr.Button("Generate", variant='primary', scale=1, visible=False)
|
459 |
+
generate_btn = gr.Button("Generate", variant='primary', visible=False)
|
460 |
+
# Linguistic Features Container
|
461 |
+
with gr.Accordion("Linguistic Features", elem_classes="features-container", open=True, visible=False) as ling_features:
|
462 |
+
with gr.Row():
|
463 |
+
select_all_btn = gr.Button("Activate All", size='sm')
|
464 |
+
unselect_all_btn = gr.Button("Deactivate All", size='sm')
|
465 |
+
|
466 |
+
for i, name in enumerate(lng_names):
|
467 |
+
with gr.Row():
|
468 |
+
feature_name = gr.Textbox(name, lines=1, label="Feature", container=False, show_label=False, interactive=False)
|
469 |
+
source_values[i].render()
|
470 |
+
active_checkboxes[i].render()
|
471 |
+
target_sliders[i].interactive = False
|
472 |
+
target_sliders[i].render()
|
473 |
+
ling_sliders.append((feature_name, source_values[i], target_sliders[i], active_checkboxes[i], i))
|
474 |
+
# Tools Accordion
|
475 |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
|
476 |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
|
477 |
impute_btn = gr.Button("Impute Missing Values", size='lg', visible=False)
|
|
|
479 |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
|
480 |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
|
481 |
with gr.Row():
|
482 |
+
sub_btn = gr.Button('Decrease target complexity by \u03B5', visible=False)
|
483 |
+
add_btn = gr.Button('Increase target complexity by \u03B5', visible=False)
|
484 |
with gr.Row():
|
485 |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence →", visible=False)
|
486 |
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False, elem_id='estimate')
|
487 |
+
interpolation = gr.Textbox(label='Quality control interpolation', visible=False, lines=5)
|
488 |
+
|
489 |
+
with gr.Group(elem_classes='bottom-separator'):
|
490 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
+
# Examples
|
493 |
+
def load_example(example_text, *values):
|
494 |
+
# Split values into source, target, and active values
|
495 |
+
n = len(lng_names)
|
496 |
+
source_values = values[:n]
|
497 |
+
target_values = values[n:]
|
498 |
+
|
499 |
+
# Update shared state
|
500 |
+
shared_state.source = [float(x) for x in source_values]
|
501 |
+
shared_state.target = list(target_values)
|
502 |
+
shared_state.active_indices = set(range(n)) # Activate all indices
|
503 |
+
|
504 |
+
# Return updates for all components:
|
505 |
+
return [True] * n
|
506 |
+
|
507 |
+
gr.Examples(
|
508 |
+
examples=examples,
|
509 |
+
inputs=[sent1] + source_values + target_sliders,
|
510 |
+
outputs=active_checkboxes,
|
511 |
+
example_labels=[ex[0] for ex in examples],
|
512 |
+
fn=load_example,
|
513 |
+
run_on_click=True,
|
514 |
+
)
|
515 |
+
|
516 |
+
|
517 |
+
# Add select/unselect all handlers
|
518 |
+
def select_all():
|
519 |
+
for i in range(len(lng_names)):
|
520 |
+
shared_state.toggle_active(i, True)
|
521 |
+
return [True] * len(lng_names) + [gr.update(interactive=True)] * len(lng_names)
|
522 |
+
|
523 |
+
def unselect_all():
|
524 |
+
shared_state.active_indices.clear()
|
525 |
+
return [False] * len(lng_names) + [gr.update(interactive=False)] * len(lng_names)
|
526 |
+
|
527 |
+
select_all_btn.click(
|
528 |
+
fn=select_all,
|
529 |
+
outputs=active_checkboxes + [slider for _, _, slider, _, _ in ling_sliders]
|
530 |
+
)
|
531 |
+
|
532 |
+
unselect_all_btn.click(
|
533 |
+
fn=unselect_all,
|
534 |
+
outputs=active_checkboxes + [slider for _, _, slider, _, _ in ling_sliders]
|
535 |
+
)
|
536 |
+
|
537 |
+
def update_slider(slider_index, new_value):
|
538 |
+
shared_state.target[slider_index] = new_value
|
539 |
+
|
540 |
+
def update_checkbox(checkbox_index, new_value):
|
541 |
+
shared_state.toggle_active(checkbox_index, new_value)
|
542 |
+
return gr.update(interactive=new_value)
|
543 |
+
|
544 |
+
# Update the event bindings
|
545 |
+
for feature_name, source_value, target_slider, active_checkbox, i in ling_sliders:
|
546 |
+
target_slider.change(
|
547 |
+
fn=update_slider,
|
548 |
+
inputs=[gr.Number(i, visible=False), target_slider],
|
549 |
+
)
|
550 |
+
active_checkbox.change(
|
551 |
+
fn=update_checkbox,
|
552 |
+
inputs=[gr.Number(i, visible=False), active_checkbox],
|
553 |
+
outputs=target_slider
|
554 |
+
)
|
555 |
+
|
556 |
+
# Define groups and visibility
|
557 |
group1 = [generate_random_btn, count]
|
558 |
group2 = [estimate_gen_btn, sent_ling_gen]
|
559 |
+
group3 = [generate_btn, estimate_src_btn, impute_btn, estimate_tgt_btn, sent_ling_est,
|
560 |
+
rand_ex_btn, copy_btn, add_btn, sub_btn, ling_features, ling_tools]
|
561 |
components = group1 + group2 + group3
|
562 |
+
|
563 |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
|
564 |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
|
565 |
outputs=[interpolation])
|
566 |
|
567 |
+
def update_sliders_from_state(ling_state, slider_indices):
|
568 |
+
updates = []
|
569 |
+
for i in slider_indices:
|
570 |
+
updates.append(str(ling_state['Source'][i]))
|
571 |
+
updates.append(ling_state['Target'][i])
|
572 |
+
updates.append(gr.update(value=True))
|
573 |
+
return updates
|
574 |
+
|
575 |
+
def update_sliders_from_estimate(approx, sent_for_estimate):
|
576 |
+
if 'approximate' in approx:
|
577 |
+
input_ids = tokenizer.encode(sent_for_estimate, return_tensors='pt').to(device)
|
578 |
+
with torch.no_grad():
|
579 |
+
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy()
|
580 |
+
ling_pred = scaler.inverse_transform(ling_pred)[0]
|
581 |
+
elif 'exact' in approx:
|
582 |
+
ling_pred = np.array(compute_lng(sent_for_estimate))[used_indices]
|
583 |
+
else:
|
584 |
+
raise ValueError()
|
585 |
+
|
586 |
+
ling_pred = round_ling(ling_pred)
|
587 |
+
shared_state.source = ling_pred.copy()
|
588 |
+
shared_state.target = ling_pred.copy()
|
589 |
+
|
590 |
+
# Return updates separately for each type of component
|
591 |
+
return ling_pred + [True] * len(lng_names)
|
592 |
+
|
593 |
+
def update_sliders_from_source(approx, source_sent):
|
594 |
+
if 'approximate' in approx:
|
595 |
+
input_ids = tokenizer.encode(source_sent, return_tensors='pt').to(device)
|
596 |
+
with torch.no_grad():
|
597 |
+
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy()
|
598 |
+
ling_pred = scaler.inverse_transform(ling_pred)[0]
|
599 |
+
elif 'exact' in approx:
|
600 |
+
ling_pred = np.array(compute_lng(source_sent))[used_indices]
|
601 |
+
else:
|
602 |
+
raise ValueError()
|
603 |
+
|
604 |
+
ling_pred = round_ling(ling_pred)
|
605 |
+
shared_state.source = ling_pred.copy()
|
606 |
+
|
607 |
+
return [str(ling_pred[i]) for i in range(len(lng_names))]
|
608 |
+
|
609 |
+
slider_indices = [i for _, _, _, _, i in ling_sliders]
|
610 |
+
slider_updates = [elem for _, source, slider, active, _ in ling_sliders for elem in [source, slider, active]]
|
611 |
+
|
612 |
+
# Bind all the event handlers
|
613 |
+
estimate_src_btn.click(update_sliders_from_source,
|
614 |
+
inputs=[approx, sent1],
|
615 |
+
outputs=source_values)
|
616 |
+
estimate_tgt_btn.click(update_sliders_from_estimate,
|
617 |
+
inputs=[approx, sent_ling_est],
|
618 |
+
outputs=target_sliders + active_checkboxes)
|
619 |
+
estimate_gen_btn.click(
|
620 |
+
fn=estimate_gen,
|
621 |
+
inputs=[sent1, sent_ling_gen, approx],
|
622 |
+
outputs=[sent2, interpolation] + target_sliders
|
623 |
+
)
|
624 |
+
impute_btn.click(
|
625 |
+
fn=lambda: [gr.update(value=val) for val in impute_targets()],
|
626 |
+
outputs=target_sliders
|
627 |
+
)
|
628 |
+
copy_btn.click(
|
629 |
+
fn=copy_source_to_target,
|
630 |
+
outputs=target_sliders
|
631 |
+
)
|
632 |
+
generate_btn.click(
|
633 |
+
fn=generate_with_feedback,
|
634 |
+
inputs=[sent1, approx],
|
635 |
+
outputs=[sent2, interpolation] + target_sliders
|
636 |
+
)
|
637 |
+
generate_random_btn.click(
|
638 |
+
fn=generate_random,
|
639 |
+
inputs=[sent1, count, approx],
|
640 |
+
outputs=[sent2, interpolation]
|
641 |
+
)
|
642 |
+
add_btn.click(
|
643 |
+
fn=add_to_target,
|
644 |
+
outputs=target_sliders
|
645 |
+
)
|
646 |
+
sub_btn.click(
|
647 |
+
fn=subtract_from_target,
|
648 |
+
outputs=target_sliders
|
649 |
+
)
|
650 |
+
|
651 |
+
# Event handlers for the tools
|
652 |
+
rand_ex_btn.click(
|
653 |
+
fn=rand_ex_target,
|
654 |
+
outputs=target_sliders
|
655 |
+
)
|
656 |
+
|
657 |
+
copy_btn.click(
|
658 |
+
fn=copy_source_to_target,
|
659 |
+
outputs=target_sliders
|
660 |
+
)
|
661 |
+
|
662 |
+
add_btn.click(
|
663 |
+
fn=add_to_target,
|
664 |
+
outputs=target_sliders
|
665 |
+
)
|
666 |
+
|
667 |
+
sub_btn.click(
|
668 |
+
fn=subtract_from_target,
|
669 |
+
outputs=target_sliders
|
670 |
+
)
|
671 |
+
|
672 |
print('Finished loading')
|
673 |
demo.launch(share=True)
|