saving collection
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ from langchain_docling.loader import ExportType
|
|
18 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
19 |
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
|
20 |
from transformers.models.llama.modeling_llama import rotate_half
|
|
|
21 |
|
22 |
from utils import (
|
23 |
calculate_tokens_suggest_compression_ratio,
|
@@ -28,7 +29,7 @@ from utils import (
|
|
28 |
|
29 |
|
30 |
# Initialize the model and tokenizer.
|
31 |
-
api_token = os.getenv("
|
32 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
34 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
@@ -68,8 +69,6 @@ question: Prior to playing for Michigan State, Keith Nichol played football for
|
|
68 |
answer: Norman
|
69 |
"""
|
70 |
|
71 |
-
global_rag_index = None
|
72 |
-
|
73 |
class FinchCache(DynamicCache):
|
74 |
def __init__(self) -> None:
|
75 |
super().__init__()
|
@@ -169,7 +168,8 @@ def convert_to_markdown(file_objs, url, do_ocr, do_table_structure):
|
|
169 |
docs = loader.load()
|
170 |
return docs[0].page_content
|
171 |
|
172 |
-
|
|
|
173 |
"""Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
|
174 |
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
175 |
tokenizer,
|
@@ -181,13 +181,12 @@ def create_rag_index(text_no_prefix):
|
|
181 |
)
|
182 |
# Concatenate pages and create Document objects.
|
183 |
docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
|
184 |
-
vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model)
|
185 |
return vectorstore
|
186 |
|
187 |
|
188 |
@spaces.GPU
|
189 |
def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
190 |
-
global global_rag_index
|
191 |
if file_objs is None and (url is None or url.strip() == ""):
|
192 |
return (
|
193 |
gr.update(value=""),
|
@@ -204,7 +203,7 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
|
204 |
markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
|
205 |
print("Done")
|
206 |
combined_text = prefix + markdown
|
207 |
-
print("
|
208 |
token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model)
|
209 |
print("Done")
|
210 |
min_ratio = min(suggestions)
|
@@ -220,10 +219,10 @@ def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
|
220 |
rag_text = combined_text[len(prefix):]
|
221 |
else:
|
222 |
rag_text = combined_text
|
223 |
-
|
224 |
-
|
|
|
225 |
print("Done")
|
226 |
-
state = {}
|
227 |
|
228 |
return (
|
229 |
combined_text,
|
@@ -441,14 +440,14 @@ def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_i
|
|
441 |
return cache
|
442 |
|
443 |
|
444 |
-
def run_naive_rag_query(query, rag_token_size, prefix, task, few_shot_examples):
|
445 |
"""
|
446 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
447 |
and generates an answer using those chunks.
|
448 |
"""
|
449 |
-
global global_rag_index
|
450 |
k = max(1, rag_token_size // 256)
|
451 |
-
|
|
|
452 |
retrieved_docs = retriever.invoke(query)
|
453 |
for doc in retrieved_docs:
|
454 |
print("=================")
|
@@ -466,7 +465,6 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
466 |
"""
|
467 |
Prepares the compressed KV cache. Uses the precomputed rag_index from state.
|
468 |
"""
|
469 |
-
global global_rag_index
|
470 |
percentage = int(global_local_value.replace('%', ''))
|
471 |
question_text = task_description + "\n" + few_shot
|
472 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
@@ -482,11 +480,9 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
482 |
print("Target token size for compression: ", target_token_size)
|
483 |
step_size = 2
|
484 |
start_time_prefill = time.perf_counter()
|
485 |
-
print("Compressing KV Cache")
|
486 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
487 |
context_ids, context_attention_mask,
|
488 |
question_ids, question_attention_mask))
|
489 |
-
print("Done")
|
490 |
compressed_length = past_key_values.get_seq_length()
|
491 |
print("Context size after compression: ", compressed_length)
|
492 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
@@ -497,17 +493,21 @@ def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_lo
|
|
497 |
compressed_length = past_key_values.get_seq_length()
|
498 |
|
499 |
|
500 |
-
|
501 |
-
|
|
|
|
|
502 |
if combined_text.startswith(prefix):
|
503 |
rag_text = combined_text[len(prefix):]
|
504 |
else:
|
505 |
rag_text = combined_text
|
506 |
-
|
507 |
-
|
|
|
508 |
state.update({
|
509 |
"compressed_cache": past_key_values,
|
510 |
"compressed_length": compressed_length,
|
|
|
511 |
"target_token_size": target_token_size,
|
512 |
"global_local": percentage,
|
513 |
"combined_text": combined_text,
|
@@ -528,6 +528,7 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
528 |
user_message = message
|
529 |
past_key_values = state["compressed_cache"]
|
530 |
compressed_length = past_key_values.get_seq_length()
|
|
|
531 |
retrieval_slider_value = state["retrieval_slider"]
|
532 |
percentage = state["global_local"]
|
533 |
|
@@ -544,7 +545,8 @@ def chat_response_stream(message: str, history: list, state: dict):
|
|
544 |
rag_few_shot = ""
|
545 |
print("user message: ", user_message)
|
546 |
if rag_retrieval_size != 0:
|
547 |
-
|
|
|
548 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
549 |
else:
|
550 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|
@@ -724,4 +726,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
724 |
type="messages"
|
725 |
)
|
726 |
|
727 |
-
demo.queue().launch()
|
|
|
18 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
19 |
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer
|
20 |
from transformers.models.llama.modeling_llama import rotate_half
|
21 |
+
import uuid
|
22 |
|
23 |
from utils import (
|
24 |
calculate_tokens_suggest_compression_ratio,
|
|
|
29 |
|
30 |
|
31 |
# Initialize the model and tokenizer.
|
32 |
+
api_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
33 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
35 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
|
|
69 |
answer: Norman
|
70 |
"""
|
71 |
|
|
|
|
|
72 |
class FinchCache(DynamicCache):
|
73 |
def __init__(self) -> None:
|
74 |
super().__init__()
|
|
|
168 |
docs = loader.load()
|
169 |
return docs[0].page_content
|
170 |
|
171 |
+
|
172 |
+
def create_rag_index(collection_name, text_no_prefix):
|
173 |
"""Loads the PDF, splits its text, and builds a vectorstore for naive RAG."""
|
174 |
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
175 |
tokenizer,
|
|
|
181 |
)
|
182 |
# Concatenate pages and create Document objects.
|
183 |
docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)]
|
184 |
+
vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model)
|
185 |
return vectorstore
|
186 |
|
187 |
|
188 |
@spaces.GPU
|
189 |
def auto_convert(file_objs, url, do_ocr, do_table_structure):
|
|
|
190 |
if file_objs is None and (url is None or url.strip() == ""):
|
191 |
return (
|
192 |
gr.update(value=""),
|
|
|
203 |
markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure)
|
204 |
print("Done")
|
205 |
combined_text = prefix + markdown
|
206 |
+
print("Suggestioning Compression ratio")
|
207 |
token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model)
|
208 |
print("Done")
|
209 |
min_ratio = min(suggestions)
|
|
|
219 |
rag_text = combined_text[len(prefix):]
|
220 |
else:
|
221 |
rag_text = combined_text
|
222 |
+
collection_name = "default_collection_" + uuid.uuid4().hex[:6]
|
223 |
+
rag_index = create_rag_index(collection_name, rag_text)
|
224 |
+
state = {"rag_index": collection_name}
|
225 |
print("Done")
|
|
|
226 |
|
227 |
return (
|
228 |
combined_text,
|
|
|
440 |
return cache
|
441 |
|
442 |
|
443 |
+
def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples):
|
444 |
"""
|
445 |
For naive RAG, retrieves top-k chunks (k based on target token size)
|
446 |
and generates an answer using those chunks.
|
447 |
"""
|
|
|
448 |
k = max(1, rag_token_size // 256)
|
449 |
+
vectorstore = Chroma(persist_directory="./chroma_db", embedding=embedding_model, collection_name=collection_name)
|
450 |
+
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
451 |
retrieved_docs = retriever.invoke(query)
|
452 |
for doc in retrieved_docs:
|
453 |
print("=================")
|
|
|
465 |
"""
|
466 |
Prepares the compressed KV cache. Uses the precomputed rag_index from state.
|
467 |
"""
|
|
|
468 |
percentage = int(global_local_value.replace('%', ''))
|
469 |
question_text = task_description + "\n" + few_shot
|
470 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
|
|
480 |
print("Target token size for compression: ", target_token_size)
|
481 |
step_size = 2
|
482 |
start_time_prefill = time.perf_counter()
|
|
|
483 |
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size,
|
484 |
context_ids, context_attention_mask,
|
485 |
question_ids, question_attention_mask))
|
|
|
486 |
compressed_length = past_key_values.get_seq_length()
|
487 |
print("Context size after compression: ", compressed_length)
|
488 |
print("Compression rate: ", context_ids.size(1) / compressed_length)
|
|
|
493 |
compressed_length = past_key_values.get_seq_length()
|
494 |
|
495 |
|
496 |
+
# Use the precomputed rag_index from state.
|
497 |
+
collection_name = state.get("rag_index", None)
|
498 |
+
if collection_name is None:
|
499 |
+
print("Collection name not found creating a new one.")
|
500 |
if combined_text.startswith(prefix):
|
501 |
rag_text = combined_text[len(prefix):]
|
502 |
else:
|
503 |
rag_text = combined_text
|
504 |
+
collection_name = "default_collection_" + uuid.uuid4().hex[:6]
|
505 |
+
rag_index = create_rag_index(collection_name, rag_text)
|
506 |
+
|
507 |
state.update({
|
508 |
"compressed_cache": past_key_values,
|
509 |
"compressed_length": compressed_length,
|
510 |
+
"rag_index": collection_name,
|
511 |
"target_token_size": target_token_size,
|
512 |
"global_local": percentage,
|
513 |
"combined_text": combined_text,
|
|
|
528 |
user_message = message
|
529 |
past_key_values = state["compressed_cache"]
|
530 |
compressed_length = past_key_values.get_seq_length()
|
531 |
+
collection_name = state["rag_index"]
|
532 |
retrieval_slider_value = state["retrieval_slider"]
|
533 |
percentage = state["global_local"]
|
534 |
|
|
|
545 |
rag_few_shot = ""
|
546 |
print("user message: ", user_message)
|
547 |
if rag_retrieval_size != 0:
|
548 |
+
print("Running RAG query")
|
549 |
+
rag_context = run_naive_rag_query(collection_name, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot)
|
550 |
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:"
|
551 |
else:
|
552 |
new_input = "\nquestion: " + user_message + suffix + "answer:"
|
|
|
726 |
type="messages"
|
727 |
)
|
728 |
|
729 |
+
demo.queue(max_size=16).launch()
|