tetrisd commited on
Commit
8ad9dbd
·
1 Parent(s): 306e122

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +113 -0
  2. probe.pt +3 -0
  3. requirements.txt +3 -0
  4. scrollbar.css +46 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Lock
2
+ import argparse
3
+
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+ import gradio as gr
7
+ import torch
8
+ import pandas as pd
9
+
10
+ from biasprobe import BinaryProbe, PairwiseExtractionRunner, SimplePairPromptBuilder, ProbeConfig
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
16
+ parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
17
+ parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
18
+ args = parser.parse_args()
19
+ return args
20
+
21
+
22
+ def main():
23
+ args = get_args()
24
+ plt.switch_backend('agg')
25
+ dmap = 'auto'
26
+ mdict = {0: '24GIB'}
27
+ config = ProbeConfig.create_for_model('mistralai/Mistral-7B-Instruct-v0.1')
28
+ probe = BinaryProbe(config).cuda()
29
+ probe.load_state_dict(torch.load('probe.pt'))
30
+
31
+ runner = PairwiseExtractionRunner.from_pretrained('mistralai/Mistral-7B-Instruct-v0.1', optimize=True, max_memory=mdict, device_map=dmap, low_cpu_mem_usage=True)
32
+ device = "cpu" if args.no_cuda else "cuda"
33
+ lock = Lock()
34
+
35
+ @torch.no_grad()
36
+ def run_extraction(prompt):
37
+ builder = SimplePairPromptBuilder(criterion='more positive')
38
+ lst = [x.strip() for x in prompt.lower()[:300].split(',')][:100]
39
+ exp = runner.run_extraction(lst, lst, layers=[15], num_repeat=100, builder=builder, parallel=False, run_inference=True, debug=True, max_new_tokens=2)
40
+ test_ds = exp.make_dataset(15)
41
+
42
+ import torch
43
+
44
+ raw_scores = []
45
+ preds_list = []
46
+ hs = []
47
+
48
+ for idx, (tensor, labels) in enumerate(test_ds):
49
+ with torch.no_grad():
50
+ labels = labels - 1 # 1-indexed
51
+
52
+ if tensor.shape[0] != 2:
53
+ continue
54
+
55
+ h = tensor[1] - tensor[0]
56
+ hs.append(h)
57
+
58
+ try:
59
+ x = probe(tensor.unsqueeze(0).cuda().float()).squeeze()
60
+ except IndexError:
61
+ continue
62
+
63
+ pred = [0, 1] if x.item() > 0 else [1, 0]
64
+ pred = np.array(pred)
65
+
66
+ if test_ds.original_examples is not None:
67
+ items = [x.content for x in test_ds.original_examples[idx].hits]
68
+ preds_list.append(np.array(items, dtype=object)[labels][pred].tolist())
69
+
70
+ raw_scores.append(x.item())
71
+
72
+ df = pd.DataFrame({'Win Rate': np.array(raw_scores) > 0, 'Word': [x[0] for x in preds_list]})
73
+ win_df = df.groupby('Word').mean('Win Rate')
74
+ win_df = win_df.reset_index().sort_values('Win Rate')
75
+ win_df['Win Rate'] = [str(x) + '%' for x in (win_df['Win Rate'] * 100).round(2).tolist()]
76
+
77
+ return win_df
78
+
79
+ with gr.Blocks(css='scrollbar.css') as demo:
80
+ md = '''# BiasProbe: Revealing Preference Biases in Language Model Representations
81
+ What do llamas really "think"? Type some words below to see how Mistral-7B-Instruct associates them with
82
+ positive and negative emotions. Higher win rates indicate that the word is more likely to be associated with
83
+ positive emotions than other words in the list.
84
+
85
+ Check out our paper, [What Do Llamas Really Think? Revealing Preference Biases in Language Model Representations](http://arxiv.org/abs/2210.04885).
86
+ See our [codebase](https://github.com/castorini/biasprobe) on GitHub.
87
+ '''
88
+ gr.Markdown(md)
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ text = gr.Textbox(label='Words', value='Republican, democrat, libertarian, authoritarian')
93
+ submit_btn = gr.Button('Submit', elem_id='submit-btn')
94
+ output = gr.DataFrame(pd.DataFrame({'Word': ['authoritarian', 'republican', 'democrat', 'libertarian'],
95
+ 'Win Rate': ['44.44%', '81.82%', '100%', '100%']}))
96
+
97
+ submit_btn.click(
98
+ fn=run_extraction,
99
+ inputs=[text],
100
+ outputs=[output])
101
+
102
+ while True:
103
+ try:
104
+ demo.launch(server_name='0.0.0.0')
105
+ except OSError:
106
+ gr.close_all()
107
+ except KeyboardInterrupt:
108
+ gr.close_all()
109
+ break
110
+
111
+
112
+ if __name__ == '__main__':
113
+ main()
probe.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc369595d41f7a7339d4bd84790c7e117207087eb00b90762848eddcfb7a6c91
3
+ size 17659
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==3.36.1
2
+ biasprobe
3
+ flash-attn
scrollbar.css ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .output-html {
2
+ overflow-x: auto;
3
+ }
4
+
5
+ .output-html::-webkit-scrollbar {
6
+ -webkit-appearance: none;
7
+ }
8
+
9
+ .output-html::-webkit-scrollbar:vertical {
10
+ width: 0px;
11
+ }
12
+
13
+ .output-html::-webkit-scrollbar:horizontal {
14
+ height: 11px;
15
+ }
16
+
17
+ .output-html::-webkit-scrollbar-thumb {
18
+ border-radius: 8px;
19
+ border: 2px solid white;
20
+ background-color: rgba(0, 0, 0, .5);
21
+ }
22
+
23
+ .output-html::-webkit-scrollbar-track {
24
+ background-color: #fff;
25
+ border-radius: 8px;
26
+ }
27
+
28
+ .spans {
29
+ min-height: 75px;
30
+ }
31
+
32
+ svg {
33
+ margin: auto;
34
+ display: block;
35
+ }
36
+
37
+ #submit-btn {
38
+ z-index: 999;
39
+ }
40
+
41
+ #viz {
42
+ width: 100%;
43
+ top: -30px;
44
+ object-fit: scale-down;
45
+ object-position: 0 100%;
46
+ }