Spaces:
Runtime error
Runtime error
File size: 12,289 Bytes
760edf9 3c1036c 119ef11 3c1036c a02bd43 e76d6fa 119ef11 4cbfdf7 e76d6fa 4cbfdf7 e76d6fa 4cbfdf7 7b3d61a a02bd43 7b3d61a 4cbfdf7 119ef11 3c1036c 9570f3d e90d7e4 7e6471e e90d7e4 3c1036c 5574047 3c1036c 96879fc 3c1036c e5a9bbf 3c1036c c3e1717 3c1036c 9570f3d 3c1036c e76d6fa 3c1036c a02bd43 5574047 e90d7e4 719af86 a02bd43 e76d6fa 3c1036c a02bd43 4cbfdf7 3c1036c 719af86 4cbfdf7 5574047 a02bd43 96879fc 719af86 a02bd43 96879fc 719af86 a02bd43 96879fc a02bd43 96879fc a02bd43 96879fc a02bd43 50b0870 a02bd43 96879fc e90d7e4 50b0870 e90d7e4 96879fc e90d7e4 96879fc 8d2fb93 96879fc 8d2fb93 96879fc 8d2fb93 96879fc 9570f3d 2343db7 96879fc e90d7e4 3c1036c 2343db7 96879fc 8d2fb93 96879fc 7e6471e 2343db7 719af86 2343db7 8d2fb93 3c1036c e93a4e5 38fce31 e93a4e5 3c1036c e93a4e5 719af86 38fce31 719af86 e93a4e5 719af86 5574047 3c1036c e93a4e5 3c1036c 5574047 3c1036c 7f3d9d5 e93a4e5 3c1036c e93a4e5 3c1036c e93a4e5 4cbfdf7 3c1036c 5574047 e93a4e5 3c1036c 2343db7 c3e1717 3c1036c 2343db7 3c1036c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import spaces
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp
from lxt.utils import clean_tokens
import gradio as gr
import numpy as np
from scipy.signal import convolve2d
from huggingface_hub import login
import os
from dotenv import load_dotenv
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
load_dotenv()
login(os.getenv("HF_TOKEN"))
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
print(f"Loading model {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", use_safetensors=True)
# model.gradient_checkpointing_enable()
attnlrp.register(model)
print(f"Loaded model.")
def really_clean_tokens(tokens):
tokens = clean_tokens(tokens)
cleaned_tokens = []
for token in tokens:
token = token.replace("_", " ").replace("β", " ").replace("<s>", " ").replace("Δ", " ").replace("Δ ", " ")
if token.startswith("<0x") and token.endswith(">"):
# Convert hex to character
char_code = int(token[3:-1], 16)
token = chr(char_code)
cleaned_tokens.append(token)
return cleaned_tokens
@spaces.GPU
def generate_and_visualize(prompt, num_tokens=10):
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
generated_tokens_ids = []
all_relevances = []
for _ in range(num_tokens):
output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]
all_relevances.append(relevance)
next_token = max_indices.unsqueeze(0)
generated_tokens_ids.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
input_embeds = model.get_input_embeddings()(input_ids)
if next_token.item() == tokenizer.eos_token_id:
print("EOS token generated, stopping generation.")
break
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
return input_tokens, all_relevances, generated_tokens
def process_relevances(input_tokens, all_relevances, generated_tokens):
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
### FIND ZONES OF INTEREST
threshold_per_token = 0.2
kernel_width = 6
context_width = 20 # Number of tokens to include as context on each side
kernel = np.ones((kernel_width, kernel_width))
if len(generated_tokens) < kernel_width:
return [(token, None, None) for token in generated_tokens]
# Compute the rolling sum using 2D convolution
rolled_sum = convolve2d(attention_matrix, kernel, mode='valid') / kernel_width**2
# Find where the rolled sum is greater than the threshold
significant_areas = rolled_sum > threshold_per_token
print(f"Found {significant_areas.sum()} relevant tokens: lower threshold to find more. Max was {rolled_sum.max()}")
print("LENGTHS:", len(input_tokens), significant_areas.shape, len(generated_tokens))
def find_largest_contiguous_patch(array):
current_patch_start = None
best_width, best_patch_start = None, None
current_width = 0
for i in range(len(array)):
if array[i]:
if current_patch_start is not None and current_patch_start + current_width == i:
current_width += 1
else:
current_patch_start = i
current_width = 1
if current_patch_start and (best_width is None or current_width > best_width):
best_patch_start = current_patch_start
best_width = current_width
else:
current_width = 0
return best_width, best_patch_start
output_with_notes = []
for row in range(len(generated_tokens)-kernel_width+1):
best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row])
if best_width is not None:
output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
else:
output_with_notes.append((generated_tokens[row], None))
output_with_notes += [(el, None) for el in generated_tokens[-kernel_width+1:]]
# Fuse the notes for consecutive output tokens if necessary
for i in range(len(output_with_notes)):
token, coords = output_with_notes[i]
if coords is not None:
best_width, best_patch_start = coords
note_width_generated = kernel_width
for next_id in range(i+1, min(i+2*kernel_width, len(output_with_notes))):
next_token, next_coords = output_with_notes[next_id]
if next_coords is not None:
next_width, next_patch_start = next_coords
if best_patch_start + best_width >= next_patch_start:
# then notes are overlapping: thus we delete the last one and make the first wider if needed
output_with_notes[next_id] = (next_token, None)
larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
best_width = larger_end - best_patch_start
note_width_generated = kernel_width + (next_id-i)
output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
else:
output_with_notes[i] = (token, None, None)
# Convert to text slices
for i, (token, coords, width) in enumerate(output_with_notes):
if coords is not None:
best_width, best_patch_start = coords
significant_start = max(0, best_patch_start)
significant_end = best_patch_start + kernel_width + best_width
context_start = max(0, significant_start - context_width)
context_end = min(len(input_tokens), significant_end + context_width)
first_part = "".join(input_tokens[context_start:significant_start])
significant_part = "".join(input_tokens[significant_start:significant_end])
final_part = "".join(input_tokens[significant_end:context_end])
output_with_notes[i] = (token, (first_part, significant_part, final_part), width)
return output_with_notes
def create_html_with_hover(output_with_notes):
html = "<div id='output-container'>"
note_number = 0
i = 0
while i < len(output_with_notes):
(token, notes, width) = output_with_notes[i]
if notes is None:
html += f'{token}'
i += 1
else:
text = "".join(really_clean_tokens([element[0] for element in output_with_notes[i:i+width]]))
print(text)
first_part, significant_part, final_part = notes
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
html += f'<span class="hover-note">{formatted_note}</span></span>'
note_number += 1
i += width
html += "</div>"
return html
@spaces.GPU
def on_generate(prompt, num_tokens):
input_tokens, all_relevances, generated_tokens = generate_and_visualize(prompt, num_tokens)
output_with_notes = process_relevances(input_tokens, all_relevances, generated_tokens)
html_output = create_html_with_hover(output_with_notes)
return html_output
css = """
#output-container {
font-size: 18px;
line-height: 1.5;
position: relative;
}
.hoverable {
color: var(--primary-500);
position: relative;
display: inline-block;
}
.hover-note {
display: none;
position: absolute;
padding: 5px;
border-radius: 5px;
bottom: 100%;
left: 0;
white-space: normal;
background-color: var(--input-background-fill);
max-width: 600px;
width: 500px;
word-wrap: break-word;
z-index: 100;
}
.hoverable:hover .hover-note {
display: block;
}
"""
examples = [
"""Context:
The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960.
Question: How many meters above 8000 did the 1922 expedition go?
Answer:""",
"""Context:
Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving.
π₯ Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables!
Prithvi WxC (Prithvi, "ΰ€ͺΰ₯ΰ€₯ΰ₯ΰ€΅ΰ₯", is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera.
π‘ But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns.
Question: How many weather variables can Prithvi predict?
Answer:""",
"""Context:
Transformers v4.45.0 released: includes a lightning-fast method to build tools! β‘οΈ
During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in transformers.agents is a bit tedious to use, because it goes in great detail.
β‘οΈ So I've made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front.
β
VoilΓ , you're good to go!
Question: How can you build tools simply in transformers?
Answer:""",
]
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG with source linking using Source attribution with [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama)")
input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0])
num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate (while no EOS token)")
generate_button = gr.Button("Generate")
output_html = gr.HTML(label="Generated Output")
generate_button.click(
on_generate,
inputs=[input_text, num_tokens],
outputs=[output_html]
)
gr.Markdown("Hover over the blue text with superscript numbers to see the important input tokens for that group.")
# Add clickable examples
gr.Examples(
examples=examples,
inputs=[input_text],
)
if __name__ == "__main__":
demo.launch() |