pecore / app.py
gsarti's picture
Remove api usage
e0116a9
raw
history blame
27.7 kB
import json
import os
import gradio as gr
import spaces
from contents import (
pecore_citation,
inseq_citation,
description,
examples,
how_it_works_intro,
cti_explanation,
cci_explanation,
how_to_use,
example_explanation,
show_code_modal,
subtitle,
title,
powered_by,
support,
)
from gradio_highlightedtextbox import HighlightedTextbox
from gradio_modal import Modal
from presets import (
set_chatml_preset,
set_cora_preset,
set_default_preset,
set_mmt_preset,
set_towerinstruct_preset,
set_zephyr_preset,
set_gemma_preset,
set_mistral_instruct_preset,
update_code_snippets_fn,
)
from style import custom_css
from utils import get_formatted_attribute_context_results
from inseq.commands.attribute_context.attribute_context import (
AttributeContextArgs,
attribute_context_with_model,
)
from inseq.models import HuggingfaceModel
loaded_model: HuggingfaceModel = None
@spaces.GPU()
def pecore(
input_current_text: str,
input_context_text: str,
output_current_text: str,
output_context_text: str,
model_name_or_path: str,
attribution_method: str,
attributed_fn: str | None,
context_sensitivity_metric: str,
context_sensitivity_std_threshold: float,
context_sensitivity_topk: int,
attribution_std_threshold: float,
attribution_topk: int,
input_template: str,
output_template: str,
contextless_input_template: str,
contextless_output_template: str,
special_tokens_to_keep: str | list[str] | None,
decoder_input_output_separator: str,
model_kwargs: str,
tokenizer_kwargs: str,
generation_kwargs: str,
attribution_kwargs: str,
):
global loaded_model
if "{context}" in output_template and not output_context_text:
raise gr.Error(
"Parameter 'Generation context' must be set when including {context} in the output template."
)
if loaded_model is None or model_name_or_path != loaded_model.model_name:
gr.Info("Loading model...")
loaded_model = HuggingfaceModel.load(
model_name_or_path,
attribution_method,
model_kwargs=json.loads(model_kwargs),
tokenizer_kwargs=json.loads(tokenizer_kwargs),
)
kwargs = {}
if context_sensitivity_topk > 0:
kwargs["context_sensitivity_topk"] = context_sensitivity_topk
if attribution_topk > 0:
kwargs["attribution_topk"] = attribution_topk
if input_context_text:
kwargs["input_context_text"] = input_context_text
if output_context_text:
kwargs["output_context_text"] = output_context_text
if output_current_text:
kwargs["output_current_text"] = output_current_text
if decoder_input_output_separator:
kwargs["decoder_input_output_separator"] = decoder_input_output_separator
pecore_args = AttributeContextArgs(
show_intermediate_outputs=False,
save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
add_output_info=True,
viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
show_viz=False,
model_name_or_path=model_name_or_path,
attribution_method=attribution_method,
attributed_fn=attributed_fn,
attribution_selectors=None,
attribution_aggregators=None,
normalize_attributions=True,
model_kwargs=json.loads(model_kwargs),
tokenizer_kwargs=json.loads(tokenizer_kwargs),
generation_kwargs=json.loads(generation_kwargs),
attribution_kwargs=json.loads(attribution_kwargs),
context_sensitivity_metric=context_sensitivity_metric,
prompt_user_for_contextless_output_next_tokens=False,
special_tokens_to_keep=special_tokens_to_keep,
context_sensitivity_std_threshold=context_sensitivity_std_threshold,
attribution_std_threshold=attribution_std_threshold,
input_current_text=input_current_text,
input_template=input_template,
output_template=output_template,
contextless_input_current_text=contextless_input_template,
contextless_output_current_text=contextless_output_template,
handle_output_context_strategy="pre",
**kwargs,
)
out = attribute_context_with_model(pecore_args, loaded_model)
tuples = get_formatted_attribute_context_results(loaded_model, out.info, out)
if not tuples:
msg = f"Output: {out.output_current}\nWarning: No pairs were found by PECoRe.\nTry adjusting Results Selection parameters to soften selection constraints (e.g. setting Context sensitivity threshold to 0)."
tuples = [(msg, None)]
return [
tuples,
gr.DownloadButton(
label="πŸ“‚ Download output",
value=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
visible=True,
),
gr.DownloadButton(
label="πŸ” Download HTML",
value=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
visible=True,
)
]
@spaces.GPU()
def preload_model(
model_name_or_path: str,
attribution_method: str,
model_kwargs: str,
tokenizer_kwargs: str,
):
global loaded_model
if loaded_model is None or model_name_or_path != loaded_model.model_name:
gr.Info("Loading model...")
loaded_model = HuggingfaceModel.load(
model_name_or_path,
attribution_method,
model_kwargs=json.loads(model_kwargs),
tokenizer_kwargs=json.loads(tokenizer_kwargs),
)
with gr.Blocks(css=custom_css) as demo:
with gr.Row():
with gr.Column(scale=0.1, min_width=100):
gr.HTML(f'<img src="file/img/pecore_logo_white_contour.png" width=100px />')
with gr.Column(scale=0.8):
gr.Markdown(title)
gr.Markdown(subtitle)
with gr.Column(scale=0.1, min_width=100):
gr.HTML(f'<img src="file/img/pecore_logo_white_contour.png" width=100px />')
gr.Markdown(description)
with gr.Tab("πŸ‘ Demo"):
with gr.Row():
with gr.Column():
input_context_text = gr.Textbox(
label="Input context", lines=3, placeholder="Your input context..."
)
input_current_text = gr.Textbox(
label="Input query", placeholder="Your input query..."
)
with gr.Row(equal_height=True):
show_code_btn = gr.Button("Show code", variant="secondary")
attribute_input_button = gr.Button("Run PECoRe", variant="primary")
with gr.Column():
pecore_output_highlights = HighlightedTextbox(
value=[
("This output will contain ", None),
("context sensitive", "Context sensitive"),
(" generated tokens and ", None),
("influential context", "Influential context"),
(" tokens.", None),
],
color_map={
"Context sensitive": "#5fb77d",
"Influential context": "#80ace8",
},
show_legend=True,
label="PECoRe Output",
combine_adjacent=True,
interactive=False,
)
with gr.Row(equal_height=True):
download_output_file_button = gr.DownloadButton(
"πŸ“‚ Download output",
visible=False,
)
download_output_html_button = gr.DownloadButton(
"πŸ” Download HTML",
visible=False,
value=os.path.join(
os.path.dirname(__file__), "outputs/output.html"
),
)
preset_comment = gr.Markdown(
"<i>The <a href='https://huggingface.co/gsarti/cora_mgen' target='_blank'>CORA Multilingual QA</a> model by <a href='https://openreview.net/forum?id=e8blYRui3j' target='_blank'>Asai et al. (2021)</a> is set as default and can be used with the examples below. Explore other presets in the βš™οΈ Parameters tab.</i>"
)
attribute_input_examples = gr.Examples(
examples,
inputs=[input_current_text, input_context_text],
examples_per_page=1,
)
with gr.Tab("βš™οΈ Parameters") as params_tab:
gr.Markdown(
"## ✨ Presets\nSelect a preset to load the selected model and its default parameters (e.g. prompt template, special tokens, etc.) into the fields below.<br>⚠️ **This will overwrite existing parameters. If you intend to use large models that could crash the demo, please clone this Space and allocate appropriate resources for them to run comfortably.**"
)
check_enable_large_models = gr.Checkbox(False, label = "I understand, enable large models presets")
with gr.Row(equal_height=True):
with gr.Column():
default_preset = gr.Button("Default", variant="secondary")
gr.Markdown(
"Default preset using templates without special tokens or parameters.\nCan be used with most decoder-only and encoder-decoder models."
)
with gr.Column():
cora_preset = gr.Button("CORA mQA", variant="secondary")
gr.Markdown(
"Preset for the <a href='https://huggingface.co/gsarti/cora_mgen' target='_blank'>CORA Multilingual QA</a> model.\nUses special templates for inputs."
)
with gr.Column():
zephyr_preset = gr.Button("Zephyr Template", variant="secondary", interactive=False)
gr.Markdown(
"Preset for models using the <a href='https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b' target='_blank'>StableLM 2 Zephyr conversational template</a>.\nUses <code><|system|></code>, <code><|user|></code> and <code><|assistant|></code> special tokens."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
multilingual_mt_template = gr.Button(
"Multilingual MT", variant="secondary"
)
gr.Markdown(
"Preset for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
)
with gr.Column(scale=1):
chatml_template = gr.Button("Qwen ChatML", variant="secondary")
gr.Markdown(
"Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
)
with gr.Column(scale=1):
towerinstruct_template = gr.Button(
"Unbabel TowerInstruct", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for models using the <a href='https://huggingface.co/Unbabel/TowerInstruct-7B-v0.1' target='_blank'>Unbabel TowerInstruct</a> conversational template.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gemma_template = gr.Button(
"Gemma Chat Template", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for <a href='https://huggingface.co/google/gemma-2b-it' target='_blank'>Gemma</a> instruction-tuned models."
)
with gr.Column(scale=1):
mistral_instruct_template = gr.Button(
"Mistral Instruct", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for models using the <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2' target='_blank'>Mistral Instruct template</a>.\nUses <code>[INST]...[/INST]</code> special tokens."
)
gr.Markdown("## βš™οΈ PECoRe Parameters")
with gr.Row(equal_height=True):
with gr.Column():
model_name_or_path = gr.Textbox(
value="gsarti/cora_mgen",
label="Model",
info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
interactive=True,
)
load_model_button = gr.Button(
"Load model",
variant="secondary",
)
context_sensitivity_metric = gr.Dropdown(
value="kl_divergence",
label="Context sensitivity metric",
info="Metric to use to measure context sensitivity of generated tokens.",
choices=[
"probability",
"logit",
"kl_divergence",
"contrast_logits_diff",
"contrast_prob_diff",
"pcxmi"
],
interactive=True,
)
attribution_method = gr.Dropdown(
value="saliency",
label="Attribution method",
info="Attribution method identifier to identify relevant context tokens.",
choices=[
"saliency",
"input_x_gradient",
"value_zeroing",
],
interactive=True,
)
attributed_fn = gr.Dropdown(
value="contrast_prob_diff",
label="Attributed function",
info="Function of model logits to use as target for the attribution method.",
choices=[
"probability",
"logit",
"contrast_logits_diff",
"contrast_prob_diff",
],
interactive=True,
)
gr.Markdown("#### Results Selection Parameters")
with gr.Row(equal_height=True):
context_sensitivity_std_threshold = gr.Number(
value=0.0,
label="Context sensitivity threshold",
info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
context_sensitivity_topk = gr.Number(
value=0,
label="Context sensitivity top-k",
info="Select N to keep top N context sensitive tokens. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=10,
)
attribution_std_threshold = gr.Number(
value=2.0,
label="Attribution threshold",
info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
attribution_topk = gr.Number(
value=5,
label="Attribution top-k",
info="Select N to keep top N attributed tokens in the context. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=100,
)
gr.Markdown("#### Text Format Parameters")
with gr.Row(equal_height=True):
input_template = gr.Textbox(
value="<Q>:{current} <P>:{context}",
label="Contextual input template",
info="Template to format the input for the model. Use {current} and {context} placeholders for Input Query and Input Context, respectively.",
interactive=True,
)
output_template = gr.Textbox(
value="{current}",
label="Contextual output template",
info="Template to format the output from the model. Use {current} and {context} placeholders for Generation Output and Generation Context, respectively.",
interactive=True,
)
contextless_input_template = gr.Textbox(
value="<Q>:{current}",
label="Contextless input template",
info="Template to format the input query in the non-contextual setting. Use {current} placeholder for Input Query.",
interactive=True,
)
contextless_output_template = gr.Textbox(
value="{current}",
label="Contextless output template",
info="Template to format the output from the model. Use {current} placeholder for Generation Output.",
interactive=True,
)
with gr.Row(equal_height=True):
special_tokens_to_keep = gr.Dropdown(
label="Special tokens to keep",
info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
value=None,
multiselect=True,
allow_custom_value=True,
)
decoder_input_output_separator = gr.Textbox(
label="Decoder input/output separator",
info="Separator to use between input and output in the decoder input.",
value="",
interactive=True,
lines=1,
)
gr.Markdown("## βš™οΈ Generation Parameters")
with gr.Row(equal_height=True):
with gr.Column(scale=0.5):
gr.Markdown(
"The following arguments can be used to control generation parameters and force specific model outputs."
)
with gr.Column(scale=1):
generation_kwargs = gr.Code(
value="{}",
language="json",
label="Generation kwargs (JSON)",
interactive=True,
lines=1,
)
with gr.Row(equal_height=True):
output_current_text = gr.Textbox(
label="Generation output",
info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.",
interactive=True,
)
output_context_text = gr.Textbox(
label="Generation context",
info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
interactive=True,
)
gr.Markdown("## βš™οΈ Other Parameters")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown(
"The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method."
)
with gr.Column():
model_kwargs = gr.Code(
value="{}",
language="json",
label="Model kwargs (JSON)",
interactive=True,
lines=1,
min_width=160,
)
with gr.Column():
tokenizer_kwargs = gr.Code(
value="{}",
language="json",
label="Tokenizer kwargs (JSON)",
interactive=True,
lines=1,
)
with gr.Column():
attribution_kwargs = gr.Code(
value='{\n\t"logprob": true\n}',
language="json",
label="Attribution kwargs (JSON)",
interactive=True,
lines=1,
)
with gr.Tab("πŸ” How Does It Work?"):
gr.Markdown(how_it_works_intro)
with gr.Row(equal_height=True):
with gr.Column(scale=0.60):
gr.Markdown(cti_explanation)
with gr.Column(scale=0.30):
gr.HTML('<img src="file/img/cti_white_outline.png" width=100% />')
with gr.Row(equal_height=True):
with gr.Column(scale=0.35):
gr.HTML('<img src="file/img/cci_white_outline.png" width=100% />')
with gr.Column(scale=0.65):
gr.Markdown(cci_explanation)
with gr.Tab("πŸ”§ Usage Guide"):
gr.Markdown(how_to_use)
gr.Markdown(example_explanation)
with gr.Tab("πŸ“š Citing PECoRe"):
gr.Markdown("To refer to the PECoRe framework for context usage detection, cite:")
gr.Code(pecore_citation, interactive=False, label="PECoRe (Sarti et al., 2024)")
gr.Markdown("If you use the Inseq implementation of PECoRe (<a href=\"https://inseq.org/en/latest/main_classes/cli.html#attribute-context\"><code>inseq attribute-context</code></a>, including this demo), please also cite:")
gr.Code(inseq_citation, interactive=False, label="Inseq (Sarti et al., 2023)")
with gr.Row(elem_classes="footer-container"):
gr.Markdown(powered_by)
gr.Markdown(support)
with Modal(visible=False) as code_modal:
gr.Markdown(show_code_modal)
with gr.Row(equal_height=True):
python_code_snippet = gr.Code(
value="""Generate Python code snippet by pressing the button.""",
language="python",
label="Python",
interactive=False,
show_label=True,
)
shell_code_snippet = gr.Code(
value="""Generate Shell code snippet by pressing the button.""",
language="shell",
label="Shell",
interactive=False,
show_label=True,
)
# Main logic
load_model_args = [
model_name_or_path,
attribution_method,
model_kwargs,
tokenizer_kwargs,
]
pecore_args = [
input_current_text,
input_context_text,
output_current_text,
output_context_text,
model_name_or_path,
attribution_method,
attributed_fn,
context_sensitivity_metric,
context_sensitivity_std_threshold,
context_sensitivity_topk,
attribution_std_threshold,
attribution_topk,
input_template,
output_template,
contextless_input_template,
contextless_output_template,
special_tokens_to_keep,
decoder_input_output_separator,
model_kwargs,
tokenizer_kwargs,
generation_kwargs,
attribution_kwargs,
]
attribute_input_button.click(
lambda *args: [gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)],
inputs=[],
outputs=[download_output_file_button, download_output_html_button],
).then(
pecore,
inputs=pecore_args,
outputs=[
pecore_output_highlights,
download_output_file_button,
download_output_html_button,
],
)
load_model_event = load_model_button.click(
preload_model,
inputs=load_model_args,
outputs=[],
)
# Preset params
check_enable_large_models.input(
lambda checkbox, *buttons: [gr.Button(interactive=checkbox) for _ in buttons],
inputs=[check_enable_large_models, zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template],
outputs=[zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template],
)
outputs_to_reset = [
model_name_or_path,
input_template,
output_template,
contextless_input_template,
contextless_output_template,
special_tokens_to_keep,
decoder_input_output_separator,
model_kwargs,
tokenizer_kwargs,
generation_kwargs,
attribution_kwargs,
]
reset_kwargs = {
"fn": set_default_preset,
"inputs": None,
"outputs": outputs_to_reset,
}
# Presets
default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args, cancels=load_model_event)
cora_preset.click(**reset_kwargs).then(
set_cora_preset,
outputs=[model_name_or_path, input_template, contextless_input_template],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
zephyr_preset.click(**reset_kwargs).then(
set_zephyr_preset,
outputs=[
model_name_or_path,
input_template,
contextless_input_template,
decoder_input_output_separator,
special_tokens_to_keep,
],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
multilingual_mt_template.click(**reset_kwargs).then(
set_mmt_preset,
outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
chatml_template.click(**reset_kwargs).then(
set_chatml_preset,
outputs=[
model_name_or_path,
input_template,
contextless_input_template,
decoder_input_output_separator,
special_tokens_to_keep,
],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
towerinstruct_template.click(**reset_kwargs).then(
set_towerinstruct_preset,
outputs=[
model_name_or_path,
input_template,
contextless_input_template,
decoder_input_output_separator,
special_tokens_to_keep,
],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
gemma_template.click(**reset_kwargs).then(
set_gemma_preset,
outputs=[
model_name_or_path,
input_template,
contextless_input_template,
decoder_input_output_separator,
special_tokens_to_keep,
],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
mistral_instruct_template.click(**reset_kwargs).then(
set_mistral_instruct_preset,
outputs=[
model_name_or_path,
input_template,
contextless_input_template,
decoder_input_output_separator,
],
).success(preload_model, inputs=load_model_args, cancels=load_model_event)
show_code_btn.click(
update_code_snippets_fn,
inputs=pecore_args,
outputs=[python_code_snippet, shell_code_snippet],
).then(lambda: Modal(visible=True), None, code_modal)
demo.queue(api_open=False, max_size=20).launch(allowed_paths=["outputs/", "img/"], show_api=False)