giulio98 commited on
Commit
42bc715
·
1 Parent(s): 6da25f1
Files changed (1) hide show
  1. app.py +101 -131
app.py CHANGED
@@ -3,6 +3,7 @@ import math
3
  import os
4
  import time
5
  from threading import Thread
 
6
 
7
  import gradio as gr
8
  import spaces
@@ -18,7 +19,6 @@ from langchain_docling.loader import ExportType
18
  from langchain_text_splitters import RecursiveCharacterTextSplitter
19
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
20
  from transformers.models.llama.modeling_llama import rotate_half
21
- import uuid
22
 
23
  from utils import (
24
  calculate_tokens_suggest_compression_ratio,
@@ -26,8 +26,6 @@ from utils import (
26
  update_retrieval_context,
27
  )
28
 
29
-
30
-
31
  # Initialize the model and tokenizer.
32
  api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
33
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
@@ -37,12 +35,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  model = model.eval()
38
  model.to(device)
39
  embedding_model = HuggingFaceBgeEmbeddings(
40
- model_name="BAAI/bge-large-en-v1.5",
41
- model_kwargs={"device": str(device)},
42
- encode_kwargs={"normalize_embeddings": True},
43
- query_instruction=""
44
- )
45
-
46
 
47
  # Create a chat template and split into prefix and suffix.
48
  content_system = ""
@@ -121,7 +118,6 @@ class FinchCache(DynamicCache):
121
  self._seen_tokens = new_length
122
 
123
  def save(self, path: str):
124
- """Save the cache to disk, moving tensors to CPU."""
125
  try:
126
  os.makedirs(os.path.dirname(path), exist_ok=True)
127
  torch.save(
@@ -133,7 +129,6 @@ class FinchCache(DynamicCache):
133
 
134
  @classmethod
135
  def load(cls, path: str, device: str = "cpu") -> "FinchCache":
136
- """Load the cache from disk and move tensors to the specified device."""
137
  data = torch.load(path, map_location=device)
138
  cache = cls()
139
  cache.key_cache = [k.to(device) for k in data["key_cache"]]
@@ -141,8 +136,6 @@ class FinchCache(DynamicCache):
141
  cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0
142
  return cache
143
 
144
-
145
-
146
  def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
147
  file_path = file_objs if file_objs is not None else url
148
  pipeline_options = PdfPipelineOptions()
@@ -154,12 +147,8 @@ def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
154
  )
155
  doc_converter = DocumentConverter(
156
  allowed_formats=[InputFormat.PDF],
157
- format_options={
158
- InputFormat.PDF: pdf_format_options
159
- }
160
  )
161
-
162
- # Pass the custom converter to the DoclingLoader.
163
  loader = DoclingLoader(
164
  file_path=file_path,
165
  export_type=ExportType.MARKDOWN,
@@ -168,39 +157,51 @@ def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
168
  docs = loader.load()
169
  return docs[0].page_content
170
 
171
-
172
  def create_rag_index(collection_name, text_no_prefix):
173
- """Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
174
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
175
- tokenizer,
176
- chunk_size=256,
177
- chunk_overlap=0,
178
- add_start_index=True,
179
- strip_whitespace=True,
180
- separators=["\n\n", "\n", ".", " ", ""],
181
- )
182
- # Concatenate pages and create Document objects.
183
  docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
184
  vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model)
185
  return vectorstore
186
 
187
-
188
  @spaces.GPU
189
  def auto_convert(file_objs, url, do_ocr, do_table_structure):
 
 
190
  if file_objs is None and (url is None or url.strip() == ""):
191
  return (
192
  gr.update(value=""),
193
  "Number of tokens before compression: ",
194
- gr.update(),
195
  "Number of tokens after compression: ",
196
  0,
197
- gr.update(interactive=False), # Disable compress button when no input.
198
  False,
199
- {} # return an empty state dictionary
 
200
  )
201
- # Convert the document to markdown.
202
  print("Converting to markdown")
203
- markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  print("Done")
205
  combined_text = prefix + markdown
206
  print("Suggestioning Compression ratio")
@@ -213,8 +214,6 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
213
  token_count_str = f"Number of tokens before compression: {token_count}"
214
  retrieval_str = f"Number of tokens after compression: {retrieval_tokens}"
215
  slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1)
216
-
217
- # Create the RAG index immediately.
218
  if combined_text.startswith(prefix):
219
  rag_text = combined_text[len(prefix):]
220
  else:
@@ -223,18 +222,17 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
223
  rag_index = create_rag_index(collection_name, rag_text)
224
  state = {"rag_index": collection_name}
225
  print("Done")
226
-
227
  return (
228
- combined_text,
229
- token_count_str,
230
- slider_update,
231
- retrieval_str,
232
- token_count,
233
- gr.update(interactive=True),
234
  False,
235
- state
 
236
  )
237
-
238
 
239
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
240
  device = model.device
@@ -250,32 +248,18 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
250
  max_context_tokens_allowed = model.config.max_position_embeddings - question_len
251
  if total_len > max_context_tokens_allowed:
252
  num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
253
-
254
  if total_len <= sink_tokens or num_chunks == 1:
255
- # If the context is too short or only one chunk is desired, use the entire context.
256
  context_ids_list = [context_ids]
257
  context_attention_mask_list = [context_attention_mask]
258
  else:
259
- # Calculate how many tokens remain after the sink tokens.
260
  remainder_len = total_len - sink_tokens
261
-
262
- # Compute the base tokens per chunk and any leftover.
263
  base = remainder_len // num_chunks
264
  leftover = remainder_len % num_chunks
265
-
266
- # Build a list of chunk sizes.
267
- # First chunk gets the sink tokens plus base tokens.
268
  chunk_sizes = [sink_tokens + base]
269
-
270
- # Chunks 2 to num_chunks-1 get base tokens each.
271
  for _ in range(num_chunks - 2):
272
  chunk_sizes.append(base)
273
-
274
- # The last chunk gets the remaining tokens (base + leftover).
275
  if num_chunks > 1:
276
  chunk_sizes.append(base + leftover)
277
-
278
- # Now slice the context using the calculated sizes.
279
  context_ids_list = []
280
  context_attention_mask_list = []
281
  offset = 0
@@ -284,33 +268,23 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
284
  context_ids_list.append(context_ids[:, offset:end])
285
  context_attention_mask_list.append(context_attention_mask[:, offset:end])
286
  offset = end
287
-
288
- # (Optional) Continue with the rest of your processing…
289
  len_rest = max(total_len - sink_tokens, 1)
290
  compression_factor = len_rest // target_token_size
291
  if compression_factor < 1:
292
  compression_factor = 1
293
-
294
  tokenized_doc_chunks = []
295
  for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
296
  tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
297
-
298
  print("Number of chunks: ", len(tokenized_doc_chunks))
299
-
300
  rotary_emb = model.model.rotary_emb.to(device)
301
  inv_freq = rotary_emb.inv_freq
302
  batch_size = question_ids.size(0)
303
  ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
304
-
305
  cache = FinchCache()
306
  past_cache_len = 0
307
  past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
308
  num_chunks = len(tokenized_doc_chunks)
309
-
310
- # Prepare a shared dictionary for hook outputs.
311
  query_context_matrices = {}
312
-
313
- # Define a hook function that uses a per-chunk offset stored on self.
314
  def query_hook_fn(module, input, output):
315
  layer_idx = getattr(module, "layer_idx", None)
316
  if layer_idx is not None:
@@ -323,26 +297,17 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
323
  .transpose(1, 2)
324
  .contiguous()
325
  )
326
- # Use self._current_chunk_offset to select only the new tokens.
327
  query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
328
-
329
- # Pre-register hooks for all layers only once.
330
  hooks = []
331
  for i, layer in enumerate(model.model.layers):
332
- layer.self_attn.q_proj.layer_idx = i # For tracking.
333
  layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
334
  hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
335
  hooks.append(hook)
336
-
337
- # Process each document chunk sequentially.
338
  for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
339
  current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
340
- # Save the offset in an attribute the hook can access.
341
  _current_chunk_offset = current_seq_length
342
- # Clear the dictionary from any previous chunk.
343
  query_context_matrices.clear()
344
-
345
- # These chunks are already on the device.
346
  chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
347
  chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
348
  segment_attention_mask = torch.cat(
@@ -350,7 +315,6 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
350
  ).contiguous()
351
  current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
352
  current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
353
-
354
  past_seen_tokens = cache.get_seq_length() if cache is not None else 0
355
  cache_position = torch.arange(
356
  past_seen_tokens + chunk_input_ids.shape[1],
@@ -366,7 +330,6 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
366
  cache_position=cache_position,
367
  batch_size=current_input_ids.size(0),
368
  ).contiguous()
369
-
370
  with torch.no_grad():
371
  outputs = model.model(
372
  input_ids=current_input_ids,
@@ -374,9 +337,7 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
374
  past_key_values=cache,
375
  )
376
  cache = outputs.past_key_values
377
-
378
  len_question = question_ids.size(1)
379
- # Now, for each transformer layer, update the cache using the query/key attention.
380
  for layer_idx in range(len(model.model.layers)):
381
  key_matrix = cache.key_cache[layer_idx]
382
  query_matrix = query_context_matrices[layer_idx]
@@ -392,13 +353,11 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
392
  query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
393
  num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
394
  key_matrix = repeat_kv(key_matrix, num_repeats)
395
-
396
  scaling = math.sqrt(model.config.head_dim)
397
  attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
398
  causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
399
  attention_matrix = attention_matrix + causal_mask_sliced
400
  attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
401
- # Normalization
402
  tol = 1e-8
403
  binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
404
  non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
@@ -423,30 +382,21 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
423
  to_keep_new = int(current_seq_length // compression_factor)
424
  k = min(past_cache_len + to_keep_new, full_context_size)
425
  else:
426
- desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens
427
  k = desired_final if full_context_size >= desired_final else full_context_size
428
  k = max(k, sink_tokens)
429
  selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
430
  selected_indices, _ = torch.sort(selected_indices, dim=-1)
431
  cache.compress_cache(layer_idx, selected_indices, inv_freq)
432
-
433
  past_cache_len = cache._seen_tokens
434
  past_attention_mask = torch.ones(1, past_cache_len, device=device)
435
-
436
- # Remove the hooks once after all chunks are processed.
437
  for hook in hooks:
438
  hook.remove()
439
-
440
  return cache
441
 
442
-
443
  def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
444
- """
445
- For naive RAG, retrieves top-k chunks (k based on target token size)
446
- and generates an answer using those chunks.
447
- """
448
  k = max(1, rag_token_size // 256)
449
- vectorstore = Chroma(persist_directory="./chroma_db", embedding=embedding_model, collection_name=collection_name)
450
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
451
  retrieved_docs = retriever.invoke(query)
452
  for doc in retrieved_docs:
@@ -454,17 +404,11 @@ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, fe
454
  print(doc.page_content)
455
  print("=================")
456
  formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
457
-
458
  rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples
459
-
460
  return rag_context
461
 
462
-
463
  @spaces.GPU
464
  def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state):
465
- """
466
- Prepares the compressed KV cache. Uses the precomputed rag_index from state.
467
- """
468
  percentage = int(global_local_value.replace('%', ''))
469
  question_text = task_description + "\n" + few_shot
470
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
@@ -474,15 +418,23 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
474
  question_ids = question_encoding["input_ids"]
475
  question_attention_mask = question_encoding["attention_mask"]
476
  retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value)
477
-
 
 
 
478
  if percentage > 0:
479
  target_token_size = int(retrieval_context_length * (percentage / 100))
480
  print("Target token size for compression: ", target_token_size)
481
  step_size = 2
482
  start_time_prefill = time.perf_counter()
483
- past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
484
- context_ids, context_attention_mask,
485
- question_ids, question_attention_mask))
 
 
 
 
 
486
  compressed_length = past_key_values.get_seq_length()
487
  print("Context size after compression: ", compressed_length)
488
  print("Compression rate: ", context_ids.size(1) / compressed_length)
@@ -491,15 +443,11 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
491
  target_token_size = 0
492
  past_key_values = FinchCache()
493
  compressed_length = past_key_values.get_seq_length()
494
-
495
- cache_name = "default_cache_" + uuid.uuid4().hex[:6]
496
  cache_name = "default_cache_" + uuid.uuid4().hex[:6] + ".pt"
497
  save_dir = "./cache_dir"
498
  os.makedirs(save_dir, exist_ok=True)
499
  save_path = os.path.join(save_dir, cache_name)
500
  past_key_values.save(save_path)
501
-
502
- # Use the precomputed rag_index from state.
503
  collection_name = state.get("rag_index", None)
504
  if collection_name is None:
505
  print("Collection name not found creating a new one.")
@@ -509,7 +457,6 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
509
  rag_text = combined_text
510
  collection_name = "default_collection_" + uuid.uuid4().hex[:6]
511
  rag_index = create_rag_index(collection_name, rag_text)
512
-
513
  state.update({
514
  "compressed_cache": save_path,
515
  "compressed_length": compressed_length,
@@ -520,32 +467,28 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
520
  "task_description": task_description,
521
  "few_shot": few_shot,
522
  "retrieval_slider": retrieval_context_length,
523
- "prefill_time": time.perf_counter() - start_time_prefill
 
 
 
524
  })
525
  return state, True
526
 
527
-
528
  @spaces.GPU
529
  def chat_response_stream(message: str, history: list, state: dict):
530
- """
531
- Generates a chat response with streaming output.
532
- Returns a simple string (not a list of message dicts) for ChatInterface.
533
- """
534
  user_message = message
535
  save_path = state["compressed_cache"]
536
  past_key_values = FinchCache.load(save_path, device=model.device)
537
- try:
538
- os.remove(save_path)
539
- except Exception as e:
540
- print(f"Error removing cache file: {e}")
541
  compressed_length = past_key_values.get_seq_length()
542
  collection_name = state["rag_index"]
543
  retrieval_slider_value = state["retrieval_slider"]
544
  percentage = state["global_local"]
545
-
546
  rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
547
  print("RAG retrieval size: ", rag_retrieval_size)
548
-
549
  if percentage == 0:
550
  rag_prefix = prefix
551
  rag_task = state["task_description"]
@@ -565,7 +508,6 @@ def chat_response_stream(message: str, history: list, state: dict):
565
  eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long)
566
  new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1)
567
  new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1)
568
-
569
  print("New input is: ", new_input)
570
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
571
  generate_kwargs = dict(
@@ -583,18 +525,28 @@ def chat_response_stream(message: str, history: list, state: dict):
583
  )
584
  t = Thread(target=model.generate, kwargs=generate_kwargs)
585
  t.start()
586
-
587
  full_output = ""
588
  for text in streamer:
589
  full_output += text
590
  time.sleep(0.05)
591
  yield full_output
592
-
593
  state["compressed_cache"] = past_key_values
594
  return full_output
595
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  ##########################################################################
597
- # Gradio Interface: note that we now use ChatInterface instead of a Chatbot.
598
  ##########################################################################
599
  CSS = """
600
  body {
@@ -694,39 +646,57 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
694
  token_count_text = gr.Markdown("Number of tokens before compression: ")
695
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
696
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
 
 
697
  global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)",
698
  choices=["0%", "25%", "50%", "75%", "100%"], value="75%")
699
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
 
 
700
 
701
  file_input.change(
702
  fn=auto_convert,
703
  inputs=[file_input, url_input, do_ocr, do_table],
704
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state]
705
  )
706
  url_input.change(
707
  fn=auto_convert,
708
  inputs=[file_input, url_input, do_ocr, do_table],
709
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state]
710
  )
711
  do_ocr.change(
712
  fn=auto_convert,
713
  inputs=[file_input, url_input, do_ocr, do_table],
714
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state]
715
  )
716
  do_table.change(
717
  fn=auto_convert,
718
  inputs=[file_input, url_input, do_ocr, do_table],
719
- outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state]
720
  )
721
  retrieval_slider.change(
722
  fn=update_retrieval_context,
723
  inputs=[hidden_token_count, retrieval_slider],
724
  outputs=retrieval_info_text
725
  )
 
 
 
 
 
 
 
 
 
 
 
726
  compress_button.click(
727
  fn=prepare_compression_and_rag,
728
  inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state],
729
  outputs=[compressed_doc_state, compression_done]
 
 
 
730
  )
731
 
732
  with gr.Column(elem_classes="chatbot-container"):
 
3
  import os
4
  import time
5
  from threading import Thread
6
+ import uuid
7
 
8
  import gradio as gr
9
  import spaces
 
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
21
  from transformers.models.llama.modeling_llama import rotate_half
 
22
 
23
  from utils import (
24
  calculate_tokens_suggest_compression_ratio,
 
26
  update_retrieval_context,
27
  )
28
 
 
 
29
  # Initialize the model and tokenizer.
30
  api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
31
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
 
35
  model = model.eval()
36
  model.to(device)
37
  embedding_model = HuggingFaceBgeEmbeddings(
38
+ model_name="BAAI/bge-large-en-v1.5",
39
+ model_kwargs={"device": str(device)},
40
+ encode_kwargs={"normalize_embeddings": True},
41
+ query_instruction=""
42
+ )
 
43
 
44
  # Create a chat template and split into prefix and suffix.
45
  content_system = ""
 
118
  self._seen_tokens = new_length
119
 
120
  def save(self, path: str):
 
121
  try:
122
  os.makedirs(os.path.dirname(path), exist_ok=True)
123
  torch.save(
 
129
 
130
  @classmethod
131
  def load(cls, path: str, device: str = "cpu") -> "FinchCache":
 
132
  data = torch.load(path, map_location=device)
133
  cache = cls()
134
  cache.key_cache = [k.to(device) for k in data["key_cache"]]
 
136
  cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0
137
  return cache
138
 
 
 
139
  def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
140
  file_path = file_objs if file_objs is not None else url
141
  pipeline_options = PdfPipelineOptions()
 
147
  )
148
  doc_converter = DocumentConverter(
149
  allowed_formats=[InputFormat.PDF],
150
+ format_options={InputFormat.PDF: pdf_format_options}
 
 
151
  )
 
 
152
  loader = DoclingLoader(
153
  file_path=file_path,
154
  export_type=ExportType.MARKDOWN,
 
157
  docs = loader.load()
158
  return docs[0].page_content
159
 
 
160
  def create_rag_index(collection_name, text_no_prefix):
 
161
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
162
+ tokenizer,
163
+ chunk_size=256,
164
+ chunk_overlap=0,
165
+ add_start_index=True,
166
+ strip_whitespace=True,
167
+ separators=["\n\n", "\n", ".", " ", ""],
168
+ )
 
169
  docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
170
  vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model)
171
  return vectorstore
172
 
 
173
  @spaces.GPU
174
  def auto_convert(file_objs, url, do_ocr, do_table_structure):
175
+ # When a new file/URL is loaded, disable chat (compression not done)
176
+ chat_status = "Document not compressed yet. Please compress the document to enable chat."
177
  if file_objs is None and (url is None or url.strip() == ""):
178
  return (
179
  gr.update(value=""),
180
  "Number of tokens before compression: ",
181
+ gr.update(),
182
  "Number of tokens after compression: ",
183
  0,
184
+ gr.update(interactive=False),
185
  False,
186
+ {},
187
+ chat_status
188
  )
 
189
  print("Converting to markdown")
190
+ try:
191
+ markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
192
+ except Exception as e:
193
+ print("Error converting to markdown:", e)
194
+ return (
195
+ gr.update(value="Error converting document to markdown. Please try uploading another document format."),
196
+ "Number of tokens before compression: ",
197
+ gr.update(),
198
+ "Number of tokens after compression: ",
199
+ 0,
200
+ gr.update(interactive=False),
201
+ False,
202
+ {},
203
+ chat_status
204
+ )
205
  print("Done")
206
  combined_text = prefix + markdown
207
  print("Suggestioning Compression ratio")
 
214
  token_count_str = f"Number of tokens before compression: {token_count}"
215
  retrieval_str = f"Number of tokens after compression: {retrieval_tokens}"
216
  slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1)
 
 
217
  if combined_text.startswith(prefix):
218
  rag_text = combined_text[len(prefix):]
219
  else:
 
222
  rag_index = create_rag_index(collection_name, rag_text)
223
  state = {"rag_index": collection_name}
224
  print("Done")
 
225
  return (
226
+ combined_text,
227
+ token_count_str,
228
+ slider_update,
229
+ retrieval_str,
230
+ token_count,
231
+ gr.update(interactive=True), # Enable compress button if conversion succeeds.
232
  False,
233
+ state,
234
+ chat_status
235
  )
 
236
 
237
  def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
238
  device = model.device
 
248
  max_context_tokens_allowed = model.config.max_position_embeddings - question_len
249
  if total_len > max_context_tokens_allowed:
250
  num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
 
251
  if total_len <= sink_tokens or num_chunks == 1:
 
252
  context_ids_list = [context_ids]
253
  context_attention_mask_list = [context_attention_mask]
254
  else:
 
255
  remainder_len = total_len - sink_tokens
 
 
256
  base = remainder_len // num_chunks
257
  leftover = remainder_len % num_chunks
 
 
 
258
  chunk_sizes = [sink_tokens + base]
 
 
259
  for _ in range(num_chunks - 2):
260
  chunk_sizes.append(base)
 
 
261
  if num_chunks > 1:
262
  chunk_sizes.append(base + leftover)
 
 
263
  context_ids_list = []
264
  context_attention_mask_list = []
265
  offset = 0
 
268
  context_ids_list.append(context_ids[:, offset:end])
269
  context_attention_mask_list.append(context_attention_mask[:, offset:end])
270
  offset = end
 
 
271
  len_rest = max(total_len - sink_tokens, 1)
272
  compression_factor = len_rest // target_token_size
273
  if compression_factor < 1:
274
  compression_factor = 1
 
275
  tokenized_doc_chunks = []
276
  for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
277
  tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
 
278
  print("Number of chunks: ", len(tokenized_doc_chunks))
 
279
  rotary_emb = model.model.rotary_emb.to(device)
280
  inv_freq = rotary_emb.inv_freq
281
  batch_size = question_ids.size(0)
282
  ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
 
283
  cache = FinchCache()
284
  past_cache_len = 0
285
  past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
286
  num_chunks = len(tokenized_doc_chunks)
 
 
287
  query_context_matrices = {}
 
 
288
  def query_hook_fn(module, input, output):
289
  layer_idx = getattr(module, "layer_idx", None)
290
  if layer_idx is not None:
 
297
  .transpose(1, 2)
298
  .contiguous()
299
  )
 
300
  query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
 
 
301
  hooks = []
302
  for i, layer in enumerate(model.model.layers):
303
+ layer.self_attn.q_proj.layer_idx = i
304
  layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
305
  hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
306
  hooks.append(hook)
 
 
307
  for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
308
  current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
 
309
  _current_chunk_offset = current_seq_length
 
310
  query_context_matrices.clear()
 
 
311
  chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
312
  chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
313
  segment_attention_mask = torch.cat(
 
315
  ).contiguous()
316
  current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
317
  current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
 
318
  past_seen_tokens = cache.get_seq_length() if cache is not None else 0
319
  cache_position = torch.arange(
320
  past_seen_tokens + chunk_input_ids.shape[1],
 
330
  cache_position=cache_position,
331
  batch_size=current_input_ids.size(0),
332
  ).contiguous()
 
333
  with torch.no_grad():
334
  outputs = model.model(
335
  input_ids=current_input_ids,
 
337
  past_key_values=cache,
338
  )
339
  cache = outputs.past_key_values
 
340
  len_question = question_ids.size(1)
 
341
  for layer_idx in range(len(model.model.layers)):
342
  key_matrix = cache.key_cache[layer_idx]
343
  query_matrix = query_context_matrices[layer_idx]
 
353
  query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
354
  num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
355
  key_matrix = repeat_kv(key_matrix, num_repeats)
 
356
  scaling = math.sqrt(model.config.head_dim)
357
  attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
358
  causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
359
  attention_matrix = attention_matrix + causal_mask_sliced
360
  attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
 
361
  tol = 1e-8
362
  binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
363
  non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
 
382
  to_keep_new = int(current_seq_length // compression_factor)
383
  k = min(past_cache_len + to_keep_new, full_context_size)
384
  else:
385
+ desired_final = sink_tokens + target_token_size + len_question
386
  k = desired_final if full_context_size >= desired_final else full_context_size
387
  k = max(k, sink_tokens)
388
  selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
389
  selected_indices, _ = torch.sort(selected_indices, dim=-1)
390
  cache.compress_cache(layer_idx, selected_indices, inv_freq)
 
391
  past_cache_len = cache._seen_tokens
392
  past_attention_mask = torch.ones(1, past_cache_len, device=device)
 
 
393
  for hook in hooks:
394
  hook.remove()
 
395
  return cache
396
 
 
397
  def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
 
 
 
 
398
  k = max(1, rag_token_size // 256)
399
+ vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embedding_model, collection_name=collection_name)
400
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
401
  retrieved_docs = retriever.invoke(query)
402
  for doc in retrieved_docs:
 
404
  print(doc.page_content)
405
  print("=================")
406
  formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
 
407
  rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples
 
408
  return rag_context
409
 
 
410
  @spaces.GPU
411
  def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state):
 
 
 
412
  percentage = int(global_local_value.replace('%', ''))
413
  question_text = task_description + "\n" + few_shot
414
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
 
418
  question_ids = question_encoding["input_ids"]
419
  question_attention_mask = question_encoding["attention_mask"]
420
  retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value)
421
+ # Compute token breakdown for display (KV compress vs RAG tokens)
422
+ rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
423
+ kv_tokens = retrieval_context_length - rag_tokens
424
+ print(f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}")
425
  if percentage > 0:
426
  target_token_size = int(retrieval_context_length * (percentage / 100))
427
  print("Target token size for compression: ", target_token_size)
428
  step_size = 2
429
  start_time_prefill = time.perf_counter()
430
+ try:
431
+ past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
432
+ context_ids, context_attention_mask,
433
+ question_ids, question_attention_mask))
434
+ except Exception as e:
435
+ print("Error during KV cache compression:", e)
436
+ state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again."
437
+ return state, False
438
  compressed_length = past_key_values.get_seq_length()
439
  print("Context size after compression: ", compressed_length)
440
  print("Compression rate: ", context_ids.size(1) / compressed_length)
 
443
  target_token_size = 0
444
  past_key_values = FinchCache()
445
  compressed_length = past_key_values.get_seq_length()
 
 
446
  cache_name = "default_cache_" + uuid.uuid4().hex[:6] + ".pt"
447
  save_dir = "./cache_dir"
448
  os.makedirs(save_dir, exist_ok=True)
449
  save_path = os.path.join(save_dir, cache_name)
450
  past_key_values.save(save_path)
 
 
451
  collection_name = state.get("rag_index", None)
452
  if collection_name is None:
453
  print("Collection name not found creating a new one.")
 
457
  rag_text = combined_text
458
  collection_name = "default_collection_" + uuid.uuid4().hex[:6]
459
  rag_index = create_rag_index(collection_name, rag_text)
 
460
  state.update({
461
  "compressed_cache": save_path,
462
  "compressed_length": compressed_length,
 
467
  "task_description": task_description,
468
  "few_shot": few_shot,
469
  "retrieval_slider": retrieval_context_length,
470
+ "prefill_time": time.perf_counter() - start_time_prefill,
471
+ "compression_done": True,
472
+ "tokens_breakdown": f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}",
473
+ "chat_feedback": "Document compressed successfully. You can now chat."
474
  })
475
  return state, True
476
 
 
477
  @spaces.GPU
478
  def chat_response_stream(message: str, history: list, state: dict):
479
+ # Check if the document is compressed before allowing chat
480
+ if not state.get("compression_done", False) or "compressed_cache" not in state:
481
+ yield "Document not compressed yet. Please compress the document first to enable chat."
482
+ return
483
  user_message = message
484
  save_path = state["compressed_cache"]
485
  past_key_values = FinchCache.load(save_path, device=model.device)
 
 
 
 
486
  compressed_length = past_key_values.get_seq_length()
487
  collection_name = state["rag_index"]
488
  retrieval_slider_value = state["retrieval_slider"]
489
  percentage = state["global_local"]
 
490
  rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
491
  print("RAG retrieval size: ", rag_retrieval_size)
 
492
  if percentage == 0:
493
  rag_prefix = prefix
494
  rag_task = state["task_description"]
 
508
  eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long)
509
  new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1)
510
  new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1)
 
511
  print("New input is: ", new_input)
512
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
513
  generate_kwargs = dict(
 
525
  )
526
  t = Thread(target=model.generate, kwargs=generate_kwargs)
527
  t.start()
 
528
  full_output = ""
529
  for text in streamer:
530
  full_output += text
531
  time.sleep(0.05)
532
  yield full_output
 
533
  state["compressed_cache"] = past_key_values
534
  return full_output
535
 
536
+ def update_token_breakdown(token_count, retrieval_slider_value, global_local_value):
537
+ try:
538
+ token_count = int(token_count)
539
+ slider_val = float(retrieval_slider_value)
540
+ percentage = int(global_local_value.replace('%', ''))
541
+ retrieval_context_length = int(token_count / slider_val)
542
+ rag_tokens = int(retrieval_context_length * (1 - (percentage / 100)))
543
+ kv_tokens = retrieval_context_length - rag_tokens
544
+ return f"KV Compress Tokens: {kv_tokens}, RAG Tokens: {rag_tokens}"
545
+ except Exception as e:
546
+ return "Token breakdown unavailable."
547
+
548
  ##########################################################################
549
+ # Gradio Interface
550
  ##########################################################################
551
  CSS = """
552
  body {
 
646
  token_count_text = gr.Markdown("Number of tokens before compression: ")
647
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
648
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
649
+ # New widget for token breakdown (KV vs RAG)
650
+ tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
651
  global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)",
652
  choices=["0%", "25%", "50%", "75%", "100%"], value="75%")
653
  compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
654
+ # New widget for chat status feedback
655
+ chat_status_text = gr.Markdown("Document not compressed yet. Please compress the document to enable chat.")
656
 
657
  file_input.change(
658
  fn=auto_convert,
659
  inputs=[file_input, url_input, do_ocr, do_table],
660
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
661
  )
662
  url_input.change(
663
  fn=auto_convert,
664
  inputs=[file_input, url_input, do_ocr, do_table],
665
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
666
  )
667
  do_ocr.change(
668
  fn=auto_convert,
669
  inputs=[file_input, url_input, do_ocr, do_table],
670
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
671
  )
672
  do_table.change(
673
  fn=auto_convert,
674
  inputs=[file_input, url_input, do_ocr, do_table],
675
+ outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text]
676
  )
677
  retrieval_slider.change(
678
  fn=update_retrieval_context,
679
  inputs=[hidden_token_count, retrieval_slider],
680
  outputs=retrieval_info_text
681
  )
682
+ # Update token breakdown when slider or global/local changes
683
+ retrieval_slider.change(
684
+ fn=update_token_breakdown,
685
+ inputs=[hidden_token_count, retrieval_slider, global_local_slider],
686
+ outputs=tokens_breakdown_text
687
+ )
688
+ global_local_slider.change(
689
+ fn=update_token_breakdown,
690
+ inputs=[hidden_token_count, retrieval_slider, global_local_slider],
691
+ outputs=tokens_breakdown_text
692
+ )
693
  compress_button.click(
694
  fn=prepare_compression_and_rag,
695
  inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state],
696
  outputs=[compressed_doc_state, compression_done]
697
+ ).then(
698
+ fn=lambda state: gr.update(value="Document compressed successfully. You can now chat."),
699
+ outputs=chat_status_text
700
  )
701
 
702
  with gr.Column(elem_classes="chatbot-container"):