giulio98 commited on
Commit
5f059ed
·
verified ·
1 Parent(s): d748a9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -21
app.py CHANGED
@@ -29,7 +29,7 @@ from utils import (
29
  )
30
 
31
  # Initialize the model and tokenizer.
32
- api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
33
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
35
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
@@ -456,7 +456,9 @@ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, fe
456
  def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
457
  progress(0, desc="Starting compression process")
458
 
459
- percentage = int(global_local_value.replace('%', ''))
 
 
460
  progress(0.1, desc="Tokenizing text and preparing task")
461
  question_text = task_description + "\n" + few_shot
462
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
@@ -538,6 +540,7 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
538
  percentage = state["global_local"]
539
  rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
540
  print("RAG retrieval size: ", rag_retrieval_size)
 
541
  if percentage == 0:
542
  rag_prefix = prefix
543
  rag_task = state["task_description"]
@@ -583,7 +586,9 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
583
 
584
  def update_token_breakdown(token_count, retrieval_slider, global_local_value):
585
  retrieval_context_length = int(token_count / retrieval_slider)
586
- percentage = int(global_local_value.replace('%', ''))
 
 
587
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
588
  kv_tokens = retrieval_context_length - rag_tokens
589
  return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
@@ -592,36 +597,51 @@ def update_token_breakdown(token_count, retrieval_slider, global_local_value):
592
  # Gradio Interface
593
  ##########################################################################
594
  CSS = """
595
- body {
596
- font-family: "Times New Roman", Times, serif;
 
 
 
 
 
 
 
 
597
  }
 
598
  .upload-section {
599
  padding: 10px;
600
  border: 2px dashed #ccc;
601
  border-radius: 10px;
602
  }
 
603
  .upload-button {
604
  background: #34c759 !important;
605
  color: white !important;
606
  border-radius: 25px !important;
607
  }
 
608
  .chatbot-container {
609
- margin-top: 20px;
610
  }
 
611
  .status-output {
612
  margin-top: 10px;
613
  font-size: 14px;
614
  }
 
615
  .processing-info {
616
  margin-top: 5px;
617
  font-size: 12px;
618
  color: #666;
619
  }
 
620
  .info-container {
621
  margin-top: 10px;
622
  padding: 10px;
623
  border-radius: 5px;
624
  }
 
625
  .file-list {
626
  margin-top: 0;
627
  max-height: 200px;
@@ -630,12 +650,14 @@ body {
630
  border: 1px solid #eee;
631
  border-radius: 5px;
632
  }
 
633
  .stats-box {
634
  margin-top: 10px;
635
  padding: 10px;
636
  border-radius: 5px;
637
  font-size: 12px;
638
  }
 
639
  .submit-btn {
640
  background: #1a73e8 !important;
641
  color: white !important;
@@ -644,18 +666,18 @@ body {
644
  padding: 5px 10px;
645
  font-size: 16px;
646
  }
 
647
  .input-row {
648
  display: flex;
649
  align-items: center;
650
  }
651
-
652
  """
653
  def reset_chat_state():
654
  return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
655
 
656
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
657
- gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
658
- gr.HTML("<center><p>Compress your document and chat with it.</p></center>")
659
 
660
  # Define chat_status_text as a Textbox with a set elem_id for custom styling.
661
  chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
@@ -666,13 +688,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
666
 
667
  with gr.Row(elem_classes="main-container"):
668
  with gr.Column(elem_classes="upload-section"):
669
- gr.Markdown("## Document Preprocessing")
670
  with gr.Row():
671
- file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area")
672
- url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf")
673
  with gr.Row():
674
- do_ocr = gr.Checkbox(label="Do OCR", value=False)
675
- do_table = gr.Checkbox(label="Include Table Structure", value=False)
676
  with gr.Accordion("Prompt Designer", open=False):
677
  task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
678
  few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
@@ -682,9 +704,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
682
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
683
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
684
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
685
- global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
686
- choices=["0%", "25%", "50%", "75%", "100%"], value="100%")
687
- compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button")
 
 
 
 
 
 
688
 
689
  # File input: Run auto_convert then chain reset_chat_state.
690
  file_input.change(
@@ -785,7 +813,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
785
  )
786
  with gr.Column(elem_classes="chatbot-container"):
787
  chat_status_text.render()
788
- gr.Markdown("## Chat")
789
  chat_interface = gr.ChatInterface(
790
  fn=chat_response_stream,
791
  additional_inputs=[compressed_doc_state, compression_done],
@@ -793,5 +821,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
793
  fill_height=True
794
  )
795
 
796
- demo.queue(max_size=16).launch()
797
-
 
29
  )
30
 
31
  # Initialize the model and tokenizer.
32
+ api_token = os.getenv("HF_TOKEN")
33
  model_name = "meta-llama/Llama-3.1-8B-Instruct"
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
35
  model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
 
456
  def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
457
  progress(0, desc="Starting compression process")
458
 
459
+ # percentage = int(global_local_value.replace('%', ''))
460
+ percentage = 0 if global_local_value == "RAG" else 100
461
+
462
  progress(0.1, desc="Tokenizing text and preparing task")
463
  question_text = task_description + "\n" + few_shot
464
  context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
 
540
  percentage = state["global_local"]
541
  rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
542
  print("RAG retrieval size: ", rag_retrieval_size)
543
+ print("Compressed cache: ", compressed_length)
544
  if percentage == 0:
545
  rag_prefix = prefix
546
  rag_task = state["task_description"]
 
586
 
587
  def update_token_breakdown(token_count, retrieval_slider, global_local_value):
588
  retrieval_context_length = int(token_count / retrieval_slider)
589
+ # percentage = int(global_local_value.replace('%', ''))
590
+ percentage = 0 if global_local_value == "RAG" else 100
591
+
592
  rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
593
  kv_tokens = retrieval_context_length - rag_tokens
594
  return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
 
597
  # Gradio Interface
598
  ##########################################################################
599
  CSS = """
600
+ .main-container {
601
+ display: flex;
602
+ align-items: stretch;
603
+ }
604
+
605
+ .upload-section, .chatbot-container {
606
+ display: flex;
607
+ flex-direction: column;
608
+ height: 100%;
609
+ overflow-y: auto;
610
  }
611
+
612
  .upload-section {
613
  padding: 10px;
614
  border: 2px dashed #ccc;
615
  border-radius: 10px;
616
  }
617
+
618
  .upload-button {
619
  background: #34c759 !important;
620
  color: white !important;
621
  border-radius: 25px !important;
622
  }
623
+
624
  .chatbot-container {
625
+ margin-top: 0;
626
  }
627
+
628
  .status-output {
629
  margin-top: 10px;
630
  font-size: 14px;
631
  }
632
+
633
  .processing-info {
634
  margin-top: 5px;
635
  font-size: 12px;
636
  color: #666;
637
  }
638
+
639
  .info-container {
640
  margin-top: 10px;
641
  padding: 10px;
642
  border-radius: 5px;
643
  }
644
+
645
  .file-list {
646
  margin-top: 0;
647
  max-height: 200px;
 
650
  border: 1px solid #eee;
651
  border-radius: 5px;
652
  }
653
+
654
  .stats-box {
655
  margin-top: 10px;
656
  padding: 10px;
657
  border-radius: 5px;
658
  font-size: 12px;
659
  }
660
+
661
  .submit-btn {
662
  background: #1a73e8 !important;
663
  color: white !important;
 
666
  padding: 5px 10px;
667
  font-size: 16px;
668
  }
669
+
670
  .input-row {
671
  display: flex;
672
  align-items: center;
673
  }
 
674
  """
675
  def reset_chat_state():
676
  return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
677
 
678
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo:
679
+ # gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
680
+ gr.HTML("<h1><center>Beyond RAG: Compress your document and chat with it.</center></h1>")
681
 
682
  # Define chat_status_text as a Textbox with a set elem_id for custom styling.
683
  chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
 
688
 
689
  with gr.Row(elem_classes="main-container"):
690
  with gr.Column(elem_classes="upload-section"):
691
+ gr.Markdown("### Document Preprocessing")
692
  with gr.Row():
693
+ file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area", height=120)
694
+ url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf", lines=2)
695
  with gr.Row():
696
+ do_ocr = gr.Checkbox(label="Do OCR on Images", value=True, visible=False)
697
+ do_table = gr.Checkbox(label="Parse Tables", value=True, visible=False)
698
  with gr.Accordion("Prompt Designer", open=False):
699
  task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
700
  few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
 
704
  retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
705
  retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
706
  tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
707
+ # global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
708
+ # choices=["0%", "25%", "50%", "75%", "100%"], value="100%")
709
+ global_local_slider = gr.Radio(
710
+ label="Retrieval Mode",
711
+ choices=["RAG", "KVCompress"],
712
+ value="KVCompress"
713
+ )
714
+
715
+ compress_button = gr.Button("Compress Document", interactive=False, size="md", elem_classes="upload-button")
716
 
717
  # File input: Run auto_convert then chain reset_chat_state.
718
  file_input.change(
 
813
  )
814
  with gr.Column(elem_classes="chatbot-container"):
815
  chat_status_text.render()
816
+ gr.Markdown("## Chat (LLama 3.1-8B-Instruct)")
817
  chat_interface = gr.ChatInterface(
818
  fn=chat_response_stream,
819
  additional_inputs=[compressed_doc_state, compression_done],
 
821
  fill_height=True
822
  )
823
 
824
+ demo.queue(max_size=16).launch()