File size: 4,919 Bytes
81e81d9 1f6dcbd 81e81d9 9a230a0 81e81d9 26e7a03 7673f3b 26e7a03 7673f3b 9a230a0 a6550e6 25fb83c b76e9de 25fb83c e8236ff ecd8b9a 086c317 276d7cf 8510266 11779fd 655bead ecd8b9a 7673f3b 9a230a0 e8236ff 7673f3b 9a230a0 26e7a03 9a230a0 26e7a03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import torch
llama_layers_format = 'model.layers.{k}'
gpt_layers_format = 'transformer.h.{k}'
dataset_info = [
{'name': 'Common Sense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
'filter': lambda x: x['label'] == 1},
# {'name': 'Physical Understanding', 'hf_repo': 'piqa', 'text_col': 'goal'},
{'name': 'Social Reasoning', 'hf_repo': 'ProlificAI/social-reasoning-rlhf', 'text_col': 'question'}
]
model_info = {
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', token=os.environ['hf_token'],
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
layers_format=llama_layers_format),
'LLAMA2-13B': dict(model_path='meta-llama/Llama-2-13b-chat-hf',
token=os.environ['hf_token'], torch_dtype=torch.float16,
wait_with_hidden_states=True,
# device_map='auto', max_memory={0: "15GB", 1: "30GB"}, dont_cuda=True, # load_in_8bit=True,
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
layers_format=llama_layers_format),
'Mixtral 8x7B Instruct (Experimental)': dict(model_path='TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ',
token=os.environ['hf_token'], wait_with_hidden_states=True,
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
layers_format=llama_layers_format
),
'CodeLLAMA 70B Instruct (Experimental)': dict(model_path='TheBloke/CodeLlama-70B-Instruct-GPTQ',
token=os.environ['hf_token'],
wait_with_hidden_states=True, dont_cuda=True, device_map='cuda', # disable_exllama=True,
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
layers_format=llama_layers_format
),
'GPT-2 Small': dict(model_path='gpt2', original_prompt_template='{prompt}',
interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
layers_format=gpt_layers_format),
# 'GPT-2 Medium': dict(model_path='gpt2-medium', original_prompt_template='{prompt}',
# interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
# layers_format=gpt_layers_format),
# 'GPT-2 Large': dict(model_path='gpt2-large', original_prompt_template='{prompt}',
# interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
# layers_format=gpt_layers_format),
# 'GPT-2 XL': dict(model_path='gpt2-xl', original_prompt_template='{prompt}',
# interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
# layers_format=gpt_layers_format),
'GPT-J 6B': dict(model_path='EleutherAI/gpt-j-6b', original_prompt_template='{prompt}',
interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
layers_format=gpt_layers_format),
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
layers_format=llama_layers_format),
# 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
# original_prompt_template='<bos>{prompt}',
# interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
# ),
# 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
# tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
# model_type='llama', hf=True, ctransformers=True,
# original_prompt_template='<s>[INST] {prompt} [/INST]',
# interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
# )
}
|