Spaces:
Sleeping
Sleeping
import spacy | |
import nltk | |
nltk.download('wordnet', quiet=True) | |
spacy.cli.download('en_core_web_sm') | |
from compute_lng import compute_lng | |
import torch | |
import joblib, json | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from const import used_indices, name_map | |
from model import get_model | |
from options import parse_args | |
from transformers import T5Tokenizer | |
from sklearn.experimental import enable_iterative_imputer | |
from sklearn.impute import IterativeImputer | |
from sklearn.linear_model import Ridge | |
def process_examples(samples): | |
processed = [] | |
for sample in samples: | |
example = [sample['sentence1']] + [str(x) for x in sample['sentence1_ling']] + sample['sentence2_ling'] | |
processed.append(example) | |
return processed | |
args, args_list, lng_names = parse_args(ckpt='./ckpt/model_fixed.pt') | |
tokenizer = T5Tokenizer.from_pretrained(args.model_name) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
lng_names = [name_map[x] for x in lng_names] | |
examples = json.load(open('assets/examples.json')) | |
example_ids = [44, 148, 86, 96, 98, 62, 114, 138] | |
examples = [examples[i] for i in example_ids] | |
examples = process_examples(examples) | |
stats = json.load(open('assets/stats.json')) | |
scaler = joblib.load('assets/scaler.bin') | |
scale_ratio = np.load('assets/ratios.npy') | |
ling_collection = np.load('assets/ling_collection.npy') | |
ling_collection_scaled = scaler.transform(ling_collection) | |
model, ling_disc, sem_emb = get_model(args, tokenizer, device) | |
# state = torch.load(args.ckpt, map_location=torch.device('cpu')) | |
# model.load_state_dict(state['model'], strict=True) | |
# model.eval() | |
# ling_disc.eval() | |
# state = torch.load(args.sem_ckpt, map_location=torch.device('cpu')) | |
# sem_emb.load_state_dict(state['model'], strict=True) | |
# sem_emb.eval() | |
############# Start demo code | |
def round_ling(x): | |
is_int = stats['is_int'] | |
mins = stats['min'] | |
maxs = stats['max'] | |
for i in range(len(x)): | |
# if is_int[i]: | |
# x[i] = round(x[i]) | |
# else: | |
# x[i] = round(x[i], 3) | |
x[i] = round(x[i], 3) | |
return np.clip(x, mins, maxs) | |
def visibility(mode): | |
if mode == 0: | |
vis_group = group1 | |
elif mode == 1: | |
vis_group = group2 | |
elif mode == 2: | |
vis_group = group3 | |
output = [gr.update(value=''), gr.update(value='')] | |
for component in components: | |
if component in vis_group: | |
output.append(gr.update(visible=True)) | |
else: | |
output.append(gr.update(visible=False)) | |
return output | |
def generate(sent1, ling_dict): | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling1 = scaler.transform([ling_dict['Source']]) | |
ling2 = scaler.transform([ling_dict['Target']]) | |
inputs = {'sentence1_input_ids': input_ids, | |
'sentence1_ling': torch.tensor(ling1).float().to(device), | |
'sentence2_ling': torch.tensor(ling2).float().to(device), | |
'sentence1_attention_mask': torch.ones_like(input_ids)} | |
preds = [] | |
with torch.no_grad(): | |
pred = model.infer(inputs).cpu().numpy() | |
pred = tokenizer.batch_decode(pred, | |
skip_special_tokens=True)[0] | |
return pred | |
def impute_targets(): | |
target_values = [] | |
for i in range(len(shared_state.target)): | |
if i in shared_state.active_indices: | |
target_values.append(shared_state.target[i]) | |
else: | |
target_values.append(np.nan) | |
target_values = np.array(target_values) | |
target_values_scaled = scaler.transform([target_values])[0] | |
estimator = Ridge(alpha=1e3, fit_intercept=False) | |
imputer = IterativeImputer(estimator=estimator, imputation_order='random', max_iter=100) | |
combined_matrix = np.vstack([ling_collection_scaled, target_values_scaled]) | |
interpolated_matrix = imputer.fit_transform(combined_matrix) | |
interpolated_vector = interpolated_matrix[-1] | |
interp_raw = scaler.inverse_transform([interpolated_vector])[0] | |
shared_state.target = round_ling(interp_raw).tolist() | |
return shared_state.target | |
def generate_with_feedback(sent1, approx): | |
if sent1 == '': | |
raise gr.Error('Please input a source text.') | |
# First impute any inactive targets | |
if len(shared_state.active_indices) < len(shared_state.target): | |
impute_targets() | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device) | |
inputs = { | |
'sentence1_input_ids': input_ids, | |
'sentence2_ling': ling2, | |
'sentence1_attention_mask': torch.ones_like(input_ids) | |
} | |
print('generating...') | |
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer) | |
interpolation = '-- ' + '\n-- '.join(interpolations) | |
# Return both the generation results and the updated slider values | |
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target] | |
def generate_random(sent1, count, approx): | |
if sent1 == '': | |
raise gr.Error('Please input a source text.') | |
preds, interpolations = [], [] | |
orig_active_indices = shared_state.active_indices | |
shared_state.active_indices = set(range(len(lng_names))) | |
for c in range(count): | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
shared_state.target = ling_ex.copy() | |
success = False | |
patience = 0 | |
while not success: | |
print(c, patience) | |
pred, interpolation = generate_with_feedback(sent1, approx)[:2] | |
print(pred) | |
if pred not in preds: | |
success = True | |
elif patience < 10: | |
patience += 1 | |
if np.random.rand() < 0.5: | |
for _ in range(patience): | |
add_to_target() | |
else: | |
for _ in range(patience): | |
subtract_from_target() | |
else: | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
shared_state.target = ling_ex.copy() | |
patience = 0 | |
preds.append(pred) | |
interpolations.append(interpolation) | |
shared_state.active_indices = orig_active_indices | |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations) | |
def estimate_gen(sent1, sent2, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent2))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
shared_state.target = ling_pred.copy() | |
orig_active_indices = shared_state.active_indices | |
shared_state.active_indices = set(range(len(lng_names))) | |
gen = generate_with_feedback(sent1, approx)[:2] | |
shared_state.active_indices = orig_active_indices | |
return gen + [gr.update(value=val) for val in shared_state.target] | |
def estimate_tgt(sent2, ling_dict, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent2))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
ling_dict['Target'] = ling_pred | |
return ling_dict | |
def estimate_src(sent1, ling_dict, approx): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent1))[used_indices] | |
else: | |
raise ValueError() | |
ling_dict['Source'] = ling_pred | |
return ling_dict | |
def rand_ex_target(): | |
idx = np.random.randint(0, len(ling_collection)) | |
ling_ex = ling_collection[idx] | |
shared_state.target = ling_ex.copy() | |
return [gr.update(value=val) for val in shared_state.target] | |
def copy_source_to_target(): | |
if "" in shared_state.source: | |
raise gr.Error("Source linguistic features not initialized. Please estimate them first.") | |
shared_state.target = shared_state.source.copy() | |
return [gr.update(value=val) for val in shared_state.target] | |
def add_to_target(): | |
if not shared_state.active_indices: | |
raise gr.Error("No features are activated. Please activate features to modify.") | |
scale_stepsize = np.random.uniform(1.0, 5.0) | |
new_targets = np.array(shared_state.target) | |
for i in shared_state.active_indices: | |
new_targets[i] += scale_stepsize * scale_ratio[i] | |
shared_state.target = round_ling(new_targets).tolist() | |
return [gr.update(value=val) for val in shared_state.target] | |
def subtract_from_target(): | |
if not shared_state.active_indices: | |
raise gr.Error("No features are activated. Please activate features to modify.") | |
scale_stepsize = np.random.uniform(1.0, 5.0) | |
new_targets = np.array(shared_state.target) | |
for i in shared_state.active_indices: | |
new_targets[i] -= scale_stepsize * scale_ratio[i] | |
shared_state.target = round_ling(new_targets).tolist() | |
return [gr.update(value=val) for val in shared_state.target] | |
title = """ | |
<h1 style="text-align: center;">Controlled Paraphrase Generation with Linguistic Feature Control</h1> | |
<p style="font-size:1.2em;">This system utilizes an encoder-decoder model to generate text with controlled complexity, guided by 40 linguistic complexity indices. | |
The model can generate diverse paraphrases of a given sentence, each adjusted to maintain consistent meaning while varying | |
in linguistic complexity according to the desired level.</p> | |
<p style="font-size:1.2em;">It is important to note that not all index combinations are feasible (e.g., a sentence of "length" 5 with 10 "unique words"). | |
To ensure high-quality outputs, our approach compares the initial generation with the target linguistic indices, and performs iterative refinement to match the closest, yet coherent | |
achievable set of indices for the given target.</p> | |
""" | |
guide = """ | |
1. **Select Operation Mode**: Choose from the available modes: | |
- **Linguistically-diverse Paraphrase Generation**: Generate diverse paraphrases. | |
- **Steps**: | |
1. Enter the source text in the provided textbox. | |
2. Specify the number of paraphrases you want. | |
3. Click "Generate" to produce paraphrases with varying linguistic complexity. | |
- **Complexity-Matched Paraphrasing**: Match the complexity of the input text. | |
- **Steps**: | |
1. Enter the source text in the provided textbox. | |
2. Provide another sentence to extract linguistic indices. | |
3. Click "Generate" to produce a paraphrase matching the complexity of the given sentence. | |
- **Manual Linguistic Control**: Manually adjust linguistic features using sliders. | |
- **Steps**: | |
1. Enter the source text in the provided textbox. | |
2. Activate or deactivate features of interest using the checkboxes. | |
3. Use the sliders to adjust linguistic features. | |
4. **Use Tools**: Access additional tools under "Tools to assist in setting linguistic indices" for advanced control. | |
- **Impute Missing Values**: Automatically fill inactive features. | |
- **Random Target**: Generate a random set of linguistic indices. | |
- **Copy Source to Target**: Copy linguistic indices from the source to the target. | |
- **Add/Subtract Complexity**: Adjust the complexity of the target indices. | |
5. Click "Generate" to produce the output text based on the adjusted features. | |
""" | |
# Updated Advanced Options Description | |
advanced_options_description = """ | |
**Advanced Options**: | |
- **Approximate vs. Exact Computation**: Choose between faster approximate computation or more precise exact computation of linguistic indices. | |
- **View Intermediate Generations**: Enable this option to see the intermediate sentences generated during the quality control process. | |
""" | |
css = """ | |
#guide span.svelte-1w6vloh {font-size: 22px !important; font-weight: 600 !important} | |
#mode span.svelte-1gfkn6j {font-size: 18px !important; font-weight: 600 !important} | |
#mode {border: 0px; box-shadow: none} | |
#mode .block {padding: 0px} | |
#estimate textarea {border: 1px solid; border-radius: 7px} | |
div.gradio-container {color: black} | |
div.form {background: inherit} | |
body { | |
--text-sm: 12px; | |
--text-md: 16px; | |
--text-lg: 18px; | |
--input-text-size: 16px; | |
--section-text-size: 16px; | |
--input-background: --neutral-50; | |
} | |
.top-separator { | |
width: 100%; | |
height: 4px; /* Adjust the height for boldness */ | |
background-color: #000; /* Adjust the color as needed */ | |
margin-top: 20px; /* Adjust the margin as needed */ | |
} | |
.bottom-separator { | |
width: 100%; | |
height: 4px; /* Adjust the height for boldness */ | |
background-color: #000; /* Adjust the color as needed */ | |
margin-bottom: 20px; /* Adjust the margin as needed */ | |
} | |
.features-container { | |
border: 1px solid rgba(0, 0, 0, 0.1); | |
border-radius: 8px; | |
background: white; | |
} | |
/* Style the inner column to be scrollable */ | |
.features-container > div > .column { | |
max-height: 400px; | |
overflow-y: scroll; | |
padding: 10px; | |
} | |
/* Scrollbar styles now apply to the inner column */ | |
.features-container > div > .column::-webkit-scrollbar { | |
width: 8px; | |
} | |
.features-container > div > .column::-webkit-scrollbar-track { | |
background: #f1f1f1; | |
border-radius: 4px; | |
} | |
.features-container > div > .column::-webkit-scrollbar-thumb { | |
background: #888; | |
border-radius: 4px; | |
} | |
.features-container > div > .column::-webkit-scrollbar-thumb:hover { | |
background: #555; | |
} | |
.features-container .label-wrap span { | |
font-weight: 600; | |
font-size: 18px; | |
} | |
""" | |
sent1 = gr.Textbox(label='Source text') | |
ling_sliders = [] | |
ling_dict = {'Source': [""] * len(lng_names), 'Target': [0] * len(lng_names)} | |
active_indices = [] | |
target_sliders = [] | |
source_values = [] | |
active_checkboxes = [] | |
for i in range(len(lng_names)): | |
source_values.append(gr.Textbox(placeholder="Not initialized", | |
lines=1, label="Source", interactive=False, | |
container=False, scale=1)) | |
active_checkboxes.append(gr.Checkbox(label="Activate", value=False)) | |
target_sliders.append( | |
gr.Slider( | |
minimum=stats['min'][i], | |
maximum=stats['max'][i], | |
value=stats['min'][i], | |
step=0.001 if not stats['is_int'][i] else 1, | |
label=None, | |
interactive=False | |
) | |
) | |
# Move SharedState class and instance to top | |
class SharedState: | |
def __init__(self, n_features): | |
self.source = [""] * n_features | |
self.target = [0] * n_features | |
self.active_indices = set() | |
def update_target(self, index, value): | |
self.target[index] = value | |
return self.target.copy() | |
def update_source(self, index, value): | |
self.source[index] = value | |
return self.source.copy() | |
def toggle_active(self, index, value): | |
if value: | |
self.active_indices.add(index) | |
else: | |
self.active_indices.discard(index) | |
return list(self.active_indices) | |
def get_state(self): | |
return { | |
'Source': self.source.copy(), | |
'Target': self.target.copy(), | |
'active_indices': list(self.active_indices) | |
} | |
shared_state = SharedState(len(lng_names)) | |
with gr.Blocks( | |
theme=gr.themes.Default( | |
spacing_size=gr.themes.sizes.spacing_md, | |
text_size=gr.themes.sizes.text_md, | |
), | |
css=css) as demo: | |
# Header | |
gr.Image('assets/logo.png', height=100, container=False, show_download_button=False, show_fullscreen_button=False) | |
gr.Markdown(title) | |
# Guide | |
with gr.Accordion("🚀 Quick Start Guide", open=False, elem_id='guide'): | |
gr.Markdown(guide) | |
with gr.Group(elem_classes='top-separator'): | |
pass | |
# Mode Selection | |
with gr.Group(elem_id='mode'): | |
mode = gr.Radio( | |
value='Linguistically-diverse Paraphrase Generation', | |
label='Operation Modes', | |
type="index", | |
choices=['🔄 Linguistically-diverse Paraphrase Generation', | |
'⚖️ Complexity-Matched Paraphrasing', | |
'🎛️ Manual Linguistic Control'], | |
) | |
with gr.Accordion("⚙️ Advanced Options", open=False): | |
gr.Markdown(advanced_options_description) | |
approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)', | |
choices=['Use approximate computation of linguistic indices (faster)', | |
'Use exact computation of linguistic indices'], container=False, show_label=False) | |
control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices') | |
# Main Input/Output | |
with gr.Row(): | |
with gr.Column(): | |
sent1.render() | |
count = gr.Number(label='Number of generated sentences', value=3, precision=0, scale=1, visible=True) | |
sent_ling_gen = gr.Textbox(label='Copy the style of this sentence', scale=1, visible=False) | |
with gr.Column(): | |
sent2 = gr.Textbox(label='Generated text') | |
generate_random_btn = gr.Button("Generate", variant='primary', scale=1, visible=True) | |
estimate_gen_btn = gr.Button("Generate", variant='primary', scale=1, visible=False) | |
generate_btn = gr.Button("Generate", variant='primary', visible=False) | |
# Linguistic Features Container | |
with gr.Accordion("Linguistic Features", elem_classes="features-container", open=True, visible=False) as ling_features: | |
with gr.Row(): | |
select_all_btn = gr.Button("Activate All", size='sm') | |
unselect_all_btn = gr.Button("Deactivate All", size='sm') | |
for i, name in enumerate(lng_names): | |
with gr.Row(): | |
feature_name = gr.Textbox(name, lines=1, label="Feature", container=False, show_label=False, interactive=False) | |
source_values[i].render() | |
active_checkboxes[i].render() | |
target_sliders[i].interactive = False | |
target_sliders[i].render() | |
ling_sliders.append((feature_name, source_values[i], target_sliders[i], active_checkboxes[i], i)) | |
# Tools Accordion | |
with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools: | |
rand_ex_btn = gr.Button("Random target", size='lg', visible=False) | |
impute_btn = gr.Button("Impute Missing Values", size='lg', visible=False) | |
with gr.Row(): | |
estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False) | |
copy_btn = gr.Button("Copy linguistic indices of source to target", size='lg', visible=False) | |
with gr.Row(): | |
sub_btn = gr.Button('Decrease target complexity by \u03B5', visible=False) | |
add_btn = gr.Button('Increase target complexity by \u03B5', visible=False) | |
with gr.Row(): | |
estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence →", visible=False) | |
sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False, container=False, elem_id='estimate') | |
interpolation = gr.Textbox(label='Quality control interpolation', visible=False, lines=5) | |
with gr.Group(elem_classes='bottom-separator'): | |
pass | |
# Examples | |
def load_example(example_text, *values): | |
# Split values into source, target, and active values | |
n = len(lng_names) | |
source_values = values[:n] | |
target_values = values[n:] | |
# Update shared state | |
shared_state.source = [float(x) for x in source_values] | |
shared_state.target = list(target_values) | |
shared_state.active_indices = set(range(n)) # Activate all indices | |
# Return updates for all components: | |
return [True] * n | |
gr.Examples( | |
examples=examples, | |
inputs=[sent1] + source_values + target_sliders, | |
outputs=active_checkboxes, | |
example_labels=[ex[0] for ex in examples], | |
fn=load_example, | |
run_on_click=True, | |
) | |
# Add select/unselect all handlers | |
def select_all(): | |
for i in range(len(lng_names)): | |
shared_state.toggle_active(i, True) | |
return [True] * len(lng_names) + [gr.update(interactive=True)] * len(lng_names) | |
def unselect_all(): | |
shared_state.active_indices.clear() | |
return [False] * len(lng_names) + [gr.update(interactive=False)] * len(lng_names) | |
select_all_btn.click( | |
fn=select_all, | |
outputs=active_checkboxes + [slider for _, _, slider, _, _ in ling_sliders] | |
) | |
unselect_all_btn.click( | |
fn=unselect_all, | |
outputs=active_checkboxes + [slider for _, _, slider, _, _ in ling_sliders] | |
) | |
def update_slider(slider_index, new_value): | |
shared_state.target[slider_index] = new_value | |
def update_checkbox(checkbox_index, new_value): | |
shared_state.toggle_active(checkbox_index, new_value) | |
return gr.update(interactive=new_value) | |
# Update the event bindings | |
for feature_name, source_value, target_slider, active_checkbox, i in ling_sliders: | |
target_slider.change( | |
fn=update_slider, | |
inputs=[gr.Number(i, visible=False), target_slider], | |
) | |
active_checkbox.change( | |
fn=update_checkbox, | |
inputs=[gr.Number(i, visible=False), active_checkbox], | |
outputs=target_slider | |
) | |
# Define groups and visibility | |
group1 = [generate_random_btn, count] | |
group2 = [estimate_gen_btn, sent_ling_gen] | |
group3 = [generate_btn, estimate_src_btn, impute_btn, estimate_tgt_btn, sent_ling_est, | |
rand_ex_btn, copy_btn, add_btn, sub_btn, ling_features, ling_tools] | |
components = group1 + group2 + group3 | |
mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components) | |
control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation], | |
outputs=[interpolation]) | |
def update_sliders_from_state(ling_state, slider_indices): | |
updates = [] | |
for i in slider_indices: | |
updates.append(str(ling_state['Source'][i])) | |
updates.append(ling_state['Target'][i]) | |
updates.append(gr.update(value=True)) | |
return updates | |
def update_sliders_from_estimate(approx, sent_for_estimate): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(sent_for_estimate, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(sent_for_estimate))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
shared_state.source = ling_pred.copy() | |
shared_state.target = ling_pred.copy() | |
# Return updates separately for each type of component | |
return ling_pred + [True] * len(lng_names) | |
def update_sliders_from_source(approx, source_sent): | |
if 'approximate' in approx: | |
input_ids = tokenizer.encode(source_sent, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
ling_pred = ling_disc(input_ids=input_ids).cpu().numpy() | |
ling_pred = scaler.inverse_transform(ling_pred)[0] | |
elif 'exact' in approx: | |
ling_pred = np.array(compute_lng(source_sent))[used_indices] | |
else: | |
raise ValueError() | |
ling_pred = round_ling(ling_pred) | |
shared_state.source = ling_pred.copy() | |
return [str(ling_pred[i]) for i in range(len(lng_names))] | |
slider_indices = [i for _, _, _, _, i in ling_sliders] | |
slider_updates = [elem for _, source, slider, active, _ in ling_sliders for elem in [source, slider, active]] | |
# Bind all the event handlers | |
estimate_src_btn.click(update_sliders_from_source, | |
inputs=[approx, sent1], | |
outputs=source_values) | |
estimate_tgt_btn.click(update_sliders_from_estimate, | |
inputs=[approx, sent_ling_est], | |
outputs=target_sliders + active_checkboxes) | |
estimate_gen_btn.click( | |
fn=estimate_gen, | |
inputs=[sent1, sent_ling_gen, approx], | |
outputs=[sent2, interpolation] + target_sliders | |
) | |
impute_btn.click( | |
fn=lambda: [gr.update(value=val) for val in impute_targets()], | |
outputs=target_sliders | |
) | |
copy_btn.click( | |
fn=copy_source_to_target, | |
outputs=target_sliders | |
) | |
generate_btn.click( | |
fn=generate_with_feedback, | |
inputs=[sent1, approx], | |
outputs=[sent2, interpolation] + target_sliders | |
) | |
generate_random_btn.click( | |
fn=generate_random, | |
inputs=[sent1, count, approx], | |
outputs=[sent2, interpolation] | |
) | |
add_btn.click( | |
fn=add_to_target, | |
outputs=target_sliders | |
) | |
sub_btn.click( | |
fn=subtract_from_target, | |
outputs=target_sliders | |
) | |
# Event handlers for the tools | |
rand_ex_btn.click( | |
fn=rand_ex_target, | |
outputs=target_sliders | |
) | |
copy_btn.click( | |
fn=copy_source_to_target, | |
outputs=target_sliders | |
) | |
add_btn.click( | |
fn=add_to_target, | |
outputs=target_sliders | |
) | |
sub_btn.click( | |
fn=subtract_from_target, | |
outputs=target_sliders | |
) | |
print('Finished loading') | |
demo.launch(share=True) | |