Spaces:
Sleeping
Sleeping
import panel as pn | |
import pandas as pd | |
import torch | |
import numpy as np | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import sys | |
import pyvene as pv | |
from pyvene import embed_to_distrib, format_token | |
from pyvene import RepresentationConfig, IntervenableConfig, IntervenableModel | |
from pyvene import VanillaIntervention | |
pn.extension(sizing_mode="stretch_width") | |
# Initialize model and tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') | |
num_layers = gpt2.config.n_layer | |
# Set padding token for the tokenizer | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
gpt2.config.pad_token_id = tokenizer.eos_token_id | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
gpt2.to(device) | |
# Monkey patch the embed_to_distrib function to use the correct attribute | |
def patched_embed_to_distrib(model, embed, log=True, logits=True): | |
if "gpt2" in model.config.architectures[0].lower(): | |
with torch.inference_mode(): | |
vocab = torch.matmul(embed, model.transformer.wte.weight.t()) | |
if logits: | |
return vocab | |
if log: | |
return torch.log_softmax(vocab, dim=-1) | |
return torch.softmax(vocab, dim=-1) | |
else: | |
return pv.embed_to_distrib(model, embed, log, logits) | |
pv.embed_to_distrib = patched_embed_to_distrib | |
def simple_position_config(model_type, component, layer): | |
config = IntervenableConfig( | |
model_type=model_type, | |
representations=[ | |
RepresentationConfig( | |
layer, # layer | |
component, # component | |
"pos", # intervention unit | |
1, # max number of unit | |
), | |
], | |
intervention_types=VanillaIntervention, | |
) | |
return config | |
def process_sentences(base_sentence, rival_sentence): | |
base = tokenizer(base_sentence, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device) | |
rival = tokenizer(rival_sentence, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device) | |
tokens = tokenizer.encode(" True False") | |
data = [] | |
with torch.no_grad(): | |
base_outputs = gpt2(**base, output_hidden_states=True) | |
# Use the last hidden state from the output | |
last_hidden_state = base_outputs.hidden_states[-1] | |
distrib_base = pv.embed_to_distrib(gpt2, last_hidden_state, logits=False) | |
logprob_true_base = np.log(float(distrib_base[0][-1][tokens[0]])) | |
logprob_false_base = np.log(float(distrib_base[0][-1][tokens[1]])) | |
base_tokens = tokenizer.convert_ids_to_tokens(base.input_ids[0]) | |
if logprob_true_base - logprob_false_base > 0: | |
for layer_i in range(num_layers): | |
for component in ["attention_input"]: | |
try: | |
config = simple_position_config(type(gpt2), component, layer_i) | |
intervenable = IntervenableModel(config, gpt2).to(device) | |
max_length = min(base.input_ids.shape[1], rival.input_ids.shape[1]) | |
for pos_i in range(max_length): | |
base_input = {key: val[:, :max_length].to(device) for key, val in base.items()} | |
rival_input = {key: val[:, :max_length].to(device) for key, val in rival.items()} | |
_, counterfactual_outputs = intervenable( | |
base_input, [rival_input], {"sources->base": pos_i} | |
) | |
# Use the last hidden state from the counterfactual output | |
last_hidden_state = counterfactual_outputs.hidden_states[-1] | |
distrib = pv.embed_to_distrib(gpt2, last_hidden_state, logits=False) | |
for token in tokens: | |
data.append({ | |
"token": format_token(tokenizer, token), | |
"prob": float(distrib[0][-1][token]), | |
"layer": f"a{layer_i}", | |
"pos": base_tokens[pos_i] if pos_i < len(base_tokens) else "[PAD]", | |
"type": component, | |
}) | |
except Exception as e: | |
print(f"Error in layer {layer_i}, component {component}: {str(e)}") | |
continue | |
return pd.DataFrame(data) | |
async def process_inputs(base_sentence: str, rival_sentence: str): | |
try: | |
main.disabled = True | |
if not base_sentence or not rival_sentence: | |
yield "##### β οΈ Please provide both base and rival sentences" | |
return | |
yield "##### β Processing sentences and running model..." | |
try: | |
result_df = process_sentences(base_sentence, rival_sentence) | |
except Exception as e: | |
yield f"##### π Something went wrong, please try different sentences! Error: {str(e)}" | |
return | |
# build the results column | |
results = pn.Column("##### π Here are the results!") | |
# Display the DataFrame | |
results.append(pn.pane.DataFrame(result_df)) | |
yield results | |
finally: | |
main.disabled = False | |
# create widgets | |
base_sentence = pn.widgets.TextInput( | |
name="Base Sentence", | |
placeholder="Enter the base sentence", | |
value="Jane got some weird looks because she wore sunglasses outside at 4 PM.", | |
) | |
rival_sentence = pn.widgets.TextInput( | |
name="Rival Sentence", | |
placeholder="Enter the rival sentence", | |
value="Jane got some weird looks because she wore sunglasses outside at 4 AM.", | |
) | |
input_widgets = pn.Column( | |
"##### π Enter base and rival sentences to start comparing!", | |
base_sentence, | |
rival_sentence, | |
) | |
# add interactivity | |
interactive_result = pn.panel( | |
pn.bind(process_inputs, base_sentence=base_sentence, rival_sentence=rival_sentence), | |
height=600, | |
) | |
# create dashboard | |
main = pn.WidgetBox( | |
input_widgets, | |
interactive_result, | |
) | |
title = "Sentence Comparison Demo" | |
pn.template.BootstrapTemplate( | |
title=title, | |
main=main, | |
main_max_width="min(80%, 1200px)", | |
header_background="#4B0082", | |
).servable(title=title) |