import copy import math import os import time from threading import Thread import uuid import gradio as gr import spaces import torch from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.document_converter import DocumentConverter, InputFormat, PdfFormatOption from langchain.schema.document import Document from langchain_chroma import Chroma from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_docling import DoclingLoader from langchain_docling.loader import ExportType from langchain_text_splitters import RecursiveCharacterTextSplitter from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer, BitsAndBytesConfig #, Gemma3ForCausalLM from transformers.models.llama.modeling_llama import rotate_half import threading import shutil import time from utils import ( calculate_tokens_suggest_compression_ratio, repeat_kv, update_retrieval_context, ) # Initialize the model and tokenizer. api_token = os.getenv("HUGGING_FACE_HUB_TOKEN") model_name = "meta-llama/Llama-3.1-8B-Instruct" # model_name = "google/gemma-3-27b-it" tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) # quantization_config = BitsAndBytesConfig(load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16) # model = Gemma3ForCausalLM.from_pretrained(model_name, token=api_token, quantization_config=quantization_config, torch_dtype="auto") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() model.to(device) embedding_model = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-large-en-v1.5", model_kwargs={"device": str(device)}, encode_kwargs={"normalize_embeddings": True}, query_instruction="" ) # Create a chat template and split into prefix and suffix. content_system = "" content_user = "######" user_template = [ {"role": "system", "content": content_system}, {"role": "user", "content": content_user} ] user = tokenizer.apply_chat_template(user_template, add_generation_prompt=True, tokenize=False) prefix, suffix = user.split(content_user) sink_tokens = max(4, len(tokenizer.encode(prefix))) # Default prompt content. default_task_description = ( "Answer the question based on the provided context." "Provide only the answer and no additional explanations." ) default_few_shot = """Examples context: Climate change is primarily driven by human activities such as burning fossil fuels, deforestation, and industrial processes, which release large amounts of greenhouse gases into the atmosphere, causing global temperatures to rise. question: What are the main human activities contributing to climate change? answer: The main human activities contributing to climate change include burning fossil fuels, deforestation, and various industrial processes that emit greenhouse gases. context: The Renaissance was a cultural movement spanning roughly from the 14th to the 17th century, marked by renewed interest in classical learning and values, advancements in art, literature, and scientific inquiry, and significant cultural developments. question: What characterized the Renaissance period? answer: The Renaissance period was characterized by a revival of classical learning, significant advancements in art and literature, and notable developments in scientific inquiry and cultural values. context: The theory of evolution by natural selection, proposed by Charles Darwin, explains how species adapt and evolve over generations based on the survival and reproduction of individuals best suited to their environment. question: How does Darwin's theory of evolution explain the adaptation of species? answer: Darwin's theory explains that species adapt and evolve through natural selection, where individuals best suited to their environment are more likely to survive and reproduce. """ CHROMA_DB_DIR = "./chroma_db" CACHE_DIR = "./cache_dir" EXPIRATION_SECONDS = 3600 def background_cleanup(): while True: current_time = int(time.time()) # Clean Chroma collections if os.path.exists(CHROMA_DB_DIR): for dirname in os.listdir(CHROMA_DB_DIR): parts = dirname.split("_") if len(parts) >= 3 and parts[1].isdigit(): timestamp = int(parts[1]) if current_time - timestamp > EXPIRATION_SECONDS: path = os.path.join(CHROMA_DB_DIR, dirname) shutil.rmtree(path, ignore_errors=True) print(f"[Cleanup] Deleted Chroma collection: {path}") # Clean cache files if os.path.exists(CACHE_DIR): for filename in os.listdir(CACHE_DIR): parts = filename.split("_") if len(parts) >= 3 and parts[1].isdigit(): timestamp = int(parts[1]) if current_time - timestamp > EXPIRATION_SECONDS: path = os.path.join(CACHE_DIR, filename) os.remove(path) print(f"[Cleanup] Deleted cache file: {path}") time.sleep(600) cleanup_thread = threading.Thread(target=background_cleanup, daemon=True) cleanup_thread.start() class FinchCache(DynamicCache): def __init__(self) -> None: super().__init__() self.key_cache = [] self.value_cache = [] @staticmethod def _rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: return (key_states * cos) + (self._rotate_half(key_states) * sin) @staticmethod def _rerotate_cos_sin(x, inv_freq, important_pos_batch): B, H, L = important_pos_batch.shape device = important_pos_batch.device device_type = x.device.type dtype = x.dtype idx = torch.arange(0, L, device=device) idx = idx.unsqueeze(0) inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) delta_pos = idx - important_pos_batch delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = delta_pos.float() * inv_freq.float() freqs = freqs.transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().contiguous() sin = emb.sin().contiguous() return cos.to(dtype=dtype), sin.to(dtype=dtype) @staticmethod def gather_important_tokens(states, indices): return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() def compress_cache(self, layer_index, important_pos, inv_freq): new_length = important_pos.size(2) new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() self.value_cache[layer_index] = gathered_values self._seen_tokens = new_length def save(self, path: str): try: os.makedirs(os.path.dirname(path), exist_ok=True) torch.save( {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, path, ) except Exception as e: print(f"Error occurred while saving: {e}") @classmethod def load(cls, path: str, device: str = "cpu") -> "FinchCache": data = torch.load(path, map_location=device) cache = cls() cache.key_cache = [k.to(device) for k in data["key_cache"]] cache.value_cache = [v.to(device) for v in data["value_cache"]] cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 return cache def convert_to_markdown(file_objs, url, do_ocr, do_table_structure): file_path = file_objs if file_objs is not None else url pipeline_options = PdfPipelineOptions() pipeline_options.do_ocr = do_ocr pipeline_options.do_table_structure = do_table_structure pdf_format_options = PdfFormatOption( pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend, ) doc_converter = DocumentConverter( allowed_formats=[InputFormat.PDF], format_options={InputFormat.PDF: pdf_format_options} ) loader = DoclingLoader( file_path=file_path, export_type=ExportType.MARKDOWN, converter=doc_converter ) try: docs = loader.load() return docs[0].page_content except Exception as e: raise RuntimeError(f"Failed to convert document to markdown: {e}") def create_rag_index(collection_name, text_no_prefix): text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( tokenizer, chunk_size=256, chunk_overlap=0, add_start_index=True, strip_whitespace=True, separators=["\n\n", "\n", ".", " ", ""], ) docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model) return vectorstore @spaces.GPU def auto_convert(file_objs, url, do_ocr, do_table_structure): # When a new file/URL is loaded, disable chat (compression not done) chat_status = "Document not compressed yet. Please compress the document to enable chat." if file_objs is None and (url is None or url.strip() == ""): return ( gr.update(value=""), "Number of tokens before compression: ", gr.update(), "Number of tokens after compression: ", 0, gr.update(interactive=False), False, {}, chat_status ) print("Converting to markdown") try: markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure) except RuntimeError as e: return ( gr.update(value=f"{str(e)} Please try uploading another document format."), "Number of tokens before compression: ", gr.update(), "Number of tokens after compression: ", 0, gr.update(interactive=False), False, {}, chat_status ) print("Done") combined_text = prefix + markdown print("Suggestioning Compression ratio") token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model) print("Done") min_ratio = min(suggestions) max_ratio = max(suggestions) default_ratio = 4 retrieval_tokens = int(token_count / default_ratio) token_count_str = f"Number of tokens before compression: {token_count}" retrieval_str = f"Number of tokens after compression: {retrieval_tokens}" slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1) if combined_text.startswith(prefix): rag_text = combined_text[len(prefix):] else: rag_text = combined_text current_timestamp = int(time.time()) collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" rag_index = create_rag_index(collection_name, rag_text) state = {"rag_index": collection_name} print("Done") return ( combined_text, token_count_str, slider_update, retrieval_str, token_count, gr.update(interactive=True), # Enable compress button if conversion succeeds. False, state, chat_status ) def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): try: device = model.device dtype = model.dtype sink_tokens = sink_tokens num_chunks = step_size context_ids = context_ids.to(device) context_attention_mask = context_attention_mask.to(device) question_ids = question_ids.to(device) question_attention_mask = question_attention_mask.to(device) question_len = question_ids.size(1) total_len = context_ids.size(1) max_context_tokens_allowed = model.config.max_position_embeddings - question_len if total_len > max_context_tokens_allowed: num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) if total_len <= sink_tokens or num_chunks == 1: context_ids_list = [context_ids] context_attention_mask_list = [context_attention_mask] else: remainder_len = total_len - sink_tokens base = remainder_len // num_chunks leftover = remainder_len % num_chunks chunk_sizes = [sink_tokens + base] for _ in range(num_chunks - 2): chunk_sizes.append(base) if num_chunks > 1: chunk_sizes.append(base + leftover) context_ids_list = [] context_attention_mask_list = [] offset = 0 for size in chunk_sizes: end = offset + size context_ids_list.append(context_ids[:, offset:end]) context_attention_mask_list.append(context_attention_mask[:, offset:end]) offset = end len_rest = max(total_len - sink_tokens, 1) compression_factor = len_rest // target_token_size if compression_factor < 1: compression_factor = 1 tokenized_doc_chunks = [] for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) print("Number of chunks: ", len(tokenized_doc_chunks)) rotary_emb = model.model.rotary_emb.to(device) inv_freq = rotary_emb.inv_freq batch_size = question_ids.size(0) ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) cache = FinchCache() past_cache_len = 0 past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) num_chunks = len(tokenized_doc_chunks) query_context_matrices = {} def query_hook_fn(module, input, output): layer_idx = getattr(module, "layer_idx", None) if layer_idx is not None: query_states = output.detach() bsz, seq_len, hidden_dim = query_states.size() num_query_heads = module.num_query_heads head_dim = hidden_dim // num_query_heads query_states = ( query_states.view(bsz, seq_len, num_query_heads, head_dim) .transpose(1, 2) .contiguous() ) query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() hooks = [] for i, layer in enumerate(model.model.layers): layer.self_attn.q_proj.layer_idx = i layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) hooks.append(hook) for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): current_seq_length = tokenized_doc_chunk["input_ids"].size(1) _current_chunk_offset = current_seq_length query_context_matrices.clear() chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() segment_attention_mask = torch.cat( [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 ).contiguous() current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() past_seen_tokens = cache.get_seq_length() if cache is not None else 0 cache_position = torch.arange( past_seen_tokens + chunk_input_ids.shape[1], past_seen_tokens + current_input_ids.shape[1], device=device ) causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( current_attention_mask, sequence_length=question_ids.size(1), target_length=current_attention_mask.size(-1), dtype=dtype, device=device, cache_position=cache_position, batch_size=current_input_ids.size(0), ).contiguous() with torch.no_grad(): outputs = model.model( input_ids=current_input_ids, use_cache=True, past_key_values=cache, ) cache = outputs.past_key_values len_question = question_ids.size(1) for layer_idx in range(len(model.model.layers)): key_matrix = cache.key_cache[layer_idx] query_matrix = query_context_matrices[layer_idx] layer_cache_pos = torch.arange( past_cache_len + current_seq_length, past_cache_len + current_seq_length + len_question, device=device ) position_ids = layer_cache_pos.unsqueeze(0) cos, sin = rotary_emb(query_matrix, position_ids) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads key_matrix = repeat_kv(key_matrix, num_repeats) scaling = math.sqrt(model.config.head_dim) attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] attention_matrix = attention_matrix + causal_mask_sliced attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) tol = 1e-8 binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) non_zero_counts = binary_mask.sum(dim=3, keepdim=True) non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) attention_matrix = attention_matrix / non_zero_counts if j != num_chunks - 1: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() else: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() attention_matrix = torch.sum(attention_matrix, dim=-2) attention_matrix = attention_matrix.view( attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 ).sum(dim=2) full_context_size = attention_matrix.size(-1) attention_matrix[..., :sink_tokens] = float("inf") if j == num_chunks - 1: attention_matrix[..., -len_question:] = float("inf") if j == 0: k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) k = min(k + past_cache_len, full_context_size) elif j < num_chunks - 1: to_keep_new = int(current_seq_length // compression_factor) k = min(past_cache_len + to_keep_new, full_context_size) else: desired_final = sink_tokens + target_token_size + len_question k = desired_final if full_context_size >= desired_final else full_context_size k = max(k, sink_tokens) selected_indices = torch.topk(attention_matrix, k, dim=-1).indices selected_indices, _ = torch.sort(selected_indices, dim=-1) cache.compress_cache(layer_idx, selected_indices, inv_freq) past_cache_len = cache._seen_tokens past_attention_mask = torch.ones(1, past_cache_len, device=device) for hook in hooks: hook.remove() return cache except Exception as e: raise RuntimeError(f"Failed to compress KV cache: {e}") def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples): k = max(1, rag_token_size // 256) vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embedding_model, collection_name=collection_name) retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) retrieved_docs = retriever.invoke(query) for doc in retrieved_docs: print("=================") print(doc.page_content) print("=================") formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples return rag_context @spaces.GPU def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()): progress(0, desc="Starting compression process") # percentage = int(global_local_value.replace('%', '')) percentage = 0 if global_local_value == "RAG" else 100 progress(0.1, desc="Tokenizing text and preparing task") question_text = task_description + "\n" + few_shot context_encoding = tokenizer(combined_text, return_tensors="pt").to(device) question_encoding = tokenizer(question_text, return_tensors="pt").to(device) context_ids = context_encoding["input_ids"] context_attention_mask = context_encoding["attention_mask"] question_ids = question_encoding["input_ids"] question_attention_mask = question_encoding["attention_mask"] retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value) rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) kv_tokens = retrieval_context_length - rag_tokens progress(0.2, desc=f"Token breakdown computed: {kv_tokens} KV tokens, {rag_tokens} RAG tokens") if percentage > 0: target_token_size = int(retrieval_context_length * (percentage / 100)) progress(0.3, desc="Starting KV cache compression") step_size = 2 try: past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask)) except Exception as e: progress(1, desc="Compression failed") print("Error during KV cache compression:", e) state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again." return state, False compressed_length = past_key_values.get_seq_length() progress(0.6, desc="KV cache compression completed") else: target_token_size = 0 past_key_values = FinchCache() compressed_length = past_key_values.get_seq_length() progress(0.3, desc="Skipping compression as percentage is 0") current_timestamp = int(time.time()) cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt" save_dir = "./cache_dir" os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, cache_name) past_key_values.save(save_path) progress(0.8, desc="Cache saved successfully") collection_name = state.get("rag_index", None) if collection_name is None: print("Collection name not found; creating a new one.") if combined_text.startswith(prefix): rag_text = combined_text[len(prefix):] else: rag_text = combined_text current_timestamp = int(time.time()) collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" rag_index = create_rag_index(collection_name, rag_text) state.update({ "compressed_cache": save_path, "rag_index": collection_name, "global_local": percentage, "task_description": task_description, "few_shot": few_shot, "retrieval_slider": retrieval_context_length, }) progress(1, desc="Compression complete") return state, "Document compressed successfully. You can now chat.", True @spaces.GPU def chat_response_stream(message: str, history: list, state: dict, compression_done: bool): # Check if the document is compressed before allowing chat if not compression_done or "compressed_cache" not in state: yield "Document not compressed yet. Please compress the document first to enable chat." return user_message = message save_path = state["compressed_cache"] past_key_values = FinchCache.load(save_path, device=model.device) compressed_length = past_key_values.get_seq_length() collection_name = state["rag_index"] retrieval_slider_value = state["retrieval_slider"] percentage = state["global_local"] rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100))) print("RAG retrieval size: ", rag_retrieval_size) print("Compressed cache: ", compressed_length) if percentage == 0: rag_prefix = prefix rag_task = state["task_description"] rag_few_shot = state["few_shot"] else: rag_prefix = "" rag_task = "" rag_few_shot = "" print("user message: ", user_message) if rag_retrieval_size != 0: print("Running RAG query") rag_context = run_naive_rag_query(collection_name, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot) new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:" else: new_input = "\nquestion: " + user_message + suffix + "answer:" tokenized_new_input = tokenizer(new_input, return_tensors="pt").to(device) eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long) new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1) new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1) print("New input is: ", new_input) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=new_input_ids, attention_mask=new_attention_mask, past_key_values=past_key_values, streamer=streamer, use_cache=True, max_new_tokens=1024, num_beams=1, do_sample=False, top_p=1.0, top_k=None, temperature=1.0, # top_k=64, # top_p=0.95, # min_p=0.0 ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() full_output = "" for text in streamer: full_output += text time.sleep(0.05) yield full_output return full_output def update_token_breakdown(token_count, retrieval_slider, global_local_value): retrieval_context_length = int(token_count / retrieval_slider) # percentage = int(global_local_value.replace('%', '')) percentage = 0 if global_local_value == "RAG" else 100 rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) kv_tokens = retrieval_context_length - rag_tokens return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)", f"Number of tokens after compression: {retrieval_context_length}" ########################################################################## # Gradio Interface ########################################################################## CSS = """ .main-container { display: flex; align-items: stretch; } .upload-section, .chatbot-container { display: flex; flex-direction: column; height: 100%; overflow-y: auto; } .upload-section { padding: 10px; border: 2px dashed #ccc; border-radius: 10px; } .upload-button { background: #34c759 !important; color: white !important; border-radius: 25px !important; } .chatbot-container { margin-top: 0; } .status-output { margin-top: 10px; font-size: 14px; } .processing-info { margin-top: 5px; font-size: 12px; color: #666; } .info-container { margin-top: 10px; padding: 10px; border-radius: 5px; } .file-list { margin-top: 0; max-height: 200px; overflow-y: auto; padding: 5px; border: 1px solid #eee; border-radius: 5px; } .stats-box { margin-top: 10px; padding: 10px; border-radius: 5px; font-size: 12px; } .submit-btn { background: #1a73e8 !important; color: white !important; border-radius: 25px !important; margin-left: 10px; padding: 5px 10px; font-size: 16px; } .input-row { display: flex; align-items: center; } """ def reset_chat_state(): return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo: # gr.HTML("

Beyond RAG with LLama 3.1-8B-Instruct Model

") gr.HTML("

Beyond RAG: Compress your document and chat with it.

") # Define chat_status_text as a Textbox with a set elem_id for custom styling. chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5) hidden_token_count = gr.State(value=0) compression_done = gr.State(value=False) compressed_doc_state = gr.State(value={}) with gr.Row(elem_classes="main-container"): with gr.Column(elem_classes="upload-section"): gr.Markdown("### Document Preprocessing") with gr.Row(): file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area", height=120) url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf", lines=2) with gr.Row(): do_ocr = gr.Checkbox(label="Do OCR on Images", value=False, visible=False) do_table = gr.Checkbox(label="Parse Tables", value=False, visible=False) with gr.Accordion("Prompt Designer", open=False): task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description") few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot") with gr.Accordion("Show Markdown Output", open=False): markdown_output = gr.Textbox(label="Markdown Output", lines=20) token_count_text = gr.Markdown("Number of tokens before compression: ") retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2) retrieval_info_text = gr.Markdown("Number of tokens after compression: ") tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.") # global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)", # choices=["0%", "25%", "50%", "75%", "100%"], value="100%") global_local_slider = gr.Radio( label="Retrieval Mode", choices=["RAG", "KVCompress"], value="KVCompress" ) compress_button = gr.Button("Compress Document", interactive=False, size="md", elem_classes="upload-button") # File input: Run auto_convert then chain reset_chat_state. file_input.change( fn=auto_convert, inputs=[file_input, url_input, do_ocr, do_table], outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] ).then( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # URL input: Run auto_convert then chain reset_chat_state. url_input.change( fn=auto_convert, inputs=[file_input, url_input, do_ocr, do_table], outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] ).then( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # OCR checkbox: Run auto_convert then chain reset_chat_state. do_ocr.change( fn=auto_convert, inputs=[file_input, url_input, do_ocr, do_table], outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] ).then( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # Table structure checkbox: Run auto_convert then chain reset_chat_state. do_table.change( fn=auto_convert, inputs=[file_input, url_input, do_ocr, do_table], outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] ).then( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # Reset chat state when prompt designer fields change. task_description_input.change( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) few_shot_input.change( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # Reset chat state when the Markdown output changes. markdown_output.change( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # When sliders change, reset chat state. retrieval_slider.change( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) global_local_slider.change( fn=reset_chat_state, inputs=None, outputs=[chat_status_text, compression_done] ) # Update token breakdown when sliders change. retrieval_slider.change( fn=update_token_breakdown, inputs=[hidden_token_count, retrieval_slider, global_local_slider], outputs=[tokens_breakdown_text, retrieval_info_text] ) global_local_slider.change( fn=update_token_breakdown, inputs=[hidden_token_count, retrieval_slider, global_local_slider], outputs=[tokens_breakdown_text, retrieval_info_text] ) # Compress button: Prepare compression and then update chat status. compress_button.click( fn=prepare_compression_and_rag, inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state], outputs=[compressed_doc_state, chat_status_text, compression_done] ) with gr.Column(elem_classes="chatbot-container"): chat_status_text.render() gr.Markdown("## Chat (LLama 3.1-8B-Instruct)") gr.Markdown("**Note:** There is currently no chat history available.") chat_interface = gr.ChatInterface( fn=chat_response_stream, additional_inputs=[compressed_doc_state, compression_done], type="messages", fill_height=True ) demo.queue(max_size=16).launch()