Spaces:
Runtime error
Runtime error
File size: 6,951 Bytes
7f9376c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
"""Web app page for showing codes for different examples in the dataset."""
import streamlit as st
from streamlit_extras.switch_page_button import switch_page
import code_search_utils
import webapp_utils
webapp_utils.load_widget_state()
if "cb_acts" not in st.session_state:
switch_page("Code_Browser")
total_examples = 2000
prec_threshold = 0.01
model_name = st.session_state["model_name_id"]
seq_len = st.session_state["seq_len"]
tokens_text = st.session_state["tokens_text"]
tokens_str = st.session_state["tokens_str"]
cb_acts = st.session_state["cb_acts"]
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
ccb = st.session_state["ccb"]
def get_example_concept_codes(example_id):
"""Get concept codes for the given example id."""
token_pos_ids = [(example_id, i) for i in range(seq_len)]
all_codes = []
for cb_name, cb in cb_acts.items():
base_cb_name = code_search_utils.convert_to_base_name(cb_name, ccb=ccb)
codes, prec, rec, code_acts = code_search_utils.get_code_pr(
token_pos_ids,
cb,
act_count_ft_tkns[base_cb_name],
)
prec_sat_idx = prec >= prec_threshold
codes, prec, rec, code_acts = (
codes[prec_sat_idx],
prec[prec_sat_idx],
rec[prec_sat_idx],
code_acts[prec_sat_idx],
)
rec_sat_idx = rec >= recall_threshold
codes, prec, rec, code_acts = (
codes[rec_sat_idx],
prec[rec_sat_idx],
rec[rec_sat_idx],
code_acts[rec_sat_idx],
)
codes_pr = list(zip(codes, prec, rec, code_acts))
all_codes.append((cb_name, codes_pr))
return all_codes
def find_next_example(example_id):
"""Find the example after `example_id` that has concept codes."""
initial_example_id = example_id
example_id += 1
while example_id != initial_example_id:
all_codes = get_example_concept_codes(example_id)
codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
if codes_found > 0:
st.session_state["example_id"] = example_id
return
example_id = (example_id + 1) % total_examples
st.error(
f"No examples found at the specified recall threshold: {recall_threshold}.",
icon="π¨",
)
def redirect_to_main_with_code(code, layer, head):
"""Redirect to main page with the given code."""
st.session_state["ct_act_code"] = code
st.session_state["ct_act_layer"] = layer
if st.session_state["is_attn"]:
st.session_state["ct_act_head"] = head
switch_page("Code Browser")
def show_examples_for_concept_code(code, layer, head, code_act_ratio=0.3):
"""Show examples that the code activates on."""
ex_acts, _ = webapp_utils.get_code_acts(
model_name,
tokens_str,
code,
layer,
head,
ctx_size=5,
return_example_list=True,
)
filt_ex_acts = []
for act_str, num_acts in ex_acts:
if num_acts > seq_len * code_act_ratio:
filt_ex_acts.append(act_str)
st.markdown("#### Examples for Code")
st.markdown(
webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
)
is_attn = st.session_state["is_attn"]
st.markdown("## Concept Code")
concept_code_description = (
"Concept codes are codes that activate a lot on only a particular set of examples that share a concept. "
"Hence such codes can be thought to correspond to more higher-level concepts or features and "
"can activate on most tokens that belong in an example text. This interface provides a way to search for such "
"codes by going through different examples using Example ID."
)
st.write(concept_code_description)
# ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
example_id = ex_col.number_input(
"Example ID",
0,
total_examples - 1,
0,
key="example_id",
)
# prec_threshold = p_col.slider(
# "Precision Threshold",
# 0.0,
# 1.0,
# 0.02,
# key="prec",
# help="Precision Threshold controls the specificity of the codes for the given example.",
# )
recall_threshold = r_col.slider(
"Recall Threshold",
0.0,
1.0,
0.3,
key="recall",
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
)
example_truncation = trunc_col.number_input(
"Max Output Chars", 0, 10240, 1024, key="max_chars"
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
"Sort By",
sort_by_options,
index=0,
horizontal=True,
help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)
button = st.button(
"Find Next Example",
key="find_next_example",
on_click=find_next_example,
args=(example_id,),
help="Find an example which has codes above the recall threshold.",
)
# if button:
# find_next_example(st.session_state["example_id"])
st.markdown("### Example Text")
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)
cols = st.columns(7 if is_attn else 6)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
cols[2].write("Head")
cols[-4].write("Code")
cols[-3].write("Precision")
cols[-2].write("Recall")
cols[-1].markdown(
"Num Acts",
help="Number of tokens that the code activates on in the acts dataset.",
)
all_codes = get_example_concept_codes(example_id)
all_codes = [
(cb_name, code_pr_info)
for cb_name, code_pr_infos in all_codes
for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
for cb_name, (code, p, r, acts) in all_codes:
cols = st.columns(7 if is_attn else 6)
code_button = cols[0].button(
"π",
key=f"ex-code-{code}-{cb_name}",
)
layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
cols[1].write(str(layer))
if is_attn:
cols[2].write(str(head))
cols[-4].write(code)
cols[-3].write(f"{p*100:.2f}%")
cols[-2].write(f"{r*100:.2f}%")
cols[-1].write(str(acts))
if code_button:
show_examples_for_concept_code(
code,
layer,
head,
code_act_ratio=recall_threshold,
)
if len(all_codes) == 0:
st.markdown(
f"<div style='text-align:center'>No codes found at recall threshold: {recall_threshold}</div>",
unsafe_allow_html=True,
)
|