dar-tau commited on
Commit
e4c230b
·
verified ·
1 Parent(s): 7520e6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
  import torch
11
  from datasets import load_dataset
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
 
13
  from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
14
  from interpret import InterpretationPrompt
15
  from configs import model_info, dataset_info
@@ -27,6 +28,7 @@ class LocalState:
27
  class GlobalState:
28
  tokenizer : Optional[PreTrainedTokenizer] = None
29
  model : Optional[PreTrainedModel] = None
 
30
  local_state : LocalState = LocalState()
31
  wait_with_hidden_state : bool = False
32
  interpretation_prompt_template : str = '{prompt}'
@@ -49,7 +51,7 @@ suggested_interpretation_prompts = [
49
  def initialize_gpu():
50
  pass
51
 
52
- def reset_model(model_name, *extra_components, with_extra_components=True):
53
  # extract model info
54
  model_args = deepcopy(model_info[model_name])
55
  model_path = model_args.pop('model_path')
@@ -66,6 +68,9 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
66
  global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
67
  gc.collect()
68
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
 
 
 
69
  if not dont_cuda:
70
  global_state.model.to('cuda')
71
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
@@ -131,8 +136,11 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
131
  **generation_kwargs)
132
  generation_texts = tokenizer.batch_decode(generated)
133
 
 
 
 
 
134
  # create GUI output
135
- important_idxs = 1 + ((interpreted_vectors - hidden_means)).diff(dim=0).norm(dim=-1).topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
136
  print(f'{important_idxs=}')
137
  progress_dummy_output = ''
138
  elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
 
10
  import torch
11
  from datasets import load_dataset
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
13
+ from sentence_transformers import SentenceTransformer
14
  from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
15
  from interpret import InterpretationPrompt
16
  from configs import model_info, dataset_info
 
28
  class GlobalState:
29
  tokenizer : Optional[PreTrainedTokenizer] = None
30
  model : Optional[PreTrainedModel] = None
31
+ sentence_transformer: Optional[PreTrainedModel] = None
32
  local_state : LocalState = LocalState()
33
  wait_with_hidden_state : bool = False
34
  interpretation_prompt_template : str = '{prompt}'
 
51
  def initialize_gpu():
52
  pass
53
 
54
+ def reset_model(model_name, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
55
  # extract model info
56
  model_args = deepcopy(model_info[model_name])
57
  model_path = model_args.pop('model_path')
 
68
  global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
69
  gc.collect()
70
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
71
+ if reset_sentence_transformer:
72
+ global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
73
+ gc.collect()
74
  if not dont_cuda:
75
  global_state.model.to('cuda')
76
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
 
136
  **generation_kwargs)
137
  generation_texts = tokenizer.batch_decode(generated)
138
 
139
+ # try identifying important layers
140
+ diff_score = F.normalize(global_state.sentence_transformer.encode(generation_texts), dim=-1).diff(dim=0)
141
+ important_idxs = 1 + diff_score.topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
142
+
143
  # create GUI output
 
144
  print(f'{important_idxs=}')
145
  progress_dummy_output = ''
146
  elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +