giulio98 commited on
Commit
d748a9d
·
1 Parent(s): 2edae76
Files changed (6) hide show
  1. __pycache__/utils.cpython-310.pyc +0 -0
  2. app.py +104 -68
  3. cache.py +0 -75
  4. global_compression.py +0 -211
  5. preprocess_document.py +0 -34
  6. rag.py +0 -53
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -68,8 +68,6 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
68
  answer: Norman
69
  """
70
 
71
-
72
-
73
  CHROMA_DB_DIR = "./chroma_db"
74
  CACHE_DIR = "./cache_dir"
75
  EXPIRATION_SECONDS = 3600
@@ -227,8 +225,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
227
  gr.update(interactive=False),
228
  False,
229
  {},
230
- chat_status,
231
- gr.update(interactive=False) # Disable chat interface
232
  )
233
  print("Converting to markdown")
234
  try:
@@ -243,8 +240,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
243
  gr.update(interactive=False),
244
  False,
245
  {},
246
- chat_status,
247
- gr.update(interactive=False) # Disable chat interface on error
248
  )
249
 
250
  print("Done")
@@ -254,7 +250,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
254
  print("Done")
255
  min_ratio = min(suggestions)
256
  max_ratio = max(suggestions)
257
- default_ratio = suggestions[len(suggestions) // 2]
258
  retrieval_tokens = int(token_count / default_ratio)
259
  token_count_str = f"Number of tokens before compression: {token_count}"
260
  retrieval_str = f"Number of tokens after compression: {retrieval_tokens}"
@@ -277,8 +273,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
277
  gr.update(interactive=True), # Enable compress button if conversion succeeds.
278
  False,
279
  state,
280
- chat_status,
281
- gr.update(interactive=False) # Ensure chat remains disabled until compression
282
  )
283
 
284
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
@@ -458,8 +453,11 @@ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, fe
458
  return rag_context
459
 
460
  @spaces.GPU
461
- def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state):
 
 
462
  percentage = int(global_local_value.replace('%', ''))
 
463
  question_text = task_description + "\n" + few_shot
464
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
465
  question_encoding = tokenizer(question_text, return_tensors="pt").to(device)
@@ -467,41 +465,44 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
467
  context_attention_mask = context_encoding["attention_mask"]
468
  question_ids = question_encoding["input_ids"]
469
  question_attention_mask = question_encoding["attention_mask"]
 
470
  retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value)
471
- # Compute token breakdown for display (KV compress vs RAG tokens)
472
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
473
  kv_tokens = retrieval_context_length - rag_tokens
474
- print(f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}")
 
475
  if percentage > 0:
476
  target_token_size = int(retrieval_context_length * (percentage / 100))
477
- print("Target token size for compression: ", target_token_size)
478
  step_size = 2
479
- start_time_prefill = time.perf_counter()
480
  try:
481
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
482
  context_ids, context_attention_mask,
483
  question_ids, question_attention_mask))
484
  except Exception as e:
 
485
  print("Error during KV cache compression:", e)
486
  state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again."
487
  return state, False
488
  compressed_length = past_key_values.get_seq_length()
489
- print("Context size after compression: ", compressed_length)
490
- print("Compression rate: ", context_ids.size(1) / compressed_length)
491
  else:
492
- start_time_prefill = 0
493
  target_token_size = 0
494
  past_key_values = FinchCache()
495
  compressed_length = past_key_values.get_seq_length()
 
 
496
  current_timestamp = int(time.time())
497
  cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt"
498
  save_dir = "./cache_dir"
499
  os.makedirs(save_dir, exist_ok=True)
500
  save_path = os.path.join(save_dir, cache_name)
501
  past_key_values.save(save_path)
 
 
502
  collection_name = state.get("rag_index", None)
503
  if collection_name is None:
504
- print("Collection name not found creating a new one.")
505
  if combined_text.startswith(prefix):
506
  rag_text = combined_text[len(prefix):]
507
  else:
@@ -509,27 +510,23 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
509
  current_timestamp = int(time.time())
510
  collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}"
511
  rag_index = create_rag_index(collection_name, rag_text)
 
512
  state.update({
513
  "compressed_cache": save_path,
514
- "compressed_length": compressed_length,
515
  "rag_index": collection_name,
516
- "target_token_size": target_token_size,
517
  "global_local": percentage,
518
- "combined_text": combined_text,
519
  "task_description": task_description,
520
  "few_shot": few_shot,
521
  "retrieval_slider": retrieval_context_length,
522
- "prefill_time": time.perf_counter() - start_time_prefill,
523
- "compression_done": True,
524
- "tokens_breakdown": f"RAG tokens: {rag_tokens} (for retrieval), {kv_tokens} tokens (for KV compression)",
525
- "chat_feedback": "Document compressed successfully. You can now chat."
526
  })
527
- return state, True
 
 
528
 
529
  @spaces.GPU
530
- def chat_response_stream(message: str, history: list, state: dict):
531
  # Check if the document is compressed before allowing chat
532
- if not state.get("compression_done", False) or "compressed_cache" not in state:
533
  yield "Document not compressed yet. Please compress the document first to enable chat."
534
  return
535
  user_message = message
@@ -589,7 +586,7 @@ def update_token_breakdown(token_count, retrieval_slider, global_local_value):
589
  percentage = int(global_local_value.replace('%', ''))
590
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
591
  kv_tokens = retrieval_context_length - rag_tokens
592
- return f"Token Breakdown: {rag_tokens} tokens will be used for RAG retrieval, and {kv_tokens} tokens for KV compression."
593
 
594
  ##########################################################################
595
  # Gradio Interface
@@ -651,33 +648,22 @@ body {
651
  display: flex;
652
  align-items: center;
653
  }
654
- @media (min-width: 768px) {
655
- .main-container {
656
- display: flex;
657
- justify-content: space-between;
658
- gap: 20px;
659
- }
660
- .upload-section {
661
- flex: 3;
662
- }
663
- .chatbot-container {
664
- flex: 1;
665
- margin-top: 0;
666
- }
667
- }
668
  """
 
 
669
 
670
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
671
  gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
672
  gr.HTML("<center><p>Compress your document and chat with it.</p></center>")
673
 
 
 
 
674
  hidden_token_count = gr.State(value=0)
675
  compression_done = gr.State(value=False)
676
  compressed_doc_state = gr.State(value={})
677
 
678
- def toggle_chat_interactivity(compression_done):
679
- return gr.update(interactive=compression_done)
680
-
681
  with gr.Row(elem_classes="main-container"):
682
  with gr.Column(elem_classes="upload-section"):
683
  gr.Markdown("## Document Preprocessing")
@@ -696,37 +682,90 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
696
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
697
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
698
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
699
- global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)",
700
- choices=["0%", "25%", "50%", "75%", "100%"], value="75%")
701
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
702
- chat_status_text = gr.Markdown("Document not compressed yet. Please compress the document to enable chat.")
703
 
704
- # When document parameters change, disable the chat interface.
705
  file_input.change(
706
  fn=auto_convert,
707
  inputs=[file_input, url_input, do_ocr, do_table],
708
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
 
 
 
 
 
709
  )
 
 
710
  url_input.change(
711
  fn=auto_convert,
712
  inputs=[file_input, url_input, do_ocr, do_table],
713
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
 
 
 
 
 
714
  )
 
 
715
  do_ocr.change(
716
  fn=auto_convert,
717
  inputs=[file_input, url_input, do_ocr, do_table],
718
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
 
 
 
 
 
719
  )
 
 
720
  do_table.change(
721
  fn=auto_convert,
722
  inputs=[file_input, url_input, do_ocr, do_table],
723
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text, gr.State().update(interactive=False)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  )
 
 
725
  retrieval_slider.change(
726
- fn=update_retrieval_context,
727
- inputs=[hidden_token_count, retrieval_slider],
728
- outputs=retrieval_info_text
 
 
 
 
 
729
  )
 
 
730
  retrieval_slider.change(
731
  fn=update_token_breakdown,
732
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
@@ -737,25 +776,22 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
737
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
738
  outputs=tokens_breakdown_text
739
  )
 
 
740
  compress_button.click(
741
  fn=prepare_compression_and_rag,
742
  inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state],
743
- outputs=[compressed_doc_state, compression_done]
744
- ).then(
745
- fn=lambda state: gr.update(value="Document compressed successfully. You can now chat."),
746
- outputs=chat_status_text
747
- ).then(
748
- fn=lambda: gr.update(interactive=True),
749
- outputs=lambda: chat_interface # Re-enable chat interface after successful compression.
750
  )
751
-
752
  with gr.Column(elem_classes="chatbot-container"):
 
753
  gr.Markdown("## Chat")
754
  chat_interface = gr.ChatInterface(
755
  fn=chat_response_stream,
756
- additional_inputs=[compressed_doc_state],
757
  type="messages",
758
- interactive=False
759
  )
760
 
761
- demo.queue(max_size=16).launch()
 
 
68
  answer: Norman
69
  """
70
 
 
 
71
  CHROMA_DB_DIR = "./chroma_db"
72
  CACHE_DIR = "./cache_dir"
73
  EXPIRATION_SECONDS = 3600
 
225
  gr.update(interactive=False),
226
  False,
227
  {},
228
+ chat_status
 
229
  )
230
  print("Converting to markdown")
231
  try:
 
240
  gr.update(interactive=False),
241
  False,
242
  {},
243
+ chat_status
 
244
  )
245
 
246
  print("Done")
 
250
  print("Done")
251
  min_ratio = min(suggestions)
252
  max_ratio = max(suggestions)
253
+ default_ratio = 6
254
  retrieval_tokens = int(token_count / default_ratio)
255
  token_count_str = f"Number of tokens before compression: {token_count}"
256
  retrieval_str = f"Number of tokens after compression: {retrieval_tokens}"
 
273
  gr.update(interactive=True), # Enable compress button if conversion succeeds.
274
  False,
275
  state,
276
+ chat_status
 
277
  )
278
 
279
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
 
453
  return rag_context
454
 
455
  @spaces.GPU
456
+ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
457
+ progress(0, desc="Starting compression process")
458
+
459
  percentage = int(global_local_value.replace('%', ''))
460
+ progress(0.1, desc="Tokenizing text and preparing task")
461
  question_text = task_description + "\n" + few_shot
462
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
463
  question_encoding = tokenizer(question_text, return_tensors="pt").to(device)
 
465
  context_attention_mask = context_encoding["attention_mask"]
466
  question_ids = question_encoding["input_ids"]
467
  question_attention_mask = question_encoding["attention_mask"]
468
+
469
  retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value)
 
470
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
471
  kv_tokens = retrieval_context_length - rag_tokens
472
+ progress(0.2, desc=f"Token breakdown computed: {kv_tokens} KV tokens, {rag_tokens} RAG tokens")
473
+
474
  if percentage > 0:
475
  target_token_size = int(retrieval_context_length * (percentage / 100))
476
+ progress(0.3, desc="Starting KV cache compression")
477
  step_size = 2
 
478
  try:
479
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
480
  context_ids, context_attention_mask,
481
  question_ids, question_attention_mask))
482
  except Exception as e:
483
+ progress(1, desc="Compression failed")
484
  print("Error during KV cache compression:", e)
485
  state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again."
486
  return state, False
487
  compressed_length = past_key_values.get_seq_length()
488
+ progress(0.6, desc="KV cache compression completed")
 
489
  else:
 
490
  target_token_size = 0
491
  past_key_values = FinchCache()
492
  compressed_length = past_key_values.get_seq_length()
493
+ progress(0.3, desc="Skipping compression as percentage is 0")
494
+
495
  current_timestamp = int(time.time())
496
  cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt"
497
  save_dir = "./cache_dir"
498
  os.makedirs(save_dir, exist_ok=True)
499
  save_path = os.path.join(save_dir, cache_name)
500
  past_key_values.save(save_path)
501
+ progress(0.8, desc="Cache saved successfully")
502
+
503
  collection_name = state.get("rag_index", None)
504
  if collection_name is None:
505
+ print("Collection name not found; creating a new one.")
506
  if combined_text.startswith(prefix):
507
  rag_text = combined_text[len(prefix):]
508
  else:
 
510
  current_timestamp = int(time.time())
511
  collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}"
512
  rag_index = create_rag_index(collection_name, rag_text)
513
+
514
  state.update({
515
  "compressed_cache": save_path,
 
516
  "rag_index": collection_name,
 
517
  "global_local": percentage,
 
518
  "task_description": task_description,
519
  "few_shot": few_shot,
520
  "retrieval_slider": retrieval_context_length,
 
 
 
 
521
  })
522
+ progress(1, desc="Compression complete")
523
+ return state, "Document compressed successfully. You can now chat.", True
524
+
525
 
526
  @spaces.GPU
527
+ def chat_response_stream(message: str, history: list, state: dict, compression_done: bool):
528
  # Check if the document is compressed before allowing chat
529
+ if not compression_done or "compressed_cache" not in state:
530
  yield "Document not compressed yet. Please compress the document first to enable chat."
531
  return
532
  user_message = message
 
586
  percentage = int(global_local_value.replace('%', ''))
587
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
588
  kv_tokens = retrieval_context_length - rag_tokens
589
+ return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
590
 
591
  ##########################################################################
592
  # Gradio Interface
 
648
  display: flex;
649
  align-items: center;
650
  }
651
+
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  """
653
+ def reset_chat_state():
654
+ return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
655
 
656
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
657
  gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
658
  gr.HTML("<center><p>Compress your document and chat with it.</p></center>")
659
 
660
+ # Define chat_status_text as a Textbox with a set elem_id for custom styling.
661
+ chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
662
+
663
  hidden_token_count = gr.State(value=0)
664
  compression_done = gr.State(value=False)
665
  compressed_doc_state = gr.State(value={})
666
 
 
 
 
667
  with gr.Row(elem_classes="main-container"):
668
  with gr.Column(elem_classes="upload-section"):
669
  gr.Markdown("## Document Preprocessing")
 
682
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
683
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
684
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
685
+ global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
686
+ choices=["0%", "25%", "50%", "75%", "100%"], value="100%")
687
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
 
688
 
689
+ # File input: Run auto_convert then chain reset_chat_state.
690
  file_input.change(
691
  fn=auto_convert,
692
  inputs=[file_input, url_input, do_ocr, do_table],
693
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text,
694
+ hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
695
+ ).then(
696
+ fn=reset_chat_state,
697
+ inputs=None,
698
+ outputs=[chat_status_text, compression_done]
699
  )
700
+
701
+ # URL input: Run auto_convert then chain reset_chat_state.
702
  url_input.change(
703
  fn=auto_convert,
704
  inputs=[file_input, url_input, do_ocr, do_table],
705
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text,
706
+ hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
707
+ ).then(
708
+ fn=reset_chat_state,
709
+ inputs=None,
710
+ outputs=[chat_status_text, compression_done]
711
  )
712
+
713
+ # OCR checkbox: Run auto_convert then chain reset_chat_state.
714
  do_ocr.change(
715
  fn=auto_convert,
716
  inputs=[file_input, url_input, do_ocr, do_table],
717
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text,
718
+ hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
719
+ ).then(
720
+ fn=reset_chat_state,
721
+ inputs=None,
722
+ outputs=[chat_status_text, compression_done]
723
  )
724
+
725
+ # Table structure checkbox: Run auto_convert then chain reset_chat_state.
726
  do_table.change(
727
  fn=auto_convert,
728
  inputs=[file_input, url_input, do_ocr, do_table],
729
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text,
730
+ hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
731
+ ).then(
732
+ fn=reset_chat_state,
733
+ inputs=None,
734
+ outputs=[chat_status_text, compression_done]
735
+ )
736
+
737
+ # Reset chat state when prompt designer fields change.
738
+ task_description_input.change(
739
+ fn=reset_chat_state,
740
+ inputs=None,
741
+ outputs=[chat_status_text, compression_done]
742
+ )
743
+ few_shot_input.change(
744
+ fn=reset_chat_state,
745
+ inputs=None,
746
+ outputs=[chat_status_text, compression_done]
747
+ )
748
+
749
+ # Reset chat state when the Markdown output changes.
750
+ markdown_output.change(
751
+ fn=reset_chat_state,
752
+ inputs=None,
753
+ outputs=[chat_status_text, compression_done]
754
  )
755
+
756
+ # When sliders change, reset chat state.
757
  retrieval_slider.change(
758
+ fn=reset_chat_state,
759
+ inputs=None,
760
+ outputs=[chat_status_text, compression_done]
761
+ )
762
+ global_local_slider.change(
763
+ fn=reset_chat_state,
764
+ inputs=None,
765
+ outputs=[chat_status_text, compression_done]
766
  )
767
+
768
+ # Update token breakdown when sliders change.
769
  retrieval_slider.change(
770
  fn=update_token_breakdown,
771
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
 
776
  inputs=[hidden_token_count, retrieval_slider, global_local_slider],
777
  outputs=tokens_breakdown_text
778
  )
779
+
780
+ # Compress button: Prepare compression and then update chat status.
781
  compress_button.click(
782
  fn=prepare_compression_and_rag,
783
  inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state],
784
+ outputs=[compressed_doc_state, chat_status_text, compression_done]
 
 
 
 
 
 
785
  )
 
786
  with gr.Column(elem_classes="chatbot-container"):
787
+ chat_status_text.render()
788
  gr.Markdown("## Chat")
789
  chat_interface = gr.ChatInterface(
790
  fn=chat_response_stream,
791
+ additional_inputs=[compressed_doc_state, compression_done],
792
  type="messages",
793
+ fill_height=True
794
  )
795
 
796
+ demo.queue(max_size=16).launch()
797
+
cache.py DELETED
@@ -1,75 +0,0 @@
1
- from transformers import DynamicCache
2
- import torch
3
- import os
4
-
5
- class FinchCache(DynamicCache):
6
- def __init__(self) -> None:
7
- super().__init__()
8
- self.key_cache = []
9
- self.value_cache = []
10
-
11
- @staticmethod
12
- def _rotate_half(x):
13
- x1 = x[..., : x.shape[-1] // 2]
14
- x2 = x[..., x.shape[-1] // 2 :]
15
- return torch.cat((-x2, x1), dim=-1)
16
-
17
- def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
18
- return (key_states * cos) + (self._rotate_half(key_states) * sin)
19
-
20
- @staticmethod
21
- def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
22
- B, H, L = important_pos_batch.shape
23
- device = important_pos_batch.device
24
- device_type = x.device.type
25
- dtype = x.dtype
26
- idx = torch.arange(0, L, device=device)
27
- idx = idx.unsqueeze(0)
28
- inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1)
29
- idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L)
30
- delta_pos = idx - important_pos_batch
31
- delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L)
32
-
33
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
34
-
35
- with torch.autocast(device_type=device_type, enabled=False):
36
- freqs = delta_pos.float() * inv_freq.float()
37
- freqs = freqs.transpose(2, 3)
38
- emb = torch.cat((freqs, freqs), dim=-1)
39
- cos = emb.cos().contiguous()
40
- sin = emb.sin().contiguous()
41
- return cos.to(dtype=dtype), sin.to(dtype=dtype)
42
-
43
- @staticmethod
44
- def gather_important_tokens(states, indices):
45
- return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous()
46
-
47
- def compress_cache(self, layer_index, important_pos, inv_freq):
48
- new_length = important_pos.size(2)
49
- new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos)
50
- gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone()
51
- self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin)
52
- gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone()
53
- self.value_cache[layer_index] = gathered_values
54
- self._seen_tokens = new_length
55
-
56
- def save(self, path: str):
57
- """Save the cache to disk, moving tensors to CPU."""
58
- try:
59
- os.makedirs(os.path.dirname(path), exist_ok=True)
60
- torch.save(
61
- {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]},
62
- path,
63
- )
64
- except Exception as e:
65
- print(f"Error occurred while saving: {e}")
66
-
67
- @classmethod
68
- def load(cls, path: str, device: str = "cpu") -> "FinchCache":
69
- """Load the cache from disk and move tensors to the specified device."""
70
- data = torch.load(path, map_location=device)
71
- cache = cls()
72
- cache.key_cache = [k.to(device) for k in data["key_cache"]]
73
- cache.value_cache = [v.to(device) for v in data["value_cache"]]
74
- cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0
75
- return cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
global_compression.py DELETED
@@ -1,211 +0,0 @@
1
- import math
2
- import torch
3
- from cache import FinchCache
4
- from utils import repeat_kv
5
- from transformers.models.llama.modeling_llama import rotate_half
6
- import spaces
7
-
8
- @spaces.GPU
9
- def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
10
- device = model.device
11
- dtype = model.dtype
12
- sink_tokens = sink_tokens
13
- num_chunks = step_size
14
- context_ids = context_ids.to(device)
15
- context_attention_mask = context_attention_mask.to(device)
16
- question_ids = question_ids.to(device)
17
- question_attention_mask = question_attention_mask.to(device)
18
- question_len = question_ids.size(1)
19
- total_len = context_ids.size(1)
20
- max_context_tokens_allowed = model.config.max_position_embeddings - question_len
21
- if total_len > max_context_tokens_allowed:
22
- num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
23
-
24
- if total_len <= sink_tokens or num_chunks == 1:
25
- # If the context is too short or only one chunk is desired, use the entire context.
26
- context_ids_list = [context_ids]
27
- context_attention_mask_list = [context_attention_mask]
28
- else:
29
- # Calculate how many tokens remain after the sink tokens.
30
- remainder_len = total_len - sink_tokens
31
-
32
- # Compute the base tokens per chunk and any leftover.
33
- base = remainder_len // num_chunks
34
- leftover = remainder_len % num_chunks
35
-
36
- # Build a list of chunk sizes.
37
- # First chunk gets the sink tokens plus base tokens.
38
- chunk_sizes = [sink_tokens + base]
39
-
40
- # Chunks 2 to num_chunks-1 get base tokens each.
41
- for _ in range(num_chunks - 2):
42
- chunk_sizes.append(base)
43
-
44
- # The last chunk gets the remaining tokens (base + leftover).
45
- if num_chunks > 1:
46
- chunk_sizes.append(base + leftover)
47
-
48
- # Now slice the context using the calculated sizes.
49
- context_ids_list = []
50
- context_attention_mask_list = []
51
- offset = 0
52
- for size in chunk_sizes:
53
- end = offset + size
54
- context_ids_list.append(context_ids[:, offset:end])
55
- context_attention_mask_list.append(context_attention_mask[:, offset:end])
56
- offset = end
57
-
58
- # (Optional) Continue with the rest of your processing…
59
- len_rest = max(total_len - sink_tokens, 1)
60
- compression_factor = len_rest // target_token_size
61
- if compression_factor < 1:
62
- compression_factor = 1
63
-
64
- tokenized_doc_chunks = []
65
- for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
66
- tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
67
-
68
- print("Number of chunks: ", len(tokenized_doc_chunks))
69
-
70
- rotary_emb = model.model.rotary_emb.to(device)
71
- inv_freq = rotary_emb.inv_freq
72
- batch_size = question_ids.size(0)
73
- ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
74
-
75
- cache = FinchCache()
76
- past_cache_len = 0
77
- past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
78
- num_chunks = len(tokenized_doc_chunks)
79
-
80
- # Prepare a shared dictionary for hook outputs.
81
- query_context_matrices = {}
82
-
83
- # Define a hook function that uses a per-chunk offset stored on self.
84
- def query_hook_fn(module, input, output):
85
- layer_idx = getattr(module, "layer_idx", None)
86
- if layer_idx is not None:
87
- query_states = output.detach()
88
- bsz, seq_len, hidden_dim = query_states.size()
89
- num_query_heads = module.num_query_heads
90
- head_dim = hidden_dim // num_query_heads
91
- query_states = (
92
- query_states.view(bsz, seq_len, num_query_heads, head_dim)
93
- .transpose(1, 2)
94
- .contiguous()
95
- )
96
- # Use self._current_chunk_offset to select only the new tokens.
97
- query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
98
-
99
- # Pre-register hooks for all layers only once.
100
- hooks = []
101
- for i, layer in enumerate(model.model.layers):
102
- layer.self_attn.q_proj.layer_idx = i # For tracking.
103
- layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
104
- hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
105
- hooks.append(hook)
106
-
107
- # Process each document chunk sequentially.
108
- for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
109
- current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
110
- # Save the offset in an attribute the hook can access.
111
- _current_chunk_offset = current_seq_length
112
- # Clear the dictionary from any previous chunk.
113
- query_context_matrices.clear()
114
-
115
- # These chunks are already on the device.
116
- chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
117
- chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
118
- segment_attention_mask = torch.cat(
119
- [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1
120
- ).contiguous()
121
- current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
122
- current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
123
-
124
- past_seen_tokens = cache.get_seq_length() if cache is not None else 0
125
- cache_position = torch.arange(
126
- past_seen_tokens + chunk_input_ids.shape[1],
127
- past_seen_tokens + current_input_ids.shape[1],
128
- device=device
129
- )
130
- causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(
131
- current_attention_mask,
132
- sequence_length=question_ids.size(1),
133
- target_length=current_attention_mask.size(-1),
134
- dtype=dtype,
135
- device=device,
136
- cache_position=cache_position,
137
- batch_size=current_input_ids.size(0),
138
- ).contiguous()
139
-
140
- with torch.no_grad():
141
- outputs = model.model(
142
- input_ids=current_input_ids,
143
- use_cache=True,
144
- past_key_values=cache,
145
- )
146
- cache = outputs.past_key_values
147
-
148
- len_question = question_ids.size(1)
149
- # Now, for each transformer layer, update the cache using the query/key attention.
150
- for layer_idx in range(len(model.model.layers)):
151
- key_matrix = cache.key_cache[layer_idx]
152
- query_matrix = query_context_matrices[layer_idx]
153
- layer_cache_pos = torch.arange(
154
- past_cache_len + current_seq_length,
155
- past_cache_len + current_seq_length + len_question,
156
- device=device
157
- )
158
- position_ids = layer_cache_pos.unsqueeze(0)
159
- cos, sin = rotary_emb(query_matrix, position_ids)
160
- cos = cos.unsqueeze(1)
161
- sin = sin.unsqueeze(1)
162
- query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
163
- num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
164
- key_matrix = repeat_kv(key_matrix, num_repeats)
165
-
166
- scaling = math.sqrt(model.config.head_dim)
167
- attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
168
- causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
169
- attention_matrix = attention_matrix + causal_mask_sliced
170
- attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
171
- # Normalization
172
- tol = 1e-8
173
- binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
174
- non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
175
- non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype)
176
- attention_matrix = attention_matrix / non_zero_counts
177
- if j != num_chunks - 1:
178
- attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous()
179
- else:
180
- attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous()
181
- attention_matrix = torch.sum(attention_matrix, dim=-2)
182
- attention_matrix = attention_matrix.view(
183
- attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1
184
- ).sum(dim=2)
185
- full_context_size = attention_matrix.size(-1)
186
- attention_matrix[..., :sink_tokens] = float("inf")
187
- if j == num_chunks - 1:
188
- attention_matrix[..., -len_question:] = float("inf")
189
- if j == 0:
190
- k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor))
191
- k = min(k + past_cache_len, full_context_size)
192
- elif j < num_chunks - 1:
193
- to_keep_new = int(current_seq_length // compression_factor)
194
- k = min(past_cache_len + to_keep_new, full_context_size)
195
- else:
196
- desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens
197
- k = desired_final if full_context_size >= desired_final else full_context_size
198
- k = max(k, sink_tokens)
199
- selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
200
- selected_indices, _ = torch.sort(selected_indices, dim=-1)
201
- cache.compress_cache(layer_idx, selected_indices, inv_freq)
202
-
203
- past_cache_len = cache._seen_tokens
204
- past_attention_mask = torch.ones(1, past_cache_len, device=device)
205
-
206
- # Remove the hooks once after all chunks are processed.
207
- for hook in hooks:
208
- hook.remove()
209
-
210
- return cache
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocess_document.py DELETED
@@ -1,34 +0,0 @@
1
- from langchain_docling import DoclingLoader
2
- from langchain_docling.loader import ExportType
3
-
4
- # Import required classes for building a custom converter
5
- from docling.document_converter import DocumentConverter, PdfFormatOption, InputFormat
6
- from docling.datamodel.pipeline_options import PdfPipelineOptions
7
- from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
8
- import spaces
9
-
10
- @spaces.GPU
11
- def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
12
- file_path = file_objs if file_objs is not None else url
13
- pipeline_options = PdfPipelineOptions()
14
- pipeline_options.do_ocr = do_ocr
15
- pipeline_options.do_table_structure = do_table_structure
16
- pdf_format_options = PdfFormatOption(
17
- pipeline_options=pipeline_options,
18
- backend=PyPdfiumDocumentBackend,
19
- )
20
- doc_converter = DocumentConverter(
21
- allowed_formats=[InputFormat.PDF],
22
- format_options={
23
- InputFormat.PDF: pdf_format_options
24
- }
25
- )
26
-
27
- # Pass the custom converter to the DoclingLoader.
28
- loader = DoclingLoader(
29
- file_path=file_path,
30
- export_type=ExportType.MARKDOWN,
31
- converter=doc_converter
32
- )
33
- docs = loader.load()
34
- return docs[0].page_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag.py DELETED
@@ -1,53 +0,0 @@
1
- from langchain_text_splitters import RecursiveCharacterTextSplitter
2
- from langchain.schema.document import Document
3
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
4
- from langchain_chroma import Chroma
5
- import spaces
6
- from langchain_text_splitters import MarkdownHeaderTextSplitter
7
- import os
8
- from transformers import AutoTokenizer
9
- api_token = os.getenv("HF_TOKEN")
10
- model_name = "meta-llama/Llama-3.1-8B-Instruct"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
12
-
13
- embedding_model = HuggingFaceBgeEmbeddings(
14
- model_name="BAAI/bge-large-en-v1.5",
15
- model_kwargs={"device": "cuda"},
16
- encode_kwargs={"normalize_embeddings": True},
17
- query_instruction=""
18
- )
19
-
20
-
21
- def create_rag_index(text_no_prefix):
22
- """Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
23
- text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
24
- tokenizer,
25
- chunk_size=256,
26
- chunk_overlap=0,
27
- add_start_index=True,
28
- strip_whitespace=True,
29
- separators=["\n\n", "\n", ".", " ", ""],
30
- )
31
- # Concatenate pages and create Document objects.
32
- docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
33
-
34
- vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model)
35
- return vectorstore
36
-
37
- def run_naive_rag_query(vectorstore, query, rag_token_size, prefix, task, few_shot_examples):
38
- """
39
- For naive RAG, retrieves top-k chunks (k based on target token size)
40
- and generates an answer using those chunks.
41
- """
42
- k = max(1, rag_token_size // 256)
43
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
44
- retrieved_docs = retriever.invoke(query)
45
- for doc in retrieved_docs:
46
- print("=================")
47
- print(doc.page_content)
48
- print("=================")
49
- formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
50
-
51
- rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples
52
-
53
- return rag_context