Spaces:
Runtime error
Runtime error
File size: 7,062 Bytes
7f9376c 63b5bc1 7f9376c cb21dfd 7f9376c 63b5bc1 7f9376c cb21dfd 7f9376c cb21dfd 7f9376c cb21dfd 7f9376c cb21dfd 50c3f87 7f9376c cb21dfd 7f9376c cb21dfd 7f9376c 63b5bc1 7f9376c 63b5bc1 7f9376c cb21dfd 7f9376c cb21dfd 7f9376c cb21dfd 7f9376c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""Web app page for showing codes for different examples in the dataset."""
import streamlit as st
from streamlit_extras.switch_page_button import switch_page
import code_search_utils
import webapp_utils
webapp_utils.load_widget_state()
if "cb_acts" not in st.session_state:
switch_page("Code_Browser")
total_examples = 2000
prec_threshold = 0.01
model_name = st.session_state["model_name_id"]
seq_len = st.session_state["seq_len"]
tokens_text = st.session_state["tokens_text"]
tokens_str = st.session_state["tokens_str"]
cb_acts = st.session_state["cb_acts"]
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
gcb = st.session_state["gcb"]
def get_example_topic_codes(example_id):
"""Get topic codes for the given example id."""
token_pos_ids = [(example_id, i) for i in range(seq_len)]
all_codes = []
for cb_name, cb in cb_acts.items():
base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb)
codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall(
token_pos_ids,
cb,
act_count_ft_tkns[base_cb_name],
)
prec_sat_idx = prec >= prec_threshold
codes, prec, rec, code_acts = (
codes[prec_sat_idx],
prec[prec_sat_idx],
rec[prec_sat_idx],
code_acts[prec_sat_idx],
)
rec_sat_idx = rec >= recall_threshold
codes, prec, rec, code_acts = (
codes[rec_sat_idx],
prec[rec_sat_idx],
rec[rec_sat_idx],
code_acts[rec_sat_idx],
)
codes_pr = list(zip(codes, prec, rec, code_acts))
all_codes.append((cb_name, codes_pr))
return all_codes
def find_next_example(example_id):
"""Find the example after `example_id` that has topic codes."""
initial_example_id = example_id
example_id += 1
while example_id != initial_example_id:
all_codes = get_example_topic_codes(example_id)
codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
if codes_found > 0:
st.session_state["example_id"] = example_id
return
example_id = (example_id + 1) % total_examples
st.error(
f"No examples found at the specified recall threshold: {recall_threshold}.",
icon="🚨",
)
def redirect_to_main_with_code(code, layer, head):
"""Redirect to main page with the given code."""
st.session_state["ct_act_code"] = code
st.session_state["ct_act_layer"] = layer
if st.session_state["is_attn"]:
st.session_state["ct_act_head"] = head
switch_page("Code Browser")
def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3):
"""Show examples that the code activates on."""
ex_acts, _ = webapp_utils.get_code_acts(
model_name,
tokens_str,
code,
layer,
head,
ctx_size=5,
return_example_list=True,
)
filt_ex_acts = []
for act_str, num_acts in ex_acts:
if num_acts > seq_len * code_act_ratio:
filt_ex_acts.append(act_str)
st.markdown("#### Examples for Code")
st.markdown(
webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
)
is_attn = st.session_state["is_attn"]
st.markdown("## Topic Code")
topic_code_description = (
"Topic codes are codes that activate many different times on passages that describe a particular"
" topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking"
" at different examples in the dataset (ExampleID) and finding codes that activate on some fraction"
" of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible"
" topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at"
" least one code firing on that example above the Recall Threshold.\n\n"
"Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes"
" for a different model, go to the Code Browser page and select a different model."
)
st.write(topic_code_description)
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
example_id = ex_col.number_input(
"Example ID",
0,
total_examples - 1,
0,
key="example_id",
)
recall_threshold = r_col.slider(
"Recall Threshold",
0.0,
1.0,
0.2,
key="recall",
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
)
example_truncation = trunc_col.number_input(
"Max Output Chars", 0, 102400, 1024, key="max_chars"
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
"Sort By",
sort_by_options,
index=1,
horizontal=True,
help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)
button = st.button(
"Find Next Example",
key="find_next_example",
on_click=find_next_example,
args=(example_id,),
help="Find an example which has codes above the recall threshold.",
)
st.markdown("### Example Text")
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)
cols = st.columns(7 if is_attn else 6)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
cols[2].write("Head")
cols[-4].write("Code")
cols[-3].write("Precision")
cols[-2].write("Recall")
cols[-1].markdown(
"Num Acts",
help="Number of tokens that the code activates on in the acts dataset.",
)
all_codes = get_example_topic_codes(example_id)
all_codes = [
(cb_name, code_pr_info)
for cb_name, code_pr_infos in all_codes
for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
for cb_name, (code, p, r, acts) in all_codes:
cols = st.columns(7 if is_attn else 6)
code_button = cols[0].button(
"🔍",
key=f"ex-code-{code}-{cb_name}",
)
layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
cols[1].write(str(layer))
if is_attn:
cols[2].write(str(head))
cols[-4].write(code)
cols[-3].write(f"{p*100:.2f}%")
cols[-2].write(f"{r*100:.2f}%")
cols[-1].write(str(acts))
if code_button:
show_examples_for_topic_code(
code,
layer,
head,
code_act_ratio=recall_threshold,
)
if len(all_codes) == 0:
st.markdown(
f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}."
" Consider decreasing the recall threshold.</div>",
unsafe_allow_html=True,
)
|