|
import json |
|
import os |
|
|
|
import gradio as gr |
|
import spaces |
|
from contents import ( |
|
citation, |
|
description, |
|
examples, |
|
how_it_works, |
|
how_to_use, |
|
subtitle, |
|
title, |
|
) |
|
from gradio_highlightedtextbox import HighlightedTextbox |
|
from presets import ( |
|
set_chatml_preset, |
|
set_cora_preset, |
|
set_default_preset, |
|
set_mmt_preset, |
|
set_towerinstruct_preset, |
|
set_zephyr_preset, |
|
set_gemma_preset, |
|
) |
|
from style import custom_css |
|
from utils import get_formatted_attribute_context_results |
|
|
|
from inseq import list_feature_attribution_methods, list_step_functions |
|
from inseq.commands.attribute_context.attribute_context import ( |
|
AttributeContextArgs, |
|
attribute_context_with_model, |
|
) |
|
from inseq.models import HuggingfaceModel |
|
|
|
loaded_model: HuggingfaceModel = None |
|
|
|
|
|
@spaces.GPU() |
|
def pecore( |
|
input_current_text: str, |
|
input_context_text: str, |
|
output_current_text: str, |
|
output_context_text: str, |
|
model_name_or_path: str, |
|
attribution_method: str, |
|
attributed_fn: str | None, |
|
context_sensitivity_metric: str, |
|
context_sensitivity_std_threshold: float, |
|
context_sensitivity_topk: int, |
|
attribution_std_threshold: float, |
|
attribution_topk: int, |
|
input_template: str, |
|
contextless_input_current_text: str, |
|
output_template: str, |
|
special_tokens_to_keep: str | list[str] | None, |
|
decoder_input_output_separator: str, |
|
model_kwargs: str, |
|
tokenizer_kwargs: str, |
|
generation_kwargs: str, |
|
attribution_kwargs: str, |
|
): |
|
global loaded_model |
|
if "{context}" in output_template and not output_context_text: |
|
raise gr.Error( |
|
"Parameter 'Generated context' is required when using {context} in the output template." |
|
) |
|
if loaded_model is None or model_name_or_path != loaded_model.model_name: |
|
gr.Info("Loading model...") |
|
loaded_model = HuggingfaceModel.load( |
|
model_name_or_path, |
|
attribution_method, |
|
model_kwargs=json.loads(model_kwargs), |
|
tokenizer_kwargs=json.loads(tokenizer_kwargs), |
|
) |
|
kwargs = {} |
|
if context_sensitivity_topk > 0: |
|
kwargs["context_sensitivity_topk"] = context_sensitivity_topk |
|
if attribution_topk > 0: |
|
kwargs["attribution_topk"] = attribution_topk |
|
if input_context_text: |
|
kwargs["input_context_text"] = input_context_text |
|
if output_context_text: |
|
kwargs["output_context_text"] = output_context_text |
|
if output_current_text: |
|
kwargs["output_current_text"] = output_current_text |
|
if decoder_input_output_separator: |
|
kwargs["decoder_input_output_separator"] = decoder_input_output_separator |
|
pecore_args = AttributeContextArgs( |
|
show_intermediate_outputs=False, |
|
save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), |
|
add_output_info=True, |
|
viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), |
|
show_viz=False, |
|
model_name_or_path=model_name_or_path, |
|
attribution_method=attribution_method, |
|
attributed_fn=attributed_fn, |
|
attribution_selectors=None, |
|
attribution_aggregators=None, |
|
normalize_attributions=True, |
|
model_kwargs=json.loads(model_kwargs), |
|
tokenizer_kwargs=json.loads(tokenizer_kwargs), |
|
generation_kwargs=json.loads(generation_kwargs), |
|
attribution_kwargs=json.loads(attribution_kwargs), |
|
context_sensitivity_metric=context_sensitivity_metric, |
|
prompt_user_for_contextless_output_next_tokens=False, |
|
special_tokens_to_keep=special_tokens_to_keep, |
|
context_sensitivity_std_threshold=context_sensitivity_std_threshold, |
|
attribution_std_threshold=attribution_std_threshold, |
|
input_current_text=input_current_text, |
|
input_template=input_template, |
|
output_template=output_template, |
|
contextless_input_current_text=contextless_input_current_text, |
|
handle_output_context_strategy="pre", |
|
**kwargs, |
|
) |
|
out = attribute_context_with_model(pecore_args, loaded_model) |
|
tuples = get_formatted_attribute_context_results(loaded_model, out.info, out) |
|
if not tuples: |
|
msg = f"Output: {out.output_current}\nWarning: No pairs were found by PECoRe. Try adjusting Results Selection parameters." |
|
tuples = [(msg, None)] |
|
return tuples, gr.Button(visible=True), gr.Button(visible=True) |
|
|
|
|
|
@spaces.GPU() |
|
def preload_model( |
|
model_name_or_path: str, |
|
attribution_method: str, |
|
model_kwargs: str, |
|
tokenizer_kwargs: str, |
|
): |
|
global loaded_model |
|
if loaded_model is None or model_name_or_path != loaded_model.model_name: |
|
gr.Info("Loading model...") |
|
loaded_model = HuggingfaceModel.load( |
|
model_name_or_path, |
|
attribution_method, |
|
model_kwargs=json.loads(model_kwargs), |
|
tokenizer_kwargs=json.loads(tokenizer_kwargs), |
|
) |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(subtitle) |
|
gr.Markdown(description) |
|
with gr.Tab("π Attributing Context"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_context_text = gr.Textbox( |
|
label="Input context", lines=4, placeholder="Your input context..." |
|
) |
|
input_current_text = gr.Textbox( |
|
label="Input query", placeholder="Your input query..." |
|
) |
|
attribute_input_button = gr.Button("Submit", variant="primary") |
|
with gr.Column(): |
|
pecore_output_highlights = HighlightedTextbox( |
|
value=[ |
|
("This output will contain ", None), |
|
("context sensitive", "Context sensitive"), |
|
(" generated tokens and ", None), |
|
("influential context", "Influential context"), |
|
(" tokens.", None), |
|
], |
|
color_map={ |
|
"Context sensitive": "green", |
|
"Influential context": "blue", |
|
}, |
|
show_legend=True, |
|
label="PECoRe Output", |
|
combine_adjacent=True, |
|
interactive=False, |
|
) |
|
with gr.Row(equal_height=True): |
|
download_output_file_button = gr.Button( |
|
"β Download output", |
|
visible=False, |
|
link=os.path.join( |
|
os.path.dirname(__file__), "/file=outputs/output.json" |
|
), |
|
) |
|
download_output_html_button = gr.Button( |
|
"π Download HTML", |
|
visible=False, |
|
link=os.path.join( |
|
os.path.dirname(__file__), "/file=outputs/output.html" |
|
), |
|
) |
|
|
|
attribute_input_examples = gr.Examples( |
|
examples, |
|
inputs=[input_current_text, input_context_text], |
|
outputs=pecore_output_highlights, |
|
) |
|
with gr.Tab("βοΈ Parameters") as params_tab: |
|
gr.Markdown( |
|
"## β¨ Presets\nSelect a preset to load default parameters into the fields below." |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
default_preset = gr.Button("Default", variant="secondary") |
|
gr.Markdown( |
|
"Default preset using templates without special tokens or parameters.\nCan be used with most decoder-only and encoder-decoder models." |
|
) |
|
with gr.Column(): |
|
cora_preset = gr.Button("CORA mQA", variant="secondary") |
|
gr.Markdown( |
|
"Preset for the <a href='https://huggingface.co/gsarti/cora_mgen' target='_blank'>CORA Multilingual QA</a> model.\nUses special templates for inputs." |
|
) |
|
with gr.Column(): |
|
zephyr_preset = gr.Button("Zephyr Template", variant="secondary") |
|
gr.Markdown( |
|
"Preset for models using the <a href='https://huggingface.co/HuggingFaceH4/zephyr-7b-beta' target='_blank'>Zephyr conversational template</a>.\nUses <code><|system|></code>, <code><|user|></code> and <code><|assistant|></code> special tokens." |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=1): |
|
multilingual_mt_template = gr.Button( |
|
"Multilingual MT", variant="secondary" |
|
) |
|
gr.Markdown( |
|
"Preset for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags." |
|
) |
|
with gr.Column(scale=1): |
|
chatml_template = gr.Button("Qwen ChatML", variant="secondary") |
|
gr.Markdown( |
|
"Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens." |
|
) |
|
with gr.Column(scale=1): |
|
towerinstruct_template = gr.Button( |
|
"Unbabel TowerInstruct", variant="secondary" |
|
) |
|
gr.Markdown( |
|
"Preset for models using the <a href='https://huggingface.co/Unbabel/TowerInstruct-7B-v0.1' target='_blank'>Unbabel TowerInstruct</a> conversational template.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens." |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=1): |
|
gemma_template = gr.Button( |
|
"Gemma Chat Template", variant="secondary" |
|
) |
|
gr.Markdown( |
|
"Preset for <a href='https://huggingface.co/google/gemma-2b-it' target='_blank'>Gemma</a> instruction-tuned models." |
|
) |
|
gr.Markdown("## βοΈ PECoRe Parameters") |
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
model_name_or_path = gr.Textbox( |
|
value="gpt2", |
|
label="Model", |
|
info="Hugging Face Hub identifier of the model to analyze with PECoRe.", |
|
interactive=True, |
|
) |
|
load_model_button = gr.Button( |
|
"Load model", |
|
variant="secondary", |
|
) |
|
context_sensitivity_metric = gr.Dropdown( |
|
value="kl_divergence", |
|
label="Context sensitivity metric", |
|
info="Metric to use to measure context sensitivity of generated tokens.", |
|
choices=list_step_functions(), |
|
interactive=True, |
|
) |
|
attribution_method = gr.Dropdown( |
|
value="saliency", |
|
label="Attribution method", |
|
info="Attribution method identifier to identify relevant context tokens.", |
|
choices=list_feature_attribution_methods(), |
|
interactive=True, |
|
) |
|
attributed_fn = gr.Dropdown( |
|
value="contrast_prob_diff", |
|
label="Attributed function", |
|
info="Function of model logits to use as target for the attribution method.", |
|
choices=list_step_functions(), |
|
interactive=True, |
|
) |
|
gr.Markdown("#### Results Selection Parameters") |
|
with gr.Row(equal_height=True): |
|
context_sensitivity_std_threshold = gr.Number( |
|
value=1.0, |
|
label="Context sensitivity threshold", |
|
info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.", |
|
precision=1, |
|
minimum=0.0, |
|
maximum=5.0, |
|
step=0.5, |
|
interactive=True, |
|
) |
|
context_sensitivity_topk = gr.Number( |
|
value=0, |
|
label="Context sensitivity top-k", |
|
info="Select N to keep top N context sensitive tokens. 0 = keep all.", |
|
interactive=True, |
|
precision=0, |
|
minimum=0, |
|
maximum=10, |
|
) |
|
attribution_std_threshold = gr.Number( |
|
value=1.0, |
|
label="Attribution threshold", |
|
info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.", |
|
precision=1, |
|
minimum=0.0, |
|
maximum=5.0, |
|
step=0.5, |
|
interactive=True, |
|
) |
|
attribution_topk = gr.Number( |
|
value=0, |
|
label="Attribution top-k", |
|
info="Select N to keep top N attributed tokens in the context. 0 = keep all.", |
|
interactive=True, |
|
precision=0, |
|
minimum=0, |
|
maximum=50, |
|
) |
|
|
|
gr.Markdown("#### Text Format Parameters") |
|
with gr.Row(equal_height=True): |
|
input_template = gr.Textbox( |
|
value="{current} <P>:{context}", |
|
label="Input template", |
|
info="Template to format the input for the model. Use {current} and {context} placeholders.", |
|
interactive=True, |
|
) |
|
output_template = gr.Textbox( |
|
value="{current}", |
|
label="Output template", |
|
info="Template to format the output from the model. Use {current} and {context} placeholders.", |
|
interactive=True, |
|
) |
|
contextless_input_current_text = gr.Textbox( |
|
value="<Q>:{current}", |
|
label="Input current text template", |
|
info="Template to format the input query for the model. Use {current} placeholder.", |
|
interactive=True, |
|
) |
|
with gr.Row(equal_height=True): |
|
special_tokens_to_keep = gr.Dropdown( |
|
label="Special tokens to keep", |
|
info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.", |
|
value=None, |
|
multiselect=True, |
|
allow_custom_value=True, |
|
) |
|
decoder_input_output_separator = gr.Textbox( |
|
label="Decoder input/output separator", |
|
info="Separator to use between input and output in the decoder input.", |
|
value="", |
|
interactive=True, |
|
lines=1, |
|
) |
|
|
|
gr.Markdown("## βοΈ Generation Parameters") |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=0.5): |
|
gr.Markdown( |
|
"The following arguments can be used to control generation parameters and force specific model outputs." |
|
) |
|
with gr.Column(scale=1): |
|
generation_kwargs = gr.Code( |
|
value="{}", |
|
language="json", |
|
label="Generation kwargs (JSON)", |
|
interactive=True, |
|
lines=1, |
|
) |
|
with gr.Row(equal_height=True): |
|
output_current_text = gr.Textbox( |
|
label="Generation output", |
|
info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.", |
|
interactive=True, |
|
) |
|
output_context_text = gr.Textbox( |
|
label="Generation context", |
|
info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.", |
|
interactive=True, |
|
) |
|
gr.Markdown("## βοΈ Other Parameters") |
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
gr.Markdown( |
|
"The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method." |
|
) |
|
with gr.Column(): |
|
model_kwargs = gr.Code( |
|
value="{}", |
|
language="json", |
|
label="Model kwargs (JSON)", |
|
interactive=True, |
|
lines=1, |
|
min_width=160, |
|
) |
|
with gr.Column(): |
|
tokenizer_kwargs = gr.Code( |
|
value="{}", |
|
language="json", |
|
label="Tokenizer kwargs (JSON)", |
|
interactive=True, |
|
lines=1, |
|
) |
|
with gr.Column(): |
|
attribution_kwargs = gr.Code( |
|
value="{}", |
|
language="json", |
|
label="Attribution kwargs (JSON)", |
|
interactive=True, |
|
lines=1, |
|
) |
|
|
|
gr.Markdown(how_it_works) |
|
gr.Markdown(how_to_use) |
|
gr.Markdown(citation) |
|
|
|
|
|
|
|
load_model_args = [ |
|
model_name_or_path, |
|
attribution_method, |
|
model_kwargs, |
|
tokenizer_kwargs, |
|
] |
|
|
|
attribute_input_button.click( |
|
pecore, |
|
inputs=[ |
|
input_current_text, |
|
input_context_text, |
|
output_current_text, |
|
output_context_text, |
|
model_name_or_path, |
|
attribution_method, |
|
attributed_fn, |
|
context_sensitivity_metric, |
|
context_sensitivity_std_threshold, |
|
context_sensitivity_topk, |
|
attribution_std_threshold, |
|
attribution_topk, |
|
input_template, |
|
contextless_input_current_text, |
|
output_template, |
|
special_tokens_to_keep, |
|
decoder_input_output_separator, |
|
model_kwargs, |
|
tokenizer_kwargs, |
|
generation_kwargs, |
|
attribution_kwargs, |
|
], |
|
outputs=[ |
|
pecore_output_highlights, |
|
download_output_file_button, |
|
download_output_html_button, |
|
], |
|
) |
|
|
|
load_model_button.click( |
|
preload_model, |
|
inputs=load_model_args, |
|
outputs=[], |
|
) |
|
|
|
|
|
|
|
outputs_to_reset = [ |
|
model_name_or_path, |
|
input_template, |
|
contextless_input_current_text, |
|
output_template, |
|
special_tokens_to_keep, |
|
decoder_input_output_separator, |
|
model_kwargs, |
|
tokenizer_kwargs, |
|
generation_kwargs, |
|
attribution_kwargs, |
|
] |
|
reset_kwargs = { |
|
"fn": set_default_preset, |
|
"inputs": None, |
|
"outputs": outputs_to_reset, |
|
} |
|
|
|
|
|
|
|
default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args) |
|
|
|
cora_preset.click(**reset_kwargs).then( |
|
set_cora_preset, |
|
outputs=[model_name_or_path, input_template, contextless_input_current_text], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
zephyr_preset.click(**reset_kwargs).then( |
|
set_zephyr_preset, |
|
outputs=[ |
|
model_name_or_path, |
|
input_template, |
|
contextless_input_current_text, |
|
decoder_input_output_separator, |
|
], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
multilingual_mt_template.click(**reset_kwargs).then( |
|
set_mmt_preset, |
|
outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
chatml_template.click(**reset_kwargs).then( |
|
set_chatml_preset, |
|
outputs=[ |
|
model_name_or_path, |
|
input_template, |
|
contextless_input_current_text, |
|
decoder_input_output_separator, |
|
special_tokens_to_keep, |
|
], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
towerinstruct_template.click(**reset_kwargs).then( |
|
set_towerinstruct_preset, |
|
outputs=[ |
|
model_name_or_path, |
|
input_template, |
|
contextless_input_current_text, |
|
decoder_input_output_separator, |
|
special_tokens_to_keep, |
|
], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
gemma_template.click(**reset_kwargs).then( |
|
set_gemma_preset, |
|
outputs=[ |
|
model_name_or_path, |
|
input_template, |
|
contextless_input_current_text, |
|
decoder_input_output_separator, |
|
special_tokens_to_keep, |
|
], |
|
).success(preload_model, inputs=load_model_args) |
|
|
|
demo.launch(allowed_paths=["outputs/"]) |
|
|