import panel as pn import pandas as pd import torch import numpy as np from transformers import GPT2LMHeadModel, GPT2Tokenizer import sys import pyvene as pv from pyvene import embed_to_distrib, format_token from pyvene import RepresentationConfig, IntervenableConfig, IntervenableModel from pyvene import VanillaIntervention pn.extension(sizing_mode="stretch_width") # Initialize model and tokenizer tokenizer = GPT2Tokenizer.from_pretrained('gpt2') gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') num_layers = gpt2.config.n_layer # Set padding token for the tokenizer if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token gpt2.config.pad_token_id = tokenizer.eos_token_id device = 'cuda' if torch.cuda.is_available() else 'cpu' gpt2.to(device) # Monkey patch the embed_to_distrib function to use the correct attribute def patched_embed_to_distrib(model, embed, log=True, logits=True): if "gpt2" in model.config.architectures[0].lower(): with torch.inference_mode(): vocab = torch.matmul(embed, model.transformer.wte.weight.t()) if logits: return vocab if log: return torch.log_softmax(vocab, dim=-1) return torch.softmax(vocab, dim=-1) else: return pv.embed_to_distrib(model, embed, log, logits) pv.embed_to_distrib = patched_embed_to_distrib def simple_position_config(model_type, component, layer): config = IntervenableConfig( model_type=model_type, representations=[ RepresentationConfig( layer, # layer component, # component "pos", # intervention unit 1, # max number of unit ), ], intervention_types=VanillaIntervention, ) return config def process_sentences(base_sentence, rival_sentence): base = tokenizer(base_sentence, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device) rival = tokenizer(rival_sentence, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device) tokens = tokenizer.encode(" True False") data = [] with torch.no_grad(): base_outputs = gpt2(**base, output_hidden_states=True) # Use the last hidden state from the output last_hidden_state = base_outputs.hidden_states[-1] distrib_base = pv.embed_to_distrib(gpt2, last_hidden_state, logits=False) logprob_true_base = np.log(float(distrib_base[0][-1][tokens[0]])) logprob_false_base = np.log(float(distrib_base[0][-1][tokens[1]])) base_tokens = tokenizer.convert_ids_to_tokens(base.input_ids[0]) if logprob_true_base - logprob_false_base > 0: for layer_i in range(num_layers): for component in ["attention_input"]: try: config = simple_position_config(type(gpt2), component, layer_i) intervenable = IntervenableModel(config, gpt2).to(device) max_length = min(base.input_ids.shape[1], rival.input_ids.shape[1]) for pos_i in range(max_length): base_input = {key: val[:, :max_length].to(device) for key, val in base.items()} rival_input = {key: val[:, :max_length].to(device) for key, val in rival.items()} _, counterfactual_outputs = intervenable( base_input, [rival_input], {"sources->base": pos_i} ) # Use the last hidden state from the counterfactual output last_hidden_state = counterfactual_outputs.hidden_states[-1] distrib = pv.embed_to_distrib(gpt2, last_hidden_state, logits=False) for token in tokens: data.append({ "token": format_token(tokenizer, token), "prob": float(distrib[0][-1][token]), "layer": f"a{layer_i}", "pos": base_tokens[pos_i] if pos_i < len(base_tokens) else "[PAD]", "type": component, }) except Exception as e: print(f"Error in layer {layer_i}, component {component}: {str(e)}") continue return pd.DataFrame(data) async def process_inputs(base_sentence: str, rival_sentence: str): try: main.disabled = True if not base_sentence or not rival_sentence: yield "##### ⚠️ Please provide both base and rival sentences" return yield "##### ⚙ Processing sentences and running model..." try: result_df = process_sentences(base_sentence, rival_sentence) except Exception as e: yield f"##### 😔 Something went wrong, please try different sentences! Error: {str(e)}" return # build the results column results = pn.Column("##### 🎉 Here are the results!") # Display the DataFrame results.append(pn.pane.DataFrame(result_df)) yield results finally: main.disabled = False # create widgets base_sentence = pn.widgets.TextInput( name="Base Sentence", placeholder="Enter the base sentence", value="Jane got some weird looks because she wore sunglasses outside at 4 PM.", ) rival_sentence = pn.widgets.TextInput( name="Rival Sentence", placeholder="Enter the rival sentence", value="Jane got some weird looks because she wore sunglasses outside at 4 AM.", ) input_widgets = pn.Column( "##### 😊 Enter base and rival sentences to start comparing!", base_sentence, rival_sentence, ) # add interactivity interactive_result = pn.panel( pn.bind(process_inputs, base_sentence=base_sentence, rival_sentence=rival_sentence), height=600, ) # create dashboard main = pn.WidgetBox( input_widgets, interactive_result, ) title = "Sentence Comparison Demo" pn.template.BootstrapTemplate( title=title, main=main, main_max_width="min(80%, 1200px)", header_background="#4B0082", ).servable(title=title)