File size: 6,951 Bytes
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Web app page for showing codes for different examples in the dataset."""


import streamlit as st
from streamlit_extras.switch_page_button import switch_page

import code_search_utils
import webapp_utils

webapp_utils.load_widget_state()

if "cb_acts" not in st.session_state:
    switch_page("Code_Browser")

total_examples = 2000
prec_threshold = 0.01

model_name = st.session_state["model_name_id"]
seq_len = st.session_state["seq_len"]
tokens_text = st.session_state["tokens_text"]
tokens_str = st.session_state["tokens_str"]
cb_acts = st.session_state["cb_acts"]
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
ccb = st.session_state["ccb"]


def get_example_concept_codes(example_id):
    """Get concept codes for the given example id."""
    token_pos_ids = [(example_id, i) for i in range(seq_len)]
    all_codes = []
    for cb_name, cb in cb_acts.items():
        base_cb_name = code_search_utils.convert_to_base_name(cb_name, ccb=ccb)
        codes, prec, rec, code_acts = code_search_utils.get_code_pr(
            token_pos_ids,
            cb,
            act_count_ft_tkns[base_cb_name],
        )
        prec_sat_idx = prec >= prec_threshold
        codes, prec, rec, code_acts = (
            codes[prec_sat_idx],
            prec[prec_sat_idx],
            rec[prec_sat_idx],
            code_acts[prec_sat_idx],
        )
        rec_sat_idx = rec >= recall_threshold
        codes, prec, rec, code_acts = (
            codes[rec_sat_idx],
            prec[rec_sat_idx],
            rec[rec_sat_idx],
            code_acts[rec_sat_idx],
        )
        codes_pr = list(zip(codes, prec, rec, code_acts))
        all_codes.append((cb_name, codes_pr))
    return all_codes


def find_next_example(example_id):
    """Find the example after `example_id` that has concept codes."""
    initial_example_id = example_id
    example_id += 1
    while example_id != initial_example_id:
        all_codes = get_example_concept_codes(example_id)
        codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
        if codes_found > 0:
            st.session_state["example_id"] = example_id
            return
        example_id = (example_id + 1) % total_examples
    st.error(
        f"No examples found at the specified recall threshold: {recall_threshold}.",
        icon="🚨",
    )


def redirect_to_main_with_code(code, layer, head):
    """Redirect to main page with the given code."""
    st.session_state["ct_act_code"] = code
    st.session_state["ct_act_layer"] = layer
    if st.session_state["is_attn"]:
        st.session_state["ct_act_head"] = head
    switch_page("Code Browser")


def show_examples_for_concept_code(code, layer, head, code_act_ratio=0.3):
    """Show examples that the code activates on."""
    ex_acts, _ = webapp_utils.get_code_acts(
        model_name,
        tokens_str,
        code,
        layer,
        head,
        ctx_size=5,
        return_example_list=True,
    )
    filt_ex_acts = []
    for act_str, num_acts in ex_acts:
        if num_acts > seq_len * code_act_ratio:
            filt_ex_acts.append(act_str)
    st.markdown("#### Examples for Code")
    st.markdown(
        webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
    )


is_attn = st.session_state["is_attn"]

st.markdown("## Concept Code")
concept_code_description = (
    "Concept codes are codes that activate a lot on only a particular set of examples that share a concept. "
    "Hence such codes can be thought to correspond to more higher-level concepts or features and "
    "can activate on most tokens that belong in an example text. This interface provides a way to search for such "
    "codes by going through different examples using Example ID."
)
st.write(concept_code_description)

# ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
example_id = ex_col.number_input(
    "Example ID",
    0,
    total_examples - 1,
    0,
    key="example_id",
)
# prec_threshold = p_col.slider(
#     "Precision Threshold",
#     0.0,
#     1.0,
#     0.02,
#     key="prec",
#     help="Precision Threshold controls the specificity of the codes for the given example.",
# )
recall_threshold = r_col.slider(
    "Recall Threshold",
    0.0,
    1.0,
    0.3,
    key="recall",
    help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
)
example_truncation = trunc_col.number_input(
    "Max Output Chars", 0, 10240, 1024, key="max_chars"
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
    "Sort By",
    sort_by_options,
    index=0,
    horizontal=True,
    help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)


button = st.button(
    "Find Next Example",
    key="find_next_example",
    on_click=find_next_example,
    args=(example_id,),
    help="Find an example which has codes above the recall threshold.",
)
# if button:
#     find_next_example(st.session_state["example_id"])


st.markdown("### Example Text")
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)

cols = st.columns(7 if is_attn else 6)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
    cols[2].write("Head")
cols[-4].write("Code")
cols[-3].write("Precision")
cols[-2].write("Recall")
cols[-1].markdown(
    "Num Acts",
    help="Number of tokens that the code activates on in the acts dataset.",
)

all_codes = get_example_concept_codes(example_id)
all_codes = [
    (cb_name, code_pr_info)
    for cb_name, code_pr_infos in all_codes
    for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)

for cb_name, (code, p, r, acts) in all_codes:
    cols = st.columns(7 if is_attn else 6)
    code_button = cols[0].button(
        "πŸ”",
        key=f"ex-code-{code}-{cb_name}",
    )
    layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
    cols[1].write(str(layer))
    if is_attn:
        cols[2].write(str(head))

    cols[-4].write(code)
    cols[-3].write(f"{p*100:.2f}%")
    cols[-2].write(f"{r*100:.2f}%")
    cols[-1].write(str(acts))

    if code_button:
        show_examples_for_concept_code(
            code,
            layer,
            head,
            code_act_ratio=recall_threshold,
        )
if len(all_codes) == 0:
    st.markdown(
        f"<div style='text-align:center'>No codes found at recall threshold: {recall_threshold}</div>",
        unsafe_allow_html=True,
    )