Spaces:
Sleeping
Sleeping
Commit
·
54ba470
1
Parent(s):
01ac9fe
implement imputation
Browse files
app.py
CHANGED
@@ -13,6 +13,9 @@ from model import get_model
|
|
13 |
from options import parse_args
|
14 |
from transformers import T5Tokenizer
|
15 |
from compute_lng import compute_lng
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def process_examples(samples, full_names):
|
@@ -35,10 +38,12 @@ examples = process_examples(examples, lng_names)
|
|
35 |
|
36 |
stats = json.load(open('assets/stats.json'))
|
37 |
|
38 |
-
ling_collection = np.load('assets/ling_collection.npy')
|
39 |
scaler = joblib.load('assets/scaler.bin')
|
40 |
scale_ratio = np.load('assets/ratios.npy')
|
41 |
|
|
|
|
|
|
|
42 |
model, ling_disc, sem_emb = get_model(args, tokenizer, device)
|
43 |
|
44 |
state = torch.load(args.ckpt, map_location=torch.device('cpu'))
|
@@ -201,6 +206,21 @@ def sub(ling):
|
|
201 |
ling['Target'] = x
|
202 |
return ling
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
title = """
|
205 |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
|
206 |
|
@@ -255,6 +275,8 @@ css = """
|
|
255 |
#mode {border: 0px; box-shadow: none}
|
256 |
#mode .block {padding: 0px}
|
257 |
|
|
|
|
|
258 |
div.gradio-container {color: black}
|
259 |
div.form {background: inherit}
|
260 |
|
@@ -336,6 +358,7 @@ with gr.Blocks(
|
|
336 |
generate_btn = gr.Button("Generate", variant='primary', visible=False)
|
337 |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
|
338 |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
|
|
|
339 |
with gr.Row():
|
340 |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
|
341 |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
|
@@ -344,28 +367,25 @@ with gr.Blocks(
|
|
344 |
add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False)
|
345 |
with gr.Row():
|
346 |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence →", visible=False)
|
347 |
-
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False)
|
348 |
ling.render()
|
349 |
#####################
|
350 |
|
351 |
estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling])
|
352 |
estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
|
353 |
-
# estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling, ling], outputs=[ling])
|
354 |
estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
|
355 |
-
# rand_btn.click(rand_target, inputs=[ling], outputs=[ling])
|
356 |
rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
|
|
|
357 |
copy_btn.click(copy, inputs=[ling], outputs=[ling])
|
358 |
generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
|
359 |
generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
|
360 |
outputs=[sent2, interpolation, ling])
|
361 |
-
# generate_fb_btn.click(generate_with_feedback, inputs=[sent1, ling], outputs=sent2s)
|
362 |
-
# generate_fb_s_btn.click(generate_with_feedbacks, inputs=[sent1, ling], outputs=sent2s)
|
363 |
add_btn.click(add, inputs=[ling], outputs=[ling])
|
364 |
sub_btn.click(sub, inputs=[ling], outputs=[ling])
|
365 |
|
366 |
group1 = [generate_random_btn, count]
|
367 |
group2 = [estimate_gen_btn, sent_ling_gen]
|
368 |
-
group3 = [generate_btn, estimate_src_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools]
|
369 |
components = group1 + group2 + group3
|
370 |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
|
371 |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
|
|
|
13 |
from options import parse_args
|
14 |
from transformers import T5Tokenizer
|
15 |
from compute_lng import compute_lng
|
16 |
+
from sklearn.experimental import enable_iterative_imputer
|
17 |
+
from sklearn.impute import IterativeImputer
|
18 |
+
from sklearn.linear_model import Ridge
|
19 |
|
20 |
|
21 |
def process_examples(samples, full_names):
|
|
|
38 |
|
39 |
stats = json.load(open('assets/stats.json'))
|
40 |
|
|
|
41 |
scaler = joblib.load('assets/scaler.bin')
|
42 |
scale_ratio = np.load('assets/ratios.npy')
|
43 |
|
44 |
+
ling_collection = np.load('assets/ling_collection.npy')
|
45 |
+
ling_collection_scaled = scaler.transform(ling_collection)
|
46 |
+
|
47 |
model, ling_disc, sem_emb = get_model(args, tokenizer, device)
|
48 |
|
49 |
state = torch.load(args.ckpt, map_location=torch.device('cpu'))
|
|
|
206 |
ling['Target'] = x
|
207 |
return ling
|
208 |
|
209 |
+
def impute(ling):
|
210 |
+
ling['Target'] = ling['Target'].replace("", np.nan)
|
211 |
+
ling['Target'] = scaler.transform([ling['Target']])[0]
|
212 |
+
estimator = Ridge(alpha=1e3, fit_intercept=False)
|
213 |
+
imputer = IterativeImputer(estimator=estimator, imputation_order='random', max_iter=100)
|
214 |
+
|
215 |
+
combined_matrix = np.vstack([ling_collection, ling['Target']])
|
216 |
+
interpolated_matrix = imputer.fit_transform(combined_matrix)
|
217 |
+
interpolated_vector = interpolated_matrix[-1]
|
218 |
+
|
219 |
+
interp_raw = scaler.inverse_transform([interpolated_vector])[0]
|
220 |
+
|
221 |
+
ling['Target'] = round_ling(interp_raw)
|
222 |
+
return ling
|
223 |
+
|
224 |
title = """
|
225 |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1>
|
226 |
|
|
|
275 |
#mode {border: 0px; box-shadow: none}
|
276 |
#mode .block {padding: 0px}
|
277 |
|
278 |
+
#estimate textarea {border: 1px solid; border-radius: 7px}
|
279 |
+
|
280 |
div.gradio-container {color: black}
|
281 |
div.form {background: inherit}
|
282 |
|
|
|
358 |
generate_btn = gr.Button("Generate", variant='primary', visible=False)
|
359 |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
|
360 |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
|
361 |
+
impute_btn = gr.Button("Impute Missing Values", size='lg', visible=False)
|
362 |
with gr.Row():
|
363 |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
|
364 |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False)
|
|
|
367 |
add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False)
|
368 |
with gr.Row():
|
369 |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence →", visible=False)
|
370 |
+
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False, elem_id='estimate')
|
371 |
ling.render()
|
372 |
#####################
|
373 |
|
374 |
estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling])
|
375 |
estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
|
|
|
376 |
estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
|
|
|
377 |
rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
|
378 |
+
impute_btn.click(impute, inputs=[ling], outputs=[ling])
|
379 |
copy_btn.click(copy, inputs=[ling], outputs=[ling])
|
380 |
generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
|
381 |
generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
|
382 |
outputs=[sent2, interpolation, ling])
|
|
|
|
|
383 |
add_btn.click(add, inputs=[ling], outputs=[ling])
|
384 |
sub_btn.click(sub, inputs=[ling], outputs=[ling])
|
385 |
|
386 |
group1 = [generate_random_btn, count]
|
387 |
group2 = [estimate_gen_btn, sent_ling_gen]
|
388 |
+
group3 = [generate_btn, estimate_src_btn, impute_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools]
|
389 |
components = group1 + group2 + group3
|
390 |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
|
391 |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
|