Update app.py
Browse files
app.py
CHANGED
@@ -29,7 +29,7 @@ from utils import (
|
|
29 |
)
|
30 |
|
31 |
# Initialize the model and tokenizer.
|
32 |
-
api_token = os.getenv("
|
33 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
35 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
@@ -456,7 +456,9 @@ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, fe
|
|
456 |
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
|
457 |
progress(0, desc="Starting compression process")
|
458 |
|
459 |
-
percentage = int(global_local_value.replace('%', ''))
|
|
|
|
|
460 |
progress(0.1, desc="Tokenizing text and preparing task")
|
461 |
question_text = task_description + "\n" + few_shot
|
462 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
@@ -538,6 +540,7 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
|
|
538 |
percentage = state["global_local"]
|
539 |
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
|
540 |
print("RAG retrieval size: ", rag_retrieval_size)
|
|
|
541 |
if percentage == 0:
|
542 |
rag_prefix = prefix
|
543 |
rag_task = state["task_description"]
|
@@ -583,7 +586,9 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
|
|
583 |
|
584 |
def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
585 |
retrieval_context_length = int(token_count / retrieval_slider)
|
586 |
-
percentage = int(global_local_value.replace('%', ''))
|
|
|
|
|
587 |
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
|
588 |
kv_tokens = retrieval_context_length - rag_tokens
|
589 |
return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
|
@@ -592,36 +597,51 @@ def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
|
592 |
# Gradio Interface
|
593 |
##########################################################################
|
594 |
CSS = """
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
}
|
|
|
598 |
.upload-section {
|
599 |
padding: 10px;
|
600 |
border: 2px dashed #ccc;
|
601 |
border-radius: 10px;
|
602 |
}
|
|
|
603 |
.upload-button {
|
604 |
background: #34c759 !important;
|
605 |
color: white !important;
|
606 |
border-radius: 25px !important;
|
607 |
}
|
|
|
608 |
.chatbot-container {
|
609 |
-
margin-top:
|
610 |
}
|
|
|
611 |
.status-output {
|
612 |
margin-top: 10px;
|
613 |
font-size: 14px;
|
614 |
}
|
|
|
615 |
.processing-info {
|
616 |
margin-top: 5px;
|
617 |
font-size: 12px;
|
618 |
color: #666;
|
619 |
}
|
|
|
620 |
.info-container {
|
621 |
margin-top: 10px;
|
622 |
padding: 10px;
|
623 |
border-radius: 5px;
|
624 |
}
|
|
|
625 |
.file-list {
|
626 |
margin-top: 0;
|
627 |
max-height: 200px;
|
@@ -630,12 +650,14 @@ body {
|
|
630 |
border: 1px solid #eee;
|
631 |
border-radius: 5px;
|
632 |
}
|
|
|
633 |
.stats-box {
|
634 |
margin-top: 10px;
|
635 |
padding: 10px;
|
636 |
border-radius: 5px;
|
637 |
font-size: 12px;
|
638 |
}
|
|
|
639 |
.submit-btn {
|
640 |
background: #1a73e8 !important;
|
641 |
color: white !important;
|
@@ -644,18 +666,18 @@ body {
|
|
644 |
padding: 5px 10px;
|
645 |
font-size: 16px;
|
646 |
}
|
|
|
647 |
.input-row {
|
648 |
display: flex;
|
649 |
align-items: center;
|
650 |
}
|
651 |
-
|
652 |
"""
|
653 |
def reset_chat_state():
|
654 |
return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
|
655 |
|
656 |
-
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
657 |
-
gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
|
658 |
-
gr.HTML("<center
|
659 |
|
660 |
# Define chat_status_text as a Textbox with a set elem_id for custom styling.
|
661 |
chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
|
@@ -666,13 +688,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
666 |
|
667 |
with gr.Row(elem_classes="main-container"):
|
668 |
with gr.Column(elem_classes="upload-section"):
|
669 |
-
gr.Markdown("
|
670 |
with gr.Row():
|
671 |
-
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area")
|
672 |
-
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf")
|
673 |
with gr.Row():
|
674 |
-
do_ocr = gr.Checkbox(label="Do OCR", value=False)
|
675 |
-
do_table = gr.Checkbox(label="
|
676 |
with gr.Accordion("Prompt Designer", open=False):
|
677 |
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
|
678 |
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
|
@@ -682,9 +704,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
682 |
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
|
683 |
retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
|
684 |
tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
|
685 |
-
global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
|
686 |
-
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
# File input: Run auto_convert then chain reset_chat_state.
|
690 |
file_input.change(
|
@@ -785,7 +813,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
785 |
)
|
786 |
with gr.Column(elem_classes="chatbot-container"):
|
787 |
chat_status_text.render()
|
788 |
-
gr.Markdown("## Chat")
|
789 |
chat_interface = gr.ChatInterface(
|
790 |
fn=chat_response_stream,
|
791 |
additional_inputs=[compressed_doc_state, compression_done],
|
@@ -793,5 +821,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
793 |
fill_height=True
|
794 |
)
|
795 |
|
796 |
-
demo.queue(max_size=16).launch()
|
797 |
-
|
|
|
29 |
)
|
30 |
|
31 |
# Initialize the model and tokenizer.
|
32 |
+
api_token = os.getenv("HF_TOKEN")
|
33 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
35 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
|
|
456 |
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
|
457 |
progress(0, desc="Starting compression process")
|
458 |
|
459 |
+
# percentage = int(global_local_value.replace('%', ''))
|
460 |
+
percentage = 0 if global_local_value == "RAG" else 100
|
461 |
+
|
462 |
progress(0.1, desc="Tokenizing text and preparing task")
|
463 |
question_text = task_description + "\n" + few_shot
|
464 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
|
|
540 |
percentage = state["global_local"]
|
541 |
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
|
542 |
print("RAG retrieval size: ", rag_retrieval_size)
|
543 |
+
print("Compressed cache: ", compressed_length)
|
544 |
if percentage == 0:
|
545 |
rag_prefix = prefix
|
546 |
rag_task = state["task_description"]
|
|
|
586 |
|
587 |
def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
588 |
retrieval_context_length = int(token_count / retrieval_slider)
|
589 |
+
# percentage = int(global_local_value.replace('%', ''))
|
590 |
+
percentage = 0 if global_local_value == "RAG" else 100
|
591 |
+
|
592 |
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
|
593 |
kv_tokens = retrieval_context_length - rag_tokens
|
594 |
return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
|
|
|
597 |
# Gradio Interface
|
598 |
##########################################################################
|
599 |
CSS = """
|
600 |
+
.main-container {
|
601 |
+
display: flex;
|
602 |
+
align-items: stretch;
|
603 |
+
}
|
604 |
+
|
605 |
+
.upload-section, .chatbot-container {
|
606 |
+
display: flex;
|
607 |
+
flex-direction: column;
|
608 |
+
height: 100%;
|
609 |
+
overflow-y: auto;
|
610 |
}
|
611 |
+
|
612 |
.upload-section {
|
613 |
padding: 10px;
|
614 |
border: 2px dashed #ccc;
|
615 |
border-radius: 10px;
|
616 |
}
|
617 |
+
|
618 |
.upload-button {
|
619 |
background: #34c759 !important;
|
620 |
color: white !important;
|
621 |
border-radius: 25px !important;
|
622 |
}
|
623 |
+
|
624 |
.chatbot-container {
|
625 |
+
margin-top: 0;
|
626 |
}
|
627 |
+
|
628 |
.status-output {
|
629 |
margin-top: 10px;
|
630 |
font-size: 14px;
|
631 |
}
|
632 |
+
|
633 |
.processing-info {
|
634 |
margin-top: 5px;
|
635 |
font-size: 12px;
|
636 |
color: #666;
|
637 |
}
|
638 |
+
|
639 |
.info-container {
|
640 |
margin-top: 10px;
|
641 |
padding: 10px;
|
642 |
border-radius: 5px;
|
643 |
}
|
644 |
+
|
645 |
.file-list {
|
646 |
margin-top: 0;
|
647 |
max-height: 200px;
|
|
|
650 |
border: 1px solid #eee;
|
651 |
border-radius: 5px;
|
652 |
}
|
653 |
+
|
654 |
.stats-box {
|
655 |
margin-top: 10px;
|
656 |
padding: 10px;
|
657 |
border-radius: 5px;
|
658 |
font-size: 12px;
|
659 |
}
|
660 |
+
|
661 |
.submit-btn {
|
662 |
background: #1a73e8 !important;
|
663 |
color: white !important;
|
|
|
666 |
padding: 5px 10px;
|
667 |
font-size: 16px;
|
668 |
}
|
669 |
+
|
670 |
.input-row {
|
671 |
display: flex;
|
672 |
align-items: center;
|
673 |
}
|
|
|
674 |
"""
|
675 |
def reset_chat_state():
|
676 |
return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
|
677 |
|
678 |
+
with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo:
|
679 |
+
# gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
|
680 |
+
gr.HTML("<h1><center>Beyond RAG: Compress your document and chat with it.</center></h1>")
|
681 |
|
682 |
# Define chat_status_text as a Textbox with a set elem_id for custom styling.
|
683 |
chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
|
|
|
688 |
|
689 |
with gr.Row(elem_classes="main-container"):
|
690 |
with gr.Column(elem_classes="upload-section"):
|
691 |
+
gr.Markdown("### Document Preprocessing")
|
692 |
with gr.Row():
|
693 |
+
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area", height=120)
|
694 |
+
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf", lines=2)
|
695 |
with gr.Row():
|
696 |
+
do_ocr = gr.Checkbox(label="Do OCR on Images", value=True, visible=False)
|
697 |
+
do_table = gr.Checkbox(label="Parse Tables", value=True, visible=False)
|
698 |
with gr.Accordion("Prompt Designer", open=False):
|
699 |
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
|
700 |
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
|
|
|
704 |
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
|
705 |
retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
|
706 |
tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
|
707 |
+
# global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
|
708 |
+
# choices=["0%", "25%", "50%", "75%", "100%"], value="100%")
|
709 |
+
global_local_slider = gr.Radio(
|
710 |
+
label="Retrieval Mode",
|
711 |
+
choices=["RAG", "KVCompress"],
|
712 |
+
value="KVCompress"
|
713 |
+
)
|
714 |
+
|
715 |
+
compress_button = gr.Button("Compress Document", interactive=False, size="md", elem_classes="upload-button")
|
716 |
|
717 |
# File input: Run auto_convert then chain reset_chat_state.
|
718 |
file_input.change(
|
|
|
813 |
)
|
814 |
with gr.Column(elem_classes="chatbot-container"):
|
815 |
chat_status_text.render()
|
816 |
+
gr.Markdown("## Chat (LLama 3.1-8B-Instruct)")
|
817 |
chat_interface = gr.ChatInterface(
|
818 |
fn=chat_response_stream,
|
819 |
additional_inputs=[compressed_doc_state, compression_done],
|
|
|
821 |
fill_height=True
|
822 |
)
|
823 |
|
824 |
+
demo.queue(max_size=16).launch()
|
|