giulio98 commited on
Commit
fa8f9de
·
1 Parent(s): 0551c31

saving collection

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -18,6 +18,7 @@ from langchain_docling.loader import ExportType
18
  from langchain_text_splitters import RecursiveCharacterTextSplitter
19
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
20
  from transformers.models.llama.modeling_llama import rotate_half
 
21
 
22
  from utils import (
23
  calculate_tokens_suggest_compression_ratio,
@@ -28,7 +29,7 @@ from utils import (
28
 
29
 
30
  # Initialize the model and tokenizer.
31
- api_token = os.getenv("HF_TOKEN")
32
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
33
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
34
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
@@ -68,8 +69,6 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
68
  answer: Norman
69
  """
70
 
71
- global_rag_index = None
72
-
73
  class FinchCache(DynamicCache):
74
  def __init__(self) -> None:
75
  super().__init__()
@@ -169,7 +168,8 @@ def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
169
  docs = loader.load()
170
  return docs[0].page_content
171
 
172
- def create_rag_index(text_no_prefix):
 
173
  """Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
174
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
175
  tokenizer,
@@ -181,13 +181,12 @@ def create_rag_index(text_no_prefix):
181
  )
182
  # Concatenate pages and create Document objects.
183
  docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
184
- vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model)
185
  return vectorstore
186
 
187
 
188
  @spaces.GPU
189
  def auto_convert(file_objs, url, do_ocr, do_table_structure):
190
- global global_rag_index
191
  if file_objs is None and (url is None or url.strip() == ""):
192
  return (
193
  gr.update(value=""),
@@ -204,7 +203,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
204
  markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
205
  print("Done")
206
  combined_text = prefix + markdown
207
- print("Calculating tokens")
208
  token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model)
209
  print("Done")
210
  min_ratio = min(suggestions)
@@ -220,10 +219,10 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
220
  rag_text = combined_text[len(prefix):]
221
  else:
222
  rag_text = combined_text
223
- print("Creating RAG index")
224
- global_rag_index = create_rag_index(rag_text)
 
225
  print("Done")
226
- state = {}
227
 
228
  return (
229
  combined_text,
@@ -441,14 +440,14 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
441
  return cache
442
 
443
 
444
- def run_naive_rag_query(query, rag_token_size, prefix, task, few_shot_examples):
445
  """
446
  For naive RAG, retrieves top-k chunks (k based on target token size)
447
  and generates an answer using those chunks.
448
  """
449
- global global_rag_index
450
  k = max(1, rag_token_size // 256)
451
- retriever = global_rag_index.as_retriever(search_type="similarity", search_kwargs={"k": k})
 
452
  retrieved_docs = retriever.invoke(query)
453
  for doc in retrieved_docs:
454
  print("=================")
@@ -466,7 +465,6 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
466
  """
467
  Prepares the compressed KV cache. Uses the precomputed rag_index from state.
468
  """
469
- global global_rag_index
470
  percentage = int(global_local_value.replace('%', ''))
471
  question_text = task_description + "\n" + few_shot
472
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
@@ -482,11 +480,9 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
482
  print("Target token size for compression: ", target_token_size)
483
  step_size = 2
484
  start_time_prefill = time.perf_counter()
485
- print("Compressing KV Cache")
486
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
487
  context_ids, context_attention_mask,
488
  question_ids, question_attention_mask))
489
- print("Done")
490
  compressed_length = past_key_values.get_seq_length()
491
  print("Context size after compression: ", compressed_length)
492
  print("Compression rate: ", context_ids.size(1) / compressed_length)
@@ -497,17 +493,21 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
497
  compressed_length = past_key_values.get_seq_length()
498
 
499
 
500
-
501
- if global_rag_index is None:
 
 
502
  if combined_text.startswith(prefix):
503
  rag_text = combined_text[len(prefix):]
504
  else:
505
  rag_text = combined_text
506
- global_rag_index = create_rag_index(rag_text, device)
507
-
 
508
  state.update({
509
  "compressed_cache": past_key_values,
510
  "compressed_length": compressed_length,
 
511
  "target_token_size": target_token_size,
512
  "global_local": percentage,
513
  "combined_text": combined_text,
@@ -528,6 +528,7 @@ def chat_response_stream(message: str, history: list, state: dict):
528
  user_message = message
529
  past_key_values = state["compressed_cache"]
530
  compressed_length = past_key_values.get_seq_length()
 
531
  retrieval_slider_value = state["retrieval_slider"]
532
  percentage = state["global_local"]
533
 
@@ -544,7 +545,8 @@ def chat_response_stream(message: str, history: list, state: dict):
544
  rag_few_shot = ""
545
  print("user message: ", user_message)
546
  if rag_retrieval_size != 0:
547
- rag_context = run_naive_rag_query(user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
 
548
  new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
549
  else:
550
  new_input = "\nquestion: " + user_message + suffix + "answer:"
@@ -724,4 +726,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
724
  type="messages"
725
  )
726
 
727
- demo.queue().launch()
 
18
  from langchain_text_splitters import RecursiveCharacterTextSplitter
19
  from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
20
  from transformers.models.llama.modeling_llama import rotate_half
21
+ import uuid
22
 
23
  from utils import (
24
  calculate_tokens_suggest_compression_ratio,
 
29
 
30
 
31
  # Initialize the model and tokenizer.
32
+ api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
33
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
35
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
 
69
  answer: Norman
70
  """
71
 
 
 
72
  class FinchCache(DynamicCache):
73
  def __init__(self) -> None:
74
  super().__init__()
 
168
  docs = loader.load()
169
  return docs[0].page_content
170
 
171
+
172
+ def create_rag_index(collection_name, text_no_prefix):
173
  """Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
174
  text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
175
  tokenizer,
 
181
  )
182
  # Concatenate pages and create Document objects.
183
  docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
184
+ vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model)
185
  return vectorstore
186
 
187
 
188
  @spaces.GPU
189
  def auto_convert(file_objs, url, do_ocr, do_table_structure):
 
190
  if file_objs is None and (url is None or url.strip() == ""):
191
  return (
192
  gr.update(value=""),
 
203
  markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
204
  print("Done")
205
  combined_text = prefix + markdown
206
+ print("Suggestioning Compression ratio")
207
  token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model)
208
  print("Done")
209
  min_ratio = min(suggestions)
 
219
  rag_text = combined_text[len(prefix):]
220
  else:
221
  rag_text = combined_text
222
+ collection_name = "default_collection_" + uuid.uuid4().hex[:6]
223
+ rag_index = create_rag_index(collection_name, rag_text)
224
+ state = {"rag_index": collection_name}
225
  print("Done")
 
226
 
227
  return (
228
  combined_text,
 
440
  return cache
441
 
442
 
443
+ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
444
  """
445
  For naive RAG, retrieves top-k chunks (k based on target token size)
446
  and generates an answer using those chunks.
447
  """
 
448
  k = max(1, rag_token_size // 256)
449
+ vectorstore = Chroma(persist_directory="./chroma_db", embedding=embedding_model, collection_name=collection_name)
450
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
451
  retrieved_docs = retriever.invoke(query)
452
  for doc in retrieved_docs:
453
  print("=================")
 
465
  """
466
  Prepares the compressed KV cache. Uses the precomputed rag_index from state.
467
  """
 
468
  percentage = int(global_local_value.replace('%', ''))
469
  question_text = task_description + "\n" + few_shot
470
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
 
480
  print("Target token size for compression: ", target_token_size)
481
  step_size = 2
482
  start_time_prefill = time.perf_counter()
 
483
  past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
484
  context_ids, context_attention_mask,
485
  question_ids, question_attention_mask))
 
486
  compressed_length = past_key_values.get_seq_length()
487
  print("Context size after compression: ", compressed_length)
488
  print("Compression rate: ", context_ids.size(1) / compressed_length)
 
493
  compressed_length = past_key_values.get_seq_length()
494
 
495
 
496
+ # Use the precomputed rag_index from state.
497
+ collection_name = state.get("rag_index", None)
498
+ if collection_name is None:
499
+ print("Collection name not found creating a new one.")
500
  if combined_text.startswith(prefix):
501
  rag_text = combined_text[len(prefix):]
502
  else:
503
  rag_text = combined_text
504
+ collection_name = "default_collection_" + uuid.uuid4().hex[:6]
505
+ rag_index = create_rag_index(collection_name, rag_text)
506
+
507
  state.update({
508
  "compressed_cache": past_key_values,
509
  "compressed_length": compressed_length,
510
+ "rag_index": collection_name,
511
  "target_token_size": target_token_size,
512
  "global_local": percentage,
513
  "combined_text": combined_text,
 
528
  user_message = message
529
  past_key_values = state["compressed_cache"]
530
  compressed_length = past_key_values.get_seq_length()
531
+ collection_name = state["rag_index"]
532
  retrieval_slider_value = state["retrieval_slider"]
533
  percentage = state["global_local"]
534
 
 
545
  rag_few_shot = ""
546
  print("user message: ", user_message)
547
  if rag_retrieval_size != 0:
548
+ print("Running RAG query")
549
+ rag_context = run_naive_rag_query(collection_name, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
550
  new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
551
  else:
552
  new_input = "\nquestion: " + user_message + suffix + "answer:"
 
726
  type="messages"
727
  )
728
 
729
+ demo.queue(max_size=16).launch()