giulio98 commited on
Commit
2edae76
·
1 Parent(s): 42bc715
Files changed (1) hide show
  1. app.py +240 -189
app.py CHANGED
@@ -19,7 +19,9 @@ from langchain_docling.loader import ExportType
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
21
  from transformers.models.llama.modeling_llama import rotate_half
22
-
 
 
23
  from utils import (
24
  calculate_tokens_suggest_compression_ratio,
25
  repeat_kv,
@@ -66,6 +68,44 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
66
  answer: Norman
67
  """
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class FinchCache(DynamicCache):
70
  def __init__(self) -> None:
71
  super().__init__()
@@ -154,8 +194,11 @@ def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
154
  export_type=ExportType.MARKDOWN,
155
  converter=doc_converter
156
  )
157
- docs = loader.load()
158
- return docs[0].page_content
 
 
 
159
 
160
  def create_rag_index(collection_name, text_no_prefix):
161
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
@@ -184,15 +227,15 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
184
  gr.update(interactive=False),
185
  False,
186
  {},
187
- chat_status
 
188
  )
189
  print("Converting to markdown")
190
  try:
191
  markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
192
- except Exception as e:
193
- print("Error converting to markdown:", e)
194
  return (
195
- gr.update(value="Error converting document to markdown. Please try uploading another document format."),
196
  "Number of tokens before compression: ",
197
  gr.update(),
198
  "Number of tokens after compression: ",
@@ -200,8 +243,10 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
200
  gr.update(interactive=False),
201
  False,
202
  {},
203
- chat_status
 
204
  )
 
205
  print("Done")
206
  combined_text = prefix + markdown
207
  print("Suggestioning Compression ratio")
@@ -218,7 +263,8 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
218
  rag_text = combined_text[len(prefix):]
219
  else:
220
  rag_text = combined_text
221
- collection_name = "default_collection_" + uuid.uuid4().hex[:6]
 
222
  rag_index = create_rag_index(collection_name, rag_text)
223
  state = {"rag_index": collection_name}
224
  print("Done")
@@ -231,168 +277,172 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
231
  gr.update(interactive=True), # Enable compress button if conversion succeeds.
232
  False,
233
  state,
234
- chat_status
 
235
  )
236
 
237
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
238
- device = model.device
239
- dtype = model.dtype
240
- sink_tokens = sink_tokens
241
- num_chunks = step_size
242
- context_ids = context_ids.to(device)
243
- context_attention_mask = context_attention_mask.to(device)
244
- question_ids = question_ids.to(device)
245
- question_attention_mask = question_attention_mask.to(device)
246
- question_len = question_ids.size(1)
247
- total_len = context_ids.size(1)
248
- max_context_tokens_allowed = model.config.max_position_embeddings - question_len
249
- if total_len > max_context_tokens_allowed:
250
- num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
251
- if total_len <= sink_tokens or num_chunks == 1:
252
- context_ids_list = [context_ids]
253
- context_attention_mask_list = [context_attention_mask]
254
- else:
255
- remainder_len = total_len - sink_tokens
256
- base = remainder_len // num_chunks
257
- leftover = remainder_len % num_chunks
258
- chunk_sizes = [sink_tokens + base]
259
- for _ in range(num_chunks - 2):
260
- chunk_sizes.append(base)
261
- if num_chunks > 1:
262
- chunk_sizes.append(base + leftover)
263
- context_ids_list = []
264
- context_attention_mask_list = []
265
- offset = 0
266
- for size in chunk_sizes:
267
- end = offset + size
268
- context_ids_list.append(context_ids[:, offset:end])
269
- context_attention_mask_list.append(context_attention_mask[:, offset:end])
270
- offset = end
271
- len_rest = max(total_len - sink_tokens, 1)
272
- compression_factor = len_rest // target_token_size
273
- if compression_factor < 1:
274
- compression_factor = 1
275
- tokenized_doc_chunks = []
276
- for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
277
- tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
278
- print("Number of chunks: ", len(tokenized_doc_chunks))
279
- rotary_emb = model.model.rotary_emb.to(device)
280
- inv_freq = rotary_emb.inv_freq
281
- batch_size = question_ids.size(0)
282
- ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
283
- cache = FinchCache()
284
- past_cache_len = 0
285
- past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
286
- num_chunks = len(tokenized_doc_chunks)
287
- query_context_matrices = {}
288
- def query_hook_fn(module, input, output):
289
- layer_idx = getattr(module, "layer_idx", None)
290
- if layer_idx is not None:
291
- query_states = output.detach()
292
- bsz, seq_len, hidden_dim = query_states.size()
293
- num_query_heads = module.num_query_heads
294
- head_dim = hidden_dim // num_query_heads
295
- query_states = (
296
- query_states.view(bsz, seq_len, num_query_heads, head_dim)
297
- .transpose(1, 2)
298
- .contiguous()
299
- )
300
- query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
301
- hooks = []
302
- for i, layer in enumerate(model.model.layers):
303
- layer.self_attn.q_proj.layer_idx = i
304
- layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
305
- hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
306
- hooks.append(hook)
307
- for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
308
- current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
309
- _current_chunk_offset = current_seq_length
310
- query_context_matrices.clear()
311
- chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
312
- chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
313
- segment_attention_mask = torch.cat(
314
- [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1
315
- ).contiguous()
316
- current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
317
- current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
318
- past_seen_tokens = cache.get_seq_length() if cache is not None else 0
319
- cache_position = torch.arange(
320
- past_seen_tokens + chunk_input_ids.shape[1],
321
- past_seen_tokens + current_input_ids.shape[1],
322
- device=device
323
- )
324
- causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(
325
- current_attention_mask,
326
- sequence_length=question_ids.size(1),
327
- target_length=current_attention_mask.size(-1),
328
- dtype=dtype,
329
- device=device,
330
- cache_position=cache_position,
331
- batch_size=current_input_ids.size(0),
332
- ).contiguous()
333
- with torch.no_grad():
334
- outputs = model.model(
335
- input_ids=current_input_ids,
336
- use_cache=True,
337
- past_key_values=cache,
338
- )
339
- cache = outputs.past_key_values
340
- len_question = question_ids.size(1)
341
- for layer_idx in range(len(model.model.layers)):
342
- key_matrix = cache.key_cache[layer_idx]
343
- query_matrix = query_context_matrices[layer_idx]
344
- layer_cache_pos = torch.arange(
345
- past_cache_len + current_seq_length,
346
- past_cache_len + current_seq_length + len_question,
347
  device=device
348
  )
349
- position_ids = layer_cache_pos.unsqueeze(0)
350
- cos, sin = rotary_emb(query_matrix, position_ids)
351
- cos = cos.unsqueeze(1)
352
- sin = sin.unsqueeze(1)
353
- query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
354
- num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
355
- key_matrix = repeat_kv(key_matrix, num_repeats)
356
- scaling = math.sqrt(model.config.head_dim)
357
- attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
358
- causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
359
- attention_matrix = attention_matrix + causal_mask_sliced
360
- attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
361
- tol = 1e-8
362
- binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
363
- non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
364
- non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype)
365
- attention_matrix = attention_matrix / non_zero_counts
366
- if j != num_chunks - 1:
367
- attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous()
368
- else:
369
- attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous()
370
- attention_matrix = torch.sum(attention_matrix, dim=-2)
371
- attention_matrix = attention_matrix.view(
372
- attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1
373
- ).sum(dim=2)
374
- full_context_size = attention_matrix.size(-1)
375
- attention_matrix[..., :sink_tokens] = float("inf")
376
- if j == num_chunks - 1:
377
- attention_matrix[..., -len_question:] = float("inf")
378
- if j == 0:
379
- k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor))
380
- k = min(k + past_cache_len, full_context_size)
381
- elif j < num_chunks - 1:
382
- to_keep_new = int(current_seq_length // compression_factor)
383
- k = min(past_cache_len + to_keep_new, full_context_size)
384
- else:
385
- desired_final = sink_tokens + target_token_size + len_question
386
- k = desired_final if full_context_size >= desired_final else full_context_size
387
- k = max(k, sink_tokens)
388
- selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
389
- selected_indices, _ = torch.sort(selected_indices, dim=-1)
390
- cache.compress_cache(layer_idx, selected_indices, inv_freq)
391
- past_cache_len = cache._seen_tokens
392
- past_attention_mask = torch.ones(1, past_cache_len, device=device)
393
- for hook in hooks:
394
- hook.remove()
395
- return cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
398
  k = max(1, rag_token_size // 256)
@@ -443,7 +493,8 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
443
  target_token_size = 0
444
  past_key_values = FinchCache()
445
  compressed_length = past_key_values.get_seq_length()
446
- cache_name = "default_cache_" + uuid.uuid4().hex[:6] + ".pt"
 
447
  save_dir = "./cache_dir"
448
  os.makedirs(save_dir, exist_ok=True)
449
  save_path = os.path.join(save_dir, cache_name)
@@ -455,7 +506,8 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
455
  rag_text = combined_text[len(prefix):]
456
  else:
457
  rag_text = combined_text
458
- collection_name = "default_collection_" + uuid.uuid4().hex[:6]
 
459
  rag_index = create_rag_index(collection_name, rag_text)
460
  state.update({
461
  "compressed_cache": save_path,
@@ -469,7 +521,7 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
469
  "retrieval_slider": retrieval_context_length,
470
  "prefill_time": time.perf_counter() - start_time_prefill,
471
  "compression_done": True,
472
- "tokens_breakdown": f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}",
473
  "chat_feedback": "Document compressed successfully. You can now chat."
474
  })
475
  return state, True
@@ -530,20 +582,14 @@ def chat_response_stream(message: str, history: list, state: dict):
530
  full_output += text
531
  time.sleep(0.05)
532
  yield full_output
533
- state["compressed_cache"] = past_key_values
534
  return full_output
535
 
536
- def update_token_breakdown(token_count, retrieval_slider_value, global_local_value):
537
- try:
538
- token_count = int(token_count)
539
- slider_val = float(retrieval_slider_value)
540
- percentage = int(global_local_value.replace('%', ''))
541
- retrieval_context_length = int(token_count / slider_val)
542
- rag_tokens = int(retrieval_context_length * (1 - (percentage / 100)))
543
- kv_tokens = retrieval_context_length - rag_tokens
544
- return f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}"
545
- except Exception as e:
546
- return "Token breakdown unavailable."
547
 
548
  ##########################################################################
549
  # Gradio Interface
@@ -629,6 +675,9 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
629
  compression_done = gr.State(value=False)
630
  compressed_doc_state = gr.State(value={})
631
 
 
 
 
632
  with gr.Row(elem_classes="main-container"):
633
  with gr.Column(elem_classes="upload-section"):
634
  gr.Markdown("## Document Preprocessing")
@@ -646,40 +695,38 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
646
  token_count_text = gr.Markdown("Number of tokens before compression: ")
647
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
648
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
649
- # New widget for token breakdown (KV vs RAG)
650
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
651
  global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)",
652
  choices=["0%", "25%", "50%", "75%", "100%"], value="75%")
653
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
654
- # New widget for chat status feedback
655
  chat_status_text = gr.Markdown("Document not compressed yet. Please compress the document to enable chat.")
656
 
 
657
  file_input.change(
658
  fn=auto_convert,
659
  inputs=[file_input, url_input, do_ocr, do_table],
660
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
661
  )
662
  url_input.change(
663
  fn=auto_convert,
664
  inputs=[file_input, url_input, do_ocr, do_table],
665
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
666
  )
667
  do_ocr.change(
668
  fn=auto_convert,
669
  inputs=[file_input, url_input, do_ocr, do_table],
670
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
671
  )
672
  do_table.change(
673
  fn=auto_convert,
674
  inputs=[file_input, url_input, do_ocr, do_table],
675
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
676
  )
677
  retrieval_slider.change(
678
  fn=update_retrieval_context,
679
  inputs=[hidden_token_count, retrieval_slider],
680
  outputs=retrieval_info_text
681
  )
682
- # Update token breakdown when slider or global/local changes
683
  retrieval_slider.change(
684
  fn=update_token_breakdown,
685
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
@@ -697,6 +744,9 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
697
  ).then(
698
  fn=lambda state: gr.update(value="Document compressed successfully. You can now chat."),
699
  outputs=chat_status_text
 
 
 
700
  )
701
 
702
  with gr.Column(elem_classes="chatbot-container"):
@@ -704,7 +754,8 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
704
  chat_interface = gr.ChatInterface(
705
  fn=chat_response_stream,
706
  additional_inputs=[compressed_doc_state],
707
- type="messages"
 
708
  )
709
 
710
  demo.queue(max_size=16).launch()
 
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
21
  from transformers.models.llama.modeling_llama import rotate_half
22
+ import threading
23
+ import shutil
24
+ import time
25
  from utils import (
26
  calculate_tokens_suggest_compression_ratio,
27
  repeat_kv,
 
68
  answer: Norman
69
  """
70
 
71
+
72
+
73
+ CHROMA_DB_DIR = "./chroma_db"
74
+ CACHE_DIR = "./cache_dir"
75
+ EXPIRATION_SECONDS = 3600
76
+
77
+ def background_cleanup():
78
+ while True:
79
+ current_time = int(time.time())
80
+
81
+ # Clean Chroma collections
82
+ if os.path.exists(CHROMA_DB_DIR):
83
+ for dirname in os.listdir(CHROMA_DB_DIR):
84
+ parts = dirname.split("_")
85
+ if len(parts) >= 3 and parts[1].isdigit():
86
+ timestamp = int(parts[1])
87
+ if current_time - timestamp > EXPIRATION_SECONDS:
88
+ path = os.path.join(CHROMA_DB_DIR, dirname)
89
+ shutil.rmtree(path, ignore_errors=True)
90
+ print(f"[Cleanup] Deleted Chroma collection: {path}")
91
+
92
+ # Clean cache files
93
+ if os.path.exists(CACHE_DIR):
94
+ for filename in os.listdir(CACHE_DIR):
95
+ parts = filename.split("_")
96
+ if len(parts) >= 3 and parts[1].isdigit():
97
+ timestamp = int(parts[1])
98
+ if current_time - timestamp > EXPIRATION_SECONDS:
99
+ path = os.path.join(CACHE_DIR, filename)
100
+ os.remove(path)
101
+ print(f"[Cleanup] Deleted cache file: {path}")
102
+
103
+ time.sleep(600)
104
+
105
+ cleanup_thread = threading.Thread(target=background_cleanup, daemon=True)
106
+ cleanup_thread.start()
107
+
108
+
109
  class FinchCache(DynamicCache):
110
  def __init__(self) -> None:
111
  super().__init__()
 
194
  export_type=ExportType.MARKDOWN,
195
  converter=doc_converter
196
  )
197
+ try:
198
+ docs = loader.load()
199
+ return docs[0].page_content
200
+ except Exception as e:
201
+ raise RuntimeError(f"Failed to convert document to markdown: {e}")
202
 
203
  def create_rag_index(collection_name, text_no_prefix):
204
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
 
227
  gr.update(interactive=False),
228
  False,
229
  {},
230
+ chat_status,
231
+ gr.update(interactive=False) # Disable chat interface
232
  )
233
  print("Converting to markdown")
234
  try:
235
  markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
236
+ except RuntimeError as e:
 
237
  return (
238
+ gr.update(value=f"{str(e)} Please try uploading another document format."),
239
  "Number of tokens before compression: ",
240
  gr.update(),
241
  "Number of tokens after compression: ",
 
243
  gr.update(interactive=False),
244
  False,
245
  {},
246
+ chat_status,
247
+ gr.update(interactive=False) # Disable chat interface on error
248
  )
249
+
250
  print("Done")
251
  combined_text = prefix + markdown
252
  print("Suggestioning Compression ratio")
 
263
  rag_text = combined_text[len(prefix):]
264
  else:
265
  rag_text = combined_text
266
+ current_timestamp = int(time.time())
267
+ collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}"
268
  rag_index = create_rag_index(collection_name, rag_text)
269
  state = {"rag_index": collection_name}
270
  print("Done")
 
277
  gr.update(interactive=True), # Enable compress button if conversion succeeds.
278
  False,
279
  state,
280
+ chat_status,
281
+ gr.update(interactive=False) # Ensure chat remains disabled until compression
282
  )
283
 
284
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
285
+ try:
286
+ device = model.device
287
+ dtype = model.dtype
288
+ sink_tokens = sink_tokens
289
+ num_chunks = step_size
290
+ context_ids = context_ids.to(device)
291
+ context_attention_mask = context_attention_mask.to(device)
292
+ question_ids = question_ids.to(device)
293
+ question_attention_mask = question_attention_mask.to(device)
294
+ question_len = question_ids.size(1)
295
+ total_len = context_ids.size(1)
296
+ max_context_tokens_allowed = model.config.max_position_embeddings - question_len
297
+ if total_len > max_context_tokens_allowed:
298
+ num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
299
+ if total_len <= sink_tokens or num_chunks == 1:
300
+ context_ids_list = [context_ids]
301
+ context_attention_mask_list = [context_attention_mask]
302
+ else:
303
+ remainder_len = total_len - sink_tokens
304
+ base = remainder_len // num_chunks
305
+ leftover = remainder_len % num_chunks
306
+ chunk_sizes = [sink_tokens + base]
307
+ for _ in range(num_chunks - 2):
308
+ chunk_sizes.append(base)
309
+ if num_chunks > 1:
310
+ chunk_sizes.append(base + leftover)
311
+ context_ids_list = []
312
+ context_attention_mask_list = []
313
+ offset = 0
314
+ for size in chunk_sizes:
315
+ end = offset + size
316
+ context_ids_list.append(context_ids[:, offset:end])
317
+ context_attention_mask_list.append(context_attention_mask[:, offset:end])
318
+ offset = end
319
+ len_rest = max(total_len - sink_tokens, 1)
320
+ compression_factor = len_rest // target_token_size
321
+ if compression_factor < 1:
322
+ compression_factor = 1
323
+ tokenized_doc_chunks = []
324
+ for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
325
+ tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
326
+ print("Number of chunks: ", len(tokenized_doc_chunks))
327
+ rotary_emb = model.model.rotary_emb.to(device)
328
+ inv_freq = rotary_emb.inv_freq
329
+ batch_size = question_ids.size(0)
330
+ ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
331
+ cache = FinchCache()
332
+ past_cache_len = 0
333
+ past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
334
+ num_chunks = len(tokenized_doc_chunks)
335
+ query_context_matrices = {}
336
+ def query_hook_fn(module, input, output):
337
+ layer_idx = getattr(module, "layer_idx", None)
338
+ if layer_idx is not None:
339
+ query_states = output.detach()
340
+ bsz, seq_len, hidden_dim = query_states.size()
341
+ num_query_heads = module.num_query_heads
342
+ head_dim = hidden_dim // num_query_heads
343
+ query_states = (
344
+ query_states.view(bsz, seq_len, num_query_heads, head_dim)
345
+ .transpose(1, 2)
346
+ .contiguous()
347
+ )
348
+ query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
349
+ hooks = []
350
+ for i, layer in enumerate(model.model.layers):
351
+ layer.self_attn.q_proj.layer_idx = i
352
+ layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
353
+ hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
354
+ hooks.append(hook)
355
+ for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
356
+ current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
357
+ _current_chunk_offset = current_seq_length
358
+ query_context_matrices.clear()
359
+ chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
360
+ chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
361
+ segment_attention_mask = torch.cat(
362
+ [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1
363
+ ).contiguous()
364
+ current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
365
+ current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
366
+ past_seen_tokens = cache.get_seq_length() if cache is not None else 0
367
+ cache_position = torch.arange(
368
+ past_seen_tokens + chunk_input_ids.shape[1],
369
+ past_seen_tokens + current_input_ids.shape[1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  device=device
371
  )
372
+ causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(
373
+ current_attention_mask,
374
+ sequence_length=question_ids.size(1),
375
+ target_length=current_attention_mask.size(-1),
376
+ dtype=dtype,
377
+ device=device,
378
+ cache_position=cache_position,
379
+ batch_size=current_input_ids.size(0),
380
+ ).contiguous()
381
+ with torch.no_grad():
382
+ outputs = model.model(
383
+ input_ids=current_input_ids,
384
+ use_cache=True,
385
+ past_key_values=cache,
386
+ )
387
+ cache = outputs.past_key_values
388
+ len_question = question_ids.size(1)
389
+ for layer_idx in range(len(model.model.layers)):
390
+ key_matrix = cache.key_cache[layer_idx]
391
+ query_matrix = query_context_matrices[layer_idx]
392
+ layer_cache_pos = torch.arange(
393
+ past_cache_len + current_seq_length,
394
+ past_cache_len + current_seq_length + len_question,
395
+ device=device
396
+ )
397
+ position_ids = layer_cache_pos.unsqueeze(0)
398
+ cos, sin = rotary_emb(query_matrix, position_ids)
399
+ cos = cos.unsqueeze(1)
400
+ sin = sin.unsqueeze(1)
401
+ query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
402
+ num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
403
+ key_matrix = repeat_kv(key_matrix, num_repeats)
404
+ scaling = math.sqrt(model.config.head_dim)
405
+ attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
406
+ causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
407
+ attention_matrix = attention_matrix + causal_mask_sliced
408
+ attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
409
+ tol = 1e-8
410
+ binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
411
+ non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
412
+ non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype)
413
+ attention_matrix = attention_matrix / non_zero_counts
414
+ if j != num_chunks - 1:
415
+ attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous()
416
+ else:
417
+ attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous()
418
+ attention_matrix = torch.sum(attention_matrix, dim=-2)
419
+ attention_matrix = attention_matrix.view(
420
+ attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1
421
+ ).sum(dim=2)
422
+ full_context_size = attention_matrix.size(-1)
423
+ attention_matrix[..., :sink_tokens] = float("inf")
424
+ if j == num_chunks - 1:
425
+ attention_matrix[..., -len_question:] = float("inf")
426
+ if j == 0:
427
+ k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor))
428
+ k = min(k + past_cache_len, full_context_size)
429
+ elif j < num_chunks - 1:
430
+ to_keep_new = int(current_seq_length // compression_factor)
431
+ k = min(past_cache_len + to_keep_new, full_context_size)
432
+ else:
433
+ desired_final = sink_tokens + target_token_size + len_question
434
+ k = desired_final if full_context_size >= desired_final else full_context_size
435
+ k = max(k, sink_tokens)
436
+ selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
437
+ selected_indices, _ = torch.sort(selected_indices, dim=-1)
438
+ cache.compress_cache(layer_idx, selected_indices, inv_freq)
439
+ past_cache_len = cache._seen_tokens
440
+ past_attention_mask = torch.ones(1, past_cache_len, device=device)
441
+ for hook in hooks:
442
+ hook.remove()
443
+ return cache
444
+ except Exception as e:
445
+ raise RuntimeError(f"Failed to compress KV cache: {e}")
446
 
447
  def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
448
  k = max(1, rag_token_size // 256)
 
493
  target_token_size = 0
494
  past_key_values = FinchCache()
495
  compressed_length = past_key_values.get_seq_length()
496
+ current_timestamp = int(time.time())
497
+ cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt"
498
  save_dir = "./cache_dir"
499
  os.makedirs(save_dir, exist_ok=True)
500
  save_path = os.path.join(save_dir, cache_name)
 
506
  rag_text = combined_text[len(prefix):]
507
  else:
508
  rag_text = combined_text
509
+ current_timestamp = int(time.time())
510
+ collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}"
511
  rag_index = create_rag_index(collection_name, rag_text)
512
  state.update({
513
  "compressed_cache": save_path,
 
521
  "retrieval_slider": retrieval_context_length,
522
  "prefill_time": time.perf_counter() - start_time_prefill,
523
  "compression_done": True,
524
+ "tokens_breakdown": f"RAG tokens: {rag_tokens} (for retrieval), {kv_tokens} tokens (for KV compression)",
525
  "chat_feedback": "Document compressed successfully. You can now chat."
526
  })
527
  return state, True
 
582
  full_output += text
583
  time.sleep(0.05)
584
  yield full_output
 
585
  return full_output
586
 
587
+ def update_token_breakdown(token_count, retrieval_slider, global_local_value):
588
+ retrieval_context_length = int(token_count / retrieval_slider)
589
+ percentage = int(global_local_value.replace('%', ''))
590
+ rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
591
+ kv_tokens = retrieval_context_length - rag_tokens
592
+ return f"Token Breakdown: {rag_tokens} tokens will be used for RAG retrieval, and {kv_tokens} tokens for KV compression."
 
 
 
 
 
593
 
594
  ##########################################################################
595
  # Gradio Interface
 
675
  compression_done = gr.State(value=False)
676
  compressed_doc_state = gr.State(value={})
677
 
678
+ def toggle_chat_interactivity(compression_done):
679
+ return gr.update(interactive=compression_done)
680
+
681
  with gr.Row(elem_classes="main-container"):
682
  with gr.Column(elem_classes="upload-section"):
683
  gr.Markdown("## Document Preprocessing")
 
695
  token_count_text = gr.Markdown("Number of tokens before compression: ")
696
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
697
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
 
698
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
699
  global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)",
700
  choices=["0%", "25%", "50%", "75%", "100%"], value="75%")
701
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
 
702
  chat_status_text = gr.Markdown("Document not compressed yet. Please compress the document to enable chat.")
703
 
704
+ # When document parameters change, disable the chat interface.
705
  file_input.change(
706
  fn=auto_convert,
707
  inputs=[file_input, url_input, do_ocr, do_table],
708
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
709
  )
710
  url_input.change(
711
  fn=auto_convert,
712
  inputs=[file_input, url_input, do_ocr, do_table],
713
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
714
  )
715
  do_ocr.change(
716
  fn=auto_convert,
717
  inputs=[file_input, url_input, do_ocr, do_table],
718
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
719
  )
720
  do_table.change(
721
  fn=auto_convert,
722
  inputs=[file_input, url_input, do_ocr, do_table],
723
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
724
  )
725
  retrieval_slider.change(
726
  fn=update_retrieval_context,
727
  inputs=[hidden_token_count, retrieval_slider],
728
  outputs=retrieval_info_text
729
  )
 
730
  retrieval_slider.change(
731
  fn=update_token_breakdown,
732
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
 
744
  ).then(
745
  fn=lambda state: gr.update(value="Document compressed successfully. You can now chat."),
746
  outputs=chat_status_text
747
+ ).then(
748
+ fn=lambda: gr.update(interactive=True),
749
+ outputs=lambda: chat_interface # Re-enable chat interface after successful compression.
750
  )
751
 
752
  with gr.Column(elem_classes="chatbot-container"):
 
754
  chat_interface = gr.ChatInterface(
755
  fn=chat_response_stream,
756
  additional_inputs=[compressed_doc_state],
757
+ type="messages",
758
+ interactive=False
759
  )
760
 
761
  demo.queue(max_size=16).launch()