File size: 7,062 Bytes
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
7f9376c
 
cb21dfd
 
7f9376c
 
 
63b5bc1
 
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
7f9376c
 
 
cb21dfd
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
 
 
 
 
 
 
50c3f87
 
 
7f9376c
cb21dfd
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
7f9376c
 
 
 
63b5bc1
7f9376c
 
 
 
 
63b5bc1
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb21dfd
7f9376c
 
 
 
 
 
 
cb21dfd
 
7f9376c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""Web app page for showing codes for different examples in the dataset."""


import streamlit as st
from streamlit_extras.switch_page_button import switch_page

import code_search_utils
import webapp_utils

webapp_utils.load_widget_state()

if "cb_acts" not in st.session_state:
    switch_page("Code_Browser")

total_examples = 2000
prec_threshold = 0.01

model_name = st.session_state["model_name_id"]
seq_len = st.session_state["seq_len"]
tokens_text = st.session_state["tokens_text"]
tokens_str = st.session_state["tokens_str"]
cb_acts = st.session_state["cb_acts"]
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
gcb = st.session_state["gcb"]


def get_example_topic_codes(example_id):
    """Get topic codes for the given example id."""
    token_pos_ids = [(example_id, i) for i in range(seq_len)]
    all_codes = []
    for cb_name, cb in cb_acts.items():
        base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb)
        codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall(
            token_pos_ids,
            cb,
            act_count_ft_tkns[base_cb_name],
        )
        prec_sat_idx = prec >= prec_threshold
        codes, prec, rec, code_acts = (
            codes[prec_sat_idx],
            prec[prec_sat_idx],
            rec[prec_sat_idx],
            code_acts[prec_sat_idx],
        )
        rec_sat_idx = rec >= recall_threshold
        codes, prec, rec, code_acts = (
            codes[rec_sat_idx],
            prec[rec_sat_idx],
            rec[rec_sat_idx],
            code_acts[rec_sat_idx],
        )
        codes_pr = list(zip(codes, prec, rec, code_acts))
        all_codes.append((cb_name, codes_pr))
    return all_codes


def find_next_example(example_id):
    """Find the example after `example_id` that has topic codes."""
    initial_example_id = example_id
    example_id += 1
    while example_id != initial_example_id:
        all_codes = get_example_topic_codes(example_id)
        codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
        if codes_found > 0:
            st.session_state["example_id"] = example_id
            return
        example_id = (example_id + 1) % total_examples
    st.error(
        f"No examples found at the specified recall threshold: {recall_threshold}.",
        icon="🚨",
    )


def redirect_to_main_with_code(code, layer, head):
    """Redirect to main page with the given code."""
    st.session_state["ct_act_code"] = code
    st.session_state["ct_act_layer"] = layer
    if st.session_state["is_attn"]:
        st.session_state["ct_act_head"] = head
    switch_page("Code Browser")


def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3):
    """Show examples that the code activates on."""
    ex_acts, _ = webapp_utils.get_code_acts(
        model_name,
        tokens_str,
        code,
        layer,
        head,
        ctx_size=5,
        return_example_list=True,
    )
    filt_ex_acts = []
    for act_str, num_acts in ex_acts:
        if num_acts > seq_len * code_act_ratio:
            filt_ex_acts.append(act_str)
    st.markdown("#### Examples for Code")
    st.markdown(
        webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
    )


is_attn = st.session_state["is_attn"]

st.markdown("## Topic Code")
topic_code_description = (
    "Topic codes are codes that activate many different times on passages that describe a particular"
    " topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking"
    " at different examples in the dataset (ExampleID) and finding codes that activate on some fraction"
    " of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible"
    " topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at"
    " least one code firing on that example above the Recall Threshold.\n\n"
    "Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes"
    " for a different model, go to the Code Browser page and select a different model."
)
st.write(topic_code_description)

ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
example_id = ex_col.number_input(
    "Example ID",
    0,
    total_examples - 1,
    0,
    key="example_id",
)
recall_threshold = r_col.slider(
    "Recall Threshold",
    0.0,
    1.0,
    0.2,
    key="recall",
    help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
)
example_truncation = trunc_col.number_input(
    "Max Output Chars", 0, 102400, 1024, key="max_chars"
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
    "Sort By",
    sort_by_options,
    index=1,
    horizontal=True,
    help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)


button = st.button(
    "Find Next Example",
    key="find_next_example",
    on_click=find_next_example,
    args=(example_id,),
    help="Find an example which has codes above the recall threshold.",
)

st.markdown("### Example Text")
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)

cols = st.columns(7 if is_attn else 6)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
    cols[2].write("Head")
cols[-4].write("Code")
cols[-3].write("Precision")
cols[-2].write("Recall")
cols[-1].markdown(
    "Num Acts",
    help="Number of tokens that the code activates on in the acts dataset.",
)

all_codes = get_example_topic_codes(example_id)
all_codes = [
    (cb_name, code_pr_info)
    for cb_name, code_pr_infos in all_codes
    for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)

for cb_name, (code, p, r, acts) in all_codes:
    cols = st.columns(7 if is_attn else 6)
    code_button = cols[0].button(
        "🔍",
        key=f"ex-code-{code}-{cb_name}",
    )
    layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
    cols[1].write(str(layer))
    if is_attn:
        cols[2].write(str(head))

    cols[-4].write(code)
    cols[-3].write(f"{p*100:.2f}%")
    cols[-2].write(f"{r*100:.2f}%")
    cols[-1].write(str(acts))

    if code_button:
        show_examples_for_topic_code(
            code,
            layer,
            head,
            code_act_ratio=recall_threshold,
        )
if len(all_codes) == 0:
    st.markdown(
        f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}."
        " Consider decreasing the recall threshold.</div>",
        unsafe_allow_html=True,
    )