File size: 12,289 Bytes
760edf9
3c1036c
 
 
 
119ef11
3c1036c
a02bd43
e76d6fa
 
 
119ef11
4cbfdf7
 
 
 
 
 
 
e76d6fa
 
4cbfdf7
 
 
e76d6fa
4cbfdf7
 
 
7b3d61a
 
 
a02bd43
7b3d61a
4cbfdf7
119ef11
3c1036c
9570f3d
e90d7e4
 
7e6471e
e90d7e4
 
 
 
 
 
3c1036c
 
 
 
 
5574047
3c1036c
 
 
 
 
96879fc
3c1036c
 
e5a9bbf
 
 
 
3c1036c
 
 
 
 
c3e1717
 
 
 
3c1036c
 
 
9570f3d
3c1036c
e76d6fa
3c1036c
a02bd43
 
5574047
e90d7e4
719af86
a02bd43
e76d6fa
 
 
3c1036c
a02bd43
4cbfdf7
3c1036c
719af86
4cbfdf7
5574047
 
a02bd43
 
96879fc
 
719af86
a02bd43
 
96879fc
719af86
a02bd43
96879fc
a02bd43
96879fc
 
a02bd43
 
 
96879fc
a02bd43
50b0870
 
 
a02bd43
96879fc
e90d7e4
 
50b0870
e90d7e4
96879fc
 
 
e90d7e4
96879fc
 
8d2fb93
96879fc
 
 
8d2fb93
96879fc
 
 
 
 
 
 
 
 
8d2fb93
96879fc
 
 
 
 
9570f3d
 
2343db7
 
 
96879fc
e90d7e4
3c1036c
 
 
 
2343db7
96879fc
 
 
 
 
8d2fb93
96879fc
7e6471e
 
2343db7
 
 
719af86
2343db7
8d2fb93
3c1036c
 
 
 
 
 
 
 
 
 
 
e93a4e5
 
 
 
 
 
38fce31
e93a4e5
 
 
3c1036c
 
 
 
 
 
e93a4e5
719af86
38fce31
719af86
e93a4e5
719af86
5574047
3c1036c
e93a4e5
 
 
3c1036c
 
5574047
 
3c1036c
7f3d9d5
e93a4e5
 
 
 
3c1036c
 
 
 
 
 
 
e93a4e5
 
 
 
 
3c1036c
 
 
 
 
 
 
e93a4e5
4cbfdf7
 
3c1036c
 
5574047
e93a4e5
3c1036c
2343db7
c3e1717
3c1036c
 
 
 
 
 
 
 
 
 
 
 
2343db7
 
 
 
 
 
3c1036c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import spaces
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp
from lxt.utils import clean_tokens
import gradio as gr
import numpy as np
from scipy.signal import convolve2d
from huggingface_hub import login
import os
from dotenv import load_dotenv

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.bfloat16,
)

load_dotenv()

login(os.getenv("HF_TOKEN"))

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

print(f"Loading model {model_id}...")

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", use_safetensors=True)
# model.gradient_checkpointing_enable()

attnlrp.register(model)

print(f"Loaded model.")

def really_clean_tokens(tokens):
    tokens = clean_tokens(tokens)
    cleaned_tokens = []
    for token in tokens:
        token = token.replace("_", " ").replace("▁", " ").replace("<s>", " ").replace("Ċ", " ").replace("Δ ", " ")
        if token.startswith("<0x") and token.endswith(">"):
            # Convert hex to character
            char_code = int(token[3:-1], 16)
            token = chr(char_code)
        cleaned_tokens.append(token)
    return cleaned_tokens

@spaces.GPU
def generate_and_visualize(prompt, num_tokens=10):
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
    input_embeds = model.get_input_embeddings()(input_ids)
    input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))

    generated_tokens_ids = []
    all_relevances = []

    for _ in range(num_tokens):
        output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
        max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)

        max_logits.backward(max_logits) 
        relevance = input_embeds.grad.float().sum(-1).cpu()[0]
        all_relevances.append(relevance)
        
        next_token = max_indices.unsqueeze(0)
        generated_tokens_ids.append(next_token.item())

        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        input_embeds = model.get_input_embeddings()(input_ids)

        if next_token.item() == tokenizer.eos_token_id:
            print("EOS token generated, stopping generation.")
            break
    generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
    
    return input_tokens, all_relevances, generated_tokens

def process_relevances(input_tokens, all_relevances, generated_tokens):

    attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])

    ### FIND ZONES OF INTEREST
    threshold_per_token = 0.2
    kernel_width = 6
    context_width = 20  # Number of tokens to include as context on each side
    kernel = np.ones((kernel_width, kernel_width))

    if len(generated_tokens) < kernel_width:
        return [(token, None, None) for token in generated_tokens]
    
    # Compute the rolling sum using 2D convolution
    rolled_sum = convolve2d(attention_matrix, kernel, mode='valid') / kernel_width**2
    
    # Find where the rolled sum is greater than the threshold
    significant_areas = rolled_sum > threshold_per_token
    print(f"Found {significant_areas.sum()} relevant tokens: lower threshold to find more. Max was {rolled_sum.max()}")
    print("LENGTHS:", len(input_tokens), significant_areas.shape, len(generated_tokens))

    def find_largest_contiguous_patch(array):
        current_patch_start = None
        best_width, best_patch_start = None, None
        current_width = 0
        for i in range(len(array)):
            if array[i]:
                if current_patch_start is not None and current_patch_start + current_width == i:
                    current_width += 1
                else:
                    current_patch_start = i
                    current_width = 1
                if current_patch_start and (best_width is None or current_width > best_width):
                    best_patch_start = current_patch_start
                    best_width = current_width
            else:
                current_width = 0
        return best_width, best_patch_start

    output_with_notes = []
    for row in range(len(generated_tokens)-kernel_width+1):
        best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row])
        if best_width is not None:
            output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
        else:
            output_with_notes.append((generated_tokens[row], None))
    output_with_notes += [(el, None) for el in generated_tokens[-kernel_width+1:]]

    # Fuse the notes for consecutive output tokens if necessary
    for i in range(len(output_with_notes)):
        token, coords = output_with_notes[i]
        if coords is not None:
            best_width, best_patch_start = coords
            note_width_generated = kernel_width
            for next_id in range(i+1, min(i+2*kernel_width, len(output_with_notes))):
                next_token, next_coords = output_with_notes[next_id]
                if next_coords is not None:
                    next_width, next_patch_start = next_coords
                    if best_patch_start + best_width >= next_patch_start:
                        # then notes are overlapping: thus we delete the last one and make the first wider if needed
                        output_with_notes[next_id] = (next_token, None)
                        larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
                        best_width = larger_end - best_patch_start
                        note_width_generated = kernel_width + (next_id-i)
            output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
        else:
            output_with_notes[i] = (token, None, None)

    # Convert to text slices
    for i, (token, coords, width) in enumerate(output_with_notes):
        if coords is not None:
            best_width, best_patch_start = coords
            significant_start = max(0, best_patch_start)
            significant_end = best_patch_start + kernel_width + best_width
            context_start = max(0, significant_start - context_width)
            context_end = min(len(input_tokens), significant_end + context_width)
            first_part = "".join(input_tokens[context_start:significant_start])
            significant_part = "".join(input_tokens[significant_start:significant_end])
            final_part = "".join(input_tokens[significant_end:context_end])
            output_with_notes[i] = (token, (first_part, significant_part, final_part), width)

    return output_with_notes

def create_html_with_hover(output_with_notes):
    html = "<div id='output-container'>"
    note_number = 0
    i = 0
    while i < len(output_with_notes):
        (token, notes, width) = output_with_notes[i]
        if notes is None:
            html += f'{token}'
            i += 1
        else:
            text = "".join(really_clean_tokens([element[0] for element in output_with_notes[i:i+width]]))
            print(text)
            first_part, significant_part, final_part = notes
            formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
            html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
            html += f'<span class="hover-note">{formatted_note}</span></span>'
            note_number += 1
            i += width
    html += "</div>"
    return html

@spaces.GPU
def on_generate(prompt, num_tokens):
    input_tokens, all_relevances, generated_tokens = generate_and_visualize(prompt, num_tokens)
    output_with_notes = process_relevances(input_tokens, all_relevances, generated_tokens)
    html_output = create_html_with_hover(output_with_notes)
    return html_output

css = """
#output-container { 
    font-size: 18px; 
    line-height: 1.5; 
    position: relative;
}
.hoverable { 
    color: var(--primary-500);
    position: relative;
    display: inline-block;
}
.hover-note {
    display: none;
    position: absolute;
    padding: 5px;
    border-radius: 5px;
    bottom: 100%;
    left: 0;
    white-space: normal;
    background-color: var(--input-background-fill);
    max-width: 600px;
    width: 500px;
    word-wrap: break-word;
    z-index: 100;
}
.hoverable:hover .hover-note { 
    display: block; 
}
"""
examples = [
    """Context:
The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960.

Question: How many meters above 8000 did the 1922 expedition go?

Answer:""",
    """Context:
Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving.

πŸ”₯ Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables!

Prithvi WxC (Prithvi, "ΰ€ͺΰ₯ƒΰ€₯ΰ₯ΰ€΅ΰ₯€", is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera.

πŸ’‘ But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns.

Question: How many weather variables can Prithvi predict?

Answer:""",
    """Context:
Transformers v4.45.0 released: includes a lightning-fast method to build tools! ⚑️

During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in transformers.agents is a bit tedious to use, because it goes in great detail.

➑️ So I've made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front.

βœ… VoilΓ , you're good to go!

Question: How can you build tools simply in transformers?

Answer:""",
]

with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
    gr.Markdown("# RAG with source linking using Source attribution with [LXT](https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama)")
    
    input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0])
    num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate (while no EOS token)")
    generate_button = gr.Button("Generate")

    output_html = gr.HTML(label="Generated Output")

    generate_button.click(
        on_generate,
        inputs=[input_text, num_tokens],
        outputs=[output_html]
    )
    
    gr.Markdown("Hover over the blue text with superscript numbers to see the important input tokens for that group.")

    # Add clickable examples
    gr.Examples(
        examples=examples,
        inputs=[input_text],
    )

if __name__ == "__main__":
    demo.launch()