Guy24 commited on
Commit
1d82598
·
1 Parent(s): 0276ae6

adding application

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -240,15 +240,16 @@ def find_last_token_index(full_ids, word_ids):
240
  return None
241
 
242
  @GPU # this block runs on a job GPU
243
- def analyse_word(model_name: str, word: str, patchscopes_template: str):
244
  try:
 
245
  model, tokenizer = get_model_and_tokenizer(model_name)
246
 
247
  # Build extraction prompt (where hidden states will be collected)
248
  extraction_prompt ="X"
249
 
250
  # Identify last token position of the *word* inside the prompt IDs
251
- word_token_ids = tokenizer.encode(word, add_special_tokens=False)
252
 
253
  # Instantiate Patchscopes retriever
254
  patch_retriever = PatchscopesRetriever(
@@ -261,7 +262,7 @@ def analyse_word(model_name: str, word: str, patchscopes_template: str):
261
 
262
  # Run retrieval for the word across all layers (one pass)
263
  retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
264
- word,
265
  num_tokens_to_generate=len(tokenizer.tokenize(word)),
266
  )[0]
267
 
@@ -308,6 +309,7 @@ with gr.Blocks(theme="soft") as demo:
308
  label="Patchscopes prompt (use X as placeholder)",
309
  value="repeat the following word X twice: 1)X 2)",
310
  )
 
311
  word_box = gr.Textbox(label="Word to test", value="interpretable")
312
  run_btn = gr.Button("Analyse")
313
  out_html = gr.HTML()
 
240
  return None
241
 
242
  @GPU # this block runs on a job GPU
243
+ def analyse_word(model_name: str, word: str, patchscopes_template: str, context:str = ""):
244
  try:
245
+ text = context+ " " + word
246
  model, tokenizer = get_model_and_tokenizer(model_name)
247
 
248
  # Build extraction prompt (where hidden states will be collected)
249
  extraction_prompt ="X"
250
 
251
  # Identify last token position of the *word* inside the prompt IDs
252
+ word_token_ids = tokenizer.encode(text, add_special_tokens=False)
253
 
254
  # Instantiate Patchscopes retriever
255
  patch_retriever = PatchscopesRetriever(
 
262
 
263
  # Run retrieval for the word across all layers (one pass)
264
  retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
265
+ text,
266
  num_tokens_to_generate=len(tokenizer.tokenize(word)),
267
  )[0]
268
 
 
309
  label="Patchscopes prompt (use X as placeholder)",
310
  value="repeat the following word X twice: 1)X 2)",
311
  )
312
+ word_box = gr.Textbox(label="context", value="")
313
  word_box = gr.Textbox(label="Word to test", value="interpretable")
314
  run_btn = gr.Button("Analyse")
315
  out_html = gr.HTML()