Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
|
|
10 |
import torch
|
11 |
from datasets import load_dataset
|
12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
|
|
|
13 |
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
|
14 |
from interpret import InterpretationPrompt
|
15 |
from configs import model_info, dataset_info
|
@@ -27,6 +28,7 @@ class LocalState:
|
|
27 |
class GlobalState:
|
28 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
29 |
model : Optional[PreTrainedModel] = None
|
|
|
30 |
local_state : LocalState = LocalState()
|
31 |
wait_with_hidden_state : bool = False
|
32 |
interpretation_prompt_template : str = '{prompt}'
|
@@ -49,7 +51,7 @@ suggested_interpretation_prompts = [
|
|
49 |
def initialize_gpu():
|
50 |
pass
|
51 |
|
52 |
-
def reset_model(model_name, *extra_components, with_extra_components=True):
|
53 |
# extract model info
|
54 |
model_args = deepcopy(model_info[model_name])
|
55 |
model_path = model_args.pop('model_path')
|
@@ -66,6 +68,9 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
66 |
global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
|
67 |
gc.collect()
|
68 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
|
|
|
|
|
|
|
69 |
if not dont_cuda:
|
70 |
global_state.model.to('cuda')
|
71 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
@@ -131,8 +136,11 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
131 |
**generation_kwargs)
|
132 |
generation_texts = tokenizer.batch_decode(generated)
|
133 |
|
|
|
|
|
|
|
|
|
134 |
# create GUI output
|
135 |
-
important_idxs = 1 + ((interpreted_vectors - hidden_means)).diff(dim=0).norm(dim=-1).topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
|
136 |
print(f'{important_idxs=}')
|
137 |
progress_dummy_output = ''
|
138 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|
|
|
10 |
import torch
|
11 |
from datasets import load_dataset
|
12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
|
15 |
from interpret import InterpretationPrompt
|
16 |
from configs import model_info, dataset_info
|
|
|
28 |
class GlobalState:
|
29 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
30 |
model : Optional[PreTrainedModel] = None
|
31 |
+
sentence_transformer: Optional[PreTrainedModel] = None
|
32 |
local_state : LocalState = LocalState()
|
33 |
wait_with_hidden_state : bool = False
|
34 |
interpretation_prompt_template : str = '{prompt}'
|
|
|
51 |
def initialize_gpu():
|
52 |
pass
|
53 |
|
54 |
+
def reset_model(model_name, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
|
55 |
# extract model info
|
56 |
model_args = deepcopy(model_info[model_name])
|
57 |
model_path = model_args.pop('model_path')
|
|
|
68 |
global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
|
69 |
gc.collect()
|
70 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
|
71 |
+
if reset_sentence_transformer:
|
72 |
+
global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
|
73 |
+
gc.collect()
|
74 |
if not dont_cuda:
|
75 |
global_state.model.to('cuda')
|
76 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
|
|
136 |
**generation_kwargs)
|
137 |
generation_texts = tokenizer.batch_decode(generated)
|
138 |
|
139 |
+
# try identifying important layers
|
140 |
+
diff_score = F.normalize(global_state.sentence_transformer.encode(generation_texts), dim=-1).diff(dim=0)
|
141 |
+
important_idxs = 1 + diff_score.topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
|
142 |
+
|
143 |
# create GUI output
|
|
|
144 |
print(f'{important_idxs=}')
|
145 |
progress_dummy_output = ''
|
146 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|