|
import copy |
|
import math |
|
import os |
|
import time |
|
from threading import Thread |
|
import uuid |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend |
|
from docling.datamodel.pipeline_options import PdfPipelineOptions |
|
from docling.document_converter import DocumentConverter, InputFormat, PdfFormatOption |
|
from langchain.schema.document import Document |
|
from langchain_chroma import Chroma |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_docling import DoclingLoader |
|
from langchain_docling.loader import ExportType |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer, BitsAndBytesConfig |
|
from transformers.models.llama.modeling_llama import rotate_half |
|
import threading |
|
import shutil |
|
import time |
|
from utils import ( |
|
calculate_tokens_suggest_compression_ratio, |
|
repeat_kv, |
|
update_retrieval_context, |
|
) |
|
|
|
|
|
api_token = os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
model_name = "meta-llama/Llama-3.1-8B-Instruct" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.eval() |
|
model.to(device) |
|
embedding_model = HuggingFaceBgeEmbeddings( |
|
model_name="BAAI/bge-large-en-v1.5", |
|
model_kwargs={"device": str(device)}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
query_instruction="" |
|
) |
|
|
|
|
|
content_system = "" |
|
content_user = "######" |
|
user_template = [ |
|
{"role": "system", "content": content_system}, |
|
{"role": "user", "content": content_user} |
|
] |
|
user = tokenizer.apply_chat_template(user_template, add_generation_prompt=True, tokenize=False) |
|
prefix, suffix = user.split(content_user) |
|
sink_tokens = max(4, len(tokenizer.encode(prefix))) |
|
|
|
|
|
default_task_description = ( |
|
"Answer the question based on the provided context." |
|
"Provide only the answer and no additional explanations." |
|
) |
|
default_few_shot = """Examples |
|
context: Climate change is primarily driven by human activities such as burning fossil fuels, deforestation, and industrial processes, which release large amounts of greenhouse gases into the atmosphere, causing global temperatures to rise. |
|
question: What are the main human activities contributing to climate change? |
|
answer: The main human activities contributing to climate change include burning fossil fuels, deforestation, and various industrial processes that emit greenhouse gases. |
|
|
|
context: The Renaissance was a cultural movement spanning roughly from the 14th to the 17th century, marked by renewed interest in classical learning and values, advancements in art, literature, and scientific inquiry, and significant cultural developments. |
|
question: What characterized the Renaissance period? |
|
answer: The Renaissance period was characterized by a revival of classical learning, significant advancements in art and literature, and notable developments in scientific inquiry and cultural values. |
|
|
|
context: The theory of evolution by natural selection, proposed by Charles Darwin, explains how species adapt and evolve over generations based on the survival and reproduction of individuals best suited to their environment. |
|
question: How does Darwin's theory of evolution explain the adaptation of species? |
|
answer: Darwin's theory explains that species adapt and evolve through natural selection, where individuals best suited to their environment are more likely to survive and reproduce. |
|
""" |
|
|
|
CHROMA_DB_DIR = "./chroma_db" |
|
CACHE_DIR = "./cache_dir" |
|
EXPIRATION_SECONDS = 3600 |
|
|
|
def background_cleanup(): |
|
while True: |
|
current_time = int(time.time()) |
|
|
|
|
|
if os.path.exists(CHROMA_DB_DIR): |
|
for dirname in os.listdir(CHROMA_DB_DIR): |
|
parts = dirname.split("_") |
|
if len(parts) >= 3 and parts[1].isdigit(): |
|
timestamp = int(parts[1]) |
|
if current_time - timestamp > EXPIRATION_SECONDS: |
|
path = os.path.join(CHROMA_DB_DIR, dirname) |
|
shutil.rmtree(path, ignore_errors=True) |
|
print(f"[Cleanup] Deleted Chroma collection: {path}") |
|
|
|
|
|
if os.path.exists(CACHE_DIR): |
|
for filename in os.listdir(CACHE_DIR): |
|
parts = filename.split("_") |
|
if len(parts) >= 3 and parts[1].isdigit(): |
|
timestamp = int(parts[1]) |
|
if current_time - timestamp > EXPIRATION_SECONDS: |
|
path = os.path.join(CACHE_DIR, filename) |
|
os.remove(path) |
|
print(f"[Cleanup] Deleted cache file: {path}") |
|
|
|
time.sleep(600) |
|
|
|
cleanup_thread = threading.Thread(target=background_cleanup, daemon=True) |
|
cleanup_thread.start() |
|
|
|
|
|
class FinchCache(DynamicCache): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.key_cache = [] |
|
self.value_cache = [] |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
return (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
|
|
@staticmethod |
|
def _rerotate_cos_sin(x, inv_freq, important_pos_batch): |
|
B, H, L = important_pos_batch.shape |
|
device = important_pos_batch.device |
|
device_type = x.device.type |
|
dtype = x.dtype |
|
idx = torch.arange(0, L, device=device) |
|
idx = idx.unsqueeze(0) |
|
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) |
|
idx = idx[:, None, :].float().expand(B, H, L) |
|
delta_pos = idx - important_pos_batch |
|
delta_pos = delta_pos.unsqueeze(2) |
|
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = delta_pos.float() * inv_freq.float() |
|
freqs = freqs.transpose(2, 3) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().contiguous() |
|
sin = emb.sin().contiguous() |
|
return cos.to(dtype=dtype), sin.to(dtype=dtype) |
|
|
|
@staticmethod |
|
def gather_important_tokens(states, indices): |
|
return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() |
|
|
|
def compress_cache(self, layer_index, important_pos, inv_freq): |
|
new_length = important_pos.size(2) |
|
new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) |
|
gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() |
|
self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) |
|
gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() |
|
self.value_cache[layer_index] = gathered_values |
|
self._seen_tokens = new_length |
|
|
|
def save(self, path: str): |
|
try: |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
torch.save( |
|
{"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, |
|
path, |
|
) |
|
except Exception as e: |
|
print(f"Error occurred while saving: {e}") |
|
|
|
@classmethod |
|
def load(cls, path: str, device: str = "cpu") -> "FinchCache": |
|
data = torch.load(path, map_location=device) |
|
cache = cls() |
|
cache.key_cache = [k.to(device) for k in data["key_cache"]] |
|
cache.value_cache = [v.to(device) for v in data["value_cache"]] |
|
cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 |
|
return cache |
|
|
|
def convert_to_markdown(file_objs, url, do_ocr, do_table_structure): |
|
file_path = file_objs if file_objs is not None else url |
|
pipeline_options = PdfPipelineOptions() |
|
pipeline_options.do_ocr = do_ocr |
|
pipeline_options.do_table_structure = do_table_structure |
|
pdf_format_options = PdfFormatOption( |
|
pipeline_options=pipeline_options, |
|
backend=PyPdfiumDocumentBackend, |
|
) |
|
doc_converter = DocumentConverter( |
|
allowed_formats=[InputFormat.PDF], |
|
format_options={InputFormat.PDF: pdf_format_options} |
|
) |
|
loader = DoclingLoader( |
|
file_path=file_path, |
|
export_type=ExportType.MARKDOWN, |
|
converter=doc_converter |
|
) |
|
try: |
|
docs = loader.load() |
|
return docs[0].page_content |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to convert document to markdown: {e}") |
|
|
|
def create_rag_index(collection_name, text_no_prefix): |
|
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
|
tokenizer, |
|
chunk_size=256, |
|
chunk_overlap=0, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] |
|
vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model) |
|
return vectorstore |
|
|
|
@spaces.GPU |
|
def auto_convert(file_objs, url, do_ocr, do_table_structure): |
|
|
|
chat_status = "Document not compressed yet. Please compress the document to enable chat." |
|
if file_objs is None and (url is None or url.strip() == ""): |
|
return ( |
|
gr.update(value=""), |
|
"Number of tokens before compression: ", |
|
gr.update(), |
|
"Number of tokens after compression: ", |
|
0, |
|
gr.update(interactive=False), |
|
False, |
|
{}, |
|
chat_status |
|
) |
|
print("Converting to markdown") |
|
try: |
|
markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure) |
|
except RuntimeError as e: |
|
return ( |
|
gr.update(value=f"{str(e)} Please try uploading another document format."), |
|
"Number of tokens before compression: ", |
|
gr.update(), |
|
"Number of tokens after compression: ", |
|
0, |
|
gr.update(interactive=False), |
|
False, |
|
{}, |
|
chat_status |
|
) |
|
|
|
print("Done") |
|
combined_text = prefix + markdown |
|
print("Suggestioning Compression ratio") |
|
token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model) |
|
print("Done") |
|
min_ratio = min(suggestions) |
|
max_ratio = max(suggestions) |
|
default_ratio = 4 |
|
retrieval_tokens = int(token_count / default_ratio) |
|
token_count_str = f"Number of tokens before compression: {token_count}" |
|
retrieval_str = f"Number of tokens after compression: {retrieval_tokens}" |
|
slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1) |
|
if combined_text.startswith(prefix): |
|
rag_text = combined_text[len(prefix):] |
|
else: |
|
rag_text = combined_text |
|
current_timestamp = int(time.time()) |
|
collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" |
|
rag_index = create_rag_index(collection_name, rag_text) |
|
state = {"rag_index": collection_name} |
|
print("Done") |
|
return ( |
|
combined_text, |
|
token_count_str, |
|
slider_update, |
|
retrieval_str, |
|
token_count, |
|
gr.update(interactive=True), |
|
False, |
|
state, |
|
chat_status |
|
) |
|
|
|
def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): |
|
try: |
|
device = model.device |
|
dtype = model.dtype |
|
sink_tokens = sink_tokens |
|
num_chunks = step_size |
|
context_ids = context_ids.to(device) |
|
context_attention_mask = context_attention_mask.to(device) |
|
question_ids = question_ids.to(device) |
|
question_attention_mask = question_attention_mask.to(device) |
|
question_len = question_ids.size(1) |
|
total_len = context_ids.size(1) |
|
max_context_tokens_allowed = model.config.max_position_embeddings - question_len |
|
if total_len > max_context_tokens_allowed: |
|
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) |
|
if total_len <= sink_tokens or num_chunks == 1: |
|
context_ids_list = [context_ids] |
|
context_attention_mask_list = [context_attention_mask] |
|
else: |
|
remainder_len = total_len - sink_tokens |
|
base = remainder_len // num_chunks |
|
leftover = remainder_len % num_chunks |
|
chunk_sizes = [sink_tokens + base] |
|
for _ in range(num_chunks - 2): |
|
chunk_sizes.append(base) |
|
if num_chunks > 1: |
|
chunk_sizes.append(base + leftover) |
|
context_ids_list = [] |
|
context_attention_mask_list = [] |
|
offset = 0 |
|
for size in chunk_sizes: |
|
end = offset + size |
|
context_ids_list.append(context_ids[:, offset:end]) |
|
context_attention_mask_list.append(context_attention_mask[:, offset:end]) |
|
offset = end |
|
len_rest = max(total_len - sink_tokens, 1) |
|
compression_factor = len_rest // target_token_size |
|
if compression_factor < 1: |
|
compression_factor = 1 |
|
tokenized_doc_chunks = [] |
|
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): |
|
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) |
|
print("Number of chunks: ", len(tokenized_doc_chunks)) |
|
rotary_emb = model.model.rotary_emb.to(device) |
|
inv_freq = rotary_emb.inv_freq |
|
batch_size = question_ids.size(0) |
|
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) |
|
cache = FinchCache() |
|
past_cache_len = 0 |
|
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) |
|
num_chunks = len(tokenized_doc_chunks) |
|
query_context_matrices = {} |
|
def query_hook_fn(module, input, output): |
|
layer_idx = getattr(module, "layer_idx", None) |
|
if layer_idx is not None: |
|
query_states = output.detach() |
|
bsz, seq_len, hidden_dim = query_states.size() |
|
num_query_heads = module.num_query_heads |
|
head_dim = hidden_dim // num_query_heads |
|
query_states = ( |
|
query_states.view(bsz, seq_len, num_query_heads, head_dim) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() |
|
hooks = [] |
|
for i, layer in enumerate(model.model.layers): |
|
layer.self_attn.q_proj.layer_idx = i |
|
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads |
|
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) |
|
hooks.append(hook) |
|
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): |
|
current_seq_length = tokenized_doc_chunk["input_ids"].size(1) |
|
_current_chunk_offset = current_seq_length |
|
query_context_matrices.clear() |
|
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() |
|
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() |
|
segment_attention_mask = torch.cat( |
|
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 |
|
).contiguous() |
|
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() |
|
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() |
|
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens + chunk_input_ids.shape[1], |
|
past_seen_tokens + current_input_ids.shape[1], |
|
device=device |
|
) |
|
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( |
|
current_attention_mask, |
|
sequence_length=question_ids.size(1), |
|
target_length=current_attention_mask.size(-1), |
|
dtype=dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=current_input_ids.size(0), |
|
).contiguous() |
|
with torch.no_grad(): |
|
outputs = model.model( |
|
input_ids=current_input_ids, |
|
use_cache=True, |
|
past_key_values=cache, |
|
) |
|
cache = outputs.past_key_values |
|
len_question = question_ids.size(1) |
|
for layer_idx in range(len(model.model.layers)): |
|
key_matrix = cache.key_cache[layer_idx] |
|
query_matrix = query_context_matrices[layer_idx] |
|
layer_cache_pos = torch.arange( |
|
past_cache_len + current_seq_length, |
|
past_cache_len + current_seq_length + len_question, |
|
device=device |
|
) |
|
position_ids = layer_cache_pos.unsqueeze(0) |
|
cos, sin = rotary_emb(query_matrix, position_ids) |
|
cos = cos.unsqueeze(1) |
|
sin = sin.unsqueeze(1) |
|
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) |
|
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads |
|
key_matrix = repeat_kv(key_matrix, num_repeats) |
|
scaling = math.sqrt(model.config.head_dim) |
|
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling |
|
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] |
|
attention_matrix = attention_matrix + causal_mask_sliced |
|
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) |
|
tol = 1e-8 |
|
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) |
|
non_zero_counts = binary_mask.sum(dim=3, keepdim=True) |
|
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) |
|
attention_matrix = attention_matrix / non_zero_counts |
|
if j != num_chunks - 1: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() |
|
else: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() |
|
attention_matrix = torch.sum(attention_matrix, dim=-2) |
|
attention_matrix = attention_matrix.view( |
|
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 |
|
).sum(dim=2) |
|
full_context_size = attention_matrix.size(-1) |
|
attention_matrix[..., :sink_tokens] = float("inf") |
|
if j == num_chunks - 1: |
|
attention_matrix[..., -len_question:] = float("inf") |
|
if j == 0: |
|
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) |
|
k = min(k + past_cache_len, full_context_size) |
|
elif j < num_chunks - 1: |
|
to_keep_new = int(current_seq_length // compression_factor) |
|
k = min(past_cache_len + to_keep_new, full_context_size) |
|
else: |
|
desired_final = sink_tokens + target_token_size + len_question |
|
k = desired_final if full_context_size >= desired_final else full_context_size |
|
k = max(k, sink_tokens) |
|
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices |
|
selected_indices, _ = torch.sort(selected_indices, dim=-1) |
|
cache.compress_cache(layer_idx, selected_indices, inv_freq) |
|
past_cache_len = cache._seen_tokens |
|
past_attention_mask = torch.ones(1, past_cache_len, device=device) |
|
for hook in hooks: |
|
hook.remove() |
|
return cache |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to compress KV cache: {e}") |
|
|
|
def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples): |
|
k = max(1, rag_token_size // 256) |
|
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embedding_model, collection_name=collection_name) |
|
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) |
|
retrieved_docs = retriever.invoke(query) |
|
for doc in retrieved_docs: |
|
print("=================") |
|
print(doc.page_content) |
|
print("=================") |
|
formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
|
rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples |
|
return rag_context |
|
|
|
@spaces.GPU |
|
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()): |
|
progress(0, desc="Starting compression process") |
|
|
|
|
|
percentage = 0 if global_local_value == "RAG" else 100 |
|
|
|
progress(0.1, desc="Tokenizing text and preparing task") |
|
question_text = task_description + "\n" + few_shot |
|
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device) |
|
question_encoding = tokenizer(question_text, return_tensors="pt").to(device) |
|
context_ids = context_encoding["input_ids"] |
|
context_attention_mask = context_encoding["attention_mask"] |
|
question_ids = question_encoding["input_ids"] |
|
question_attention_mask = question_encoding["attention_mask"] |
|
|
|
retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value) |
|
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) |
|
kv_tokens = retrieval_context_length - rag_tokens |
|
progress(0.2, desc=f"Token breakdown computed: {kv_tokens} KV tokens, {rag_tokens} RAG tokens") |
|
|
|
if percentage > 0: |
|
target_token_size = int(retrieval_context_length * (percentage / 100)) |
|
progress(0.3, desc="Starting KV cache compression") |
|
step_size = 2 |
|
try: |
|
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size, |
|
context_ids, context_attention_mask, |
|
question_ids, question_attention_mask)) |
|
except Exception as e: |
|
progress(1, desc="Compression failed") |
|
print("Error during KV cache compression:", e) |
|
state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again." |
|
return state, False |
|
compressed_length = past_key_values.get_seq_length() |
|
progress(0.6, desc="KV cache compression completed") |
|
else: |
|
target_token_size = 0 |
|
past_key_values = FinchCache() |
|
compressed_length = past_key_values.get_seq_length() |
|
progress(0.3, desc="Skipping compression as percentage is 0") |
|
|
|
current_timestamp = int(time.time()) |
|
cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt" |
|
save_dir = "./cache_dir" |
|
os.makedirs(save_dir, exist_ok=True) |
|
save_path = os.path.join(save_dir, cache_name) |
|
past_key_values.save(save_path) |
|
progress(0.8, desc="Cache saved successfully") |
|
|
|
collection_name = state.get("rag_index", None) |
|
if collection_name is None: |
|
print("Collection name not found; creating a new one.") |
|
if combined_text.startswith(prefix): |
|
rag_text = combined_text[len(prefix):] |
|
else: |
|
rag_text = combined_text |
|
current_timestamp = int(time.time()) |
|
collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" |
|
rag_index = create_rag_index(collection_name, rag_text) |
|
|
|
state.update({ |
|
"compressed_cache": save_path, |
|
"rag_index": collection_name, |
|
"global_local": percentage, |
|
"task_description": task_description, |
|
"few_shot": few_shot, |
|
"retrieval_slider": retrieval_context_length, |
|
}) |
|
progress(1, desc="Compression complete") |
|
return state, "Document compressed successfully. You can now chat.", True |
|
|
|
|
|
@spaces.GPU |
|
def chat_response_stream(message: str, history: list, state: dict, compression_done: bool): |
|
|
|
if not compression_done or "compressed_cache" not in state: |
|
yield "Document not compressed yet. Please compress the document first to enable chat." |
|
return |
|
user_message = message |
|
save_path = state["compressed_cache"] |
|
past_key_values = FinchCache.load(save_path, device=model.device) |
|
compressed_length = past_key_values.get_seq_length() |
|
collection_name = state["rag_index"] |
|
retrieval_slider_value = state["retrieval_slider"] |
|
percentage = state["global_local"] |
|
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100))) |
|
print("RAG retrieval size: ", rag_retrieval_size) |
|
print("Compressed cache: ", compressed_length) |
|
if percentage == 0: |
|
rag_prefix = prefix |
|
rag_task = state["task_description"] |
|
rag_few_shot = state["few_shot"] |
|
else: |
|
rag_prefix = "" |
|
rag_task = "" |
|
rag_few_shot = "" |
|
print("user message: ", user_message) |
|
if rag_retrieval_size != 0: |
|
print("Running RAG query") |
|
rag_context = run_naive_rag_query(collection_name, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot) |
|
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:" |
|
else: |
|
new_input = "\nquestion: " + user_message + suffix + "answer:" |
|
tokenized_new_input = tokenizer(new_input, return_tensors="pt").to(device) |
|
eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long) |
|
new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1) |
|
new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1) |
|
print("New input is: ", new_input) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=new_input_ids, |
|
attention_mask=new_attention_mask, |
|
past_key_values=past_key_values, |
|
streamer=streamer, |
|
use_cache=True, |
|
max_new_tokens=1024, |
|
num_beams=1, |
|
do_sample=False, |
|
top_p=1.0, |
|
top_k=None, |
|
temperature=1.0, |
|
|
|
|
|
|
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
full_output = "" |
|
for text in streamer: |
|
full_output += text |
|
time.sleep(0.05) |
|
yield full_output |
|
return full_output |
|
|
|
def update_token_breakdown(token_count, retrieval_slider, global_local_value): |
|
retrieval_context_length = int(token_count / retrieval_slider) |
|
|
|
percentage = 0 if global_local_value == "RAG" else 100 |
|
|
|
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) |
|
kv_tokens = retrieval_context_length - rag_tokens |
|
return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)", f"Number of tokens after compression: {retrieval_context_length}" |
|
|
|
|
|
|
|
|
|
CSS = """ |
|
.main-container { |
|
display: flex; |
|
align-items: stretch; |
|
} |
|
|
|
.upload-section, .chatbot-container { |
|
display: flex; |
|
flex-direction: column; |
|
height: 100%; |
|
overflow-y: auto; |
|
} |
|
|
|
.upload-section { |
|
padding: 10px; |
|
border: 2px dashed #ccc; |
|
border-radius: 10px; |
|
} |
|
|
|
.upload-button { |
|
background: #34c759 !important; |
|
color: white !important; |
|
border-radius: 25px !important; |
|
} |
|
|
|
.chatbot-container { |
|
margin-top: 0; |
|
} |
|
|
|
.status-output { |
|
margin-top: 10px; |
|
font-size: 14px; |
|
} |
|
|
|
.processing-info { |
|
margin-top: 5px; |
|
font-size: 12px; |
|
color: #666; |
|
} |
|
|
|
.info-container { |
|
margin-top: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
|
|
.file-list { |
|
margin-top: 0; |
|
max-height: 200px; |
|
overflow-y: auto; |
|
padding: 5px; |
|
border: 1px solid #eee; |
|
border-radius: 5px; |
|
} |
|
|
|
.stats-box { |
|
margin-top: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
font-size: 12px; |
|
} |
|
|
|
.submit-btn { |
|
background: #1a73e8 !important; |
|
color: white !important; |
|
border-radius: 25px !important; |
|
margin-left: 10px; |
|
padding: 5px 10px; |
|
font-size: 16px; |
|
} |
|
|
|
.input-row { |
|
display: flex; |
|
align-items: center; |
|
} |
|
""" |
|
def reset_chat_state(): |
|
return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False |
|
|
|
with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo: |
|
|
|
gr.HTML("<h1><center>Beyond RAG: Compress your document and chat with it.</center></h1>") |
|
|
|
|
|
chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5) |
|
|
|
hidden_token_count = gr.State(value=0) |
|
compression_done = gr.State(value=False) |
|
compressed_doc_state = gr.State(value={}) |
|
|
|
with gr.Row(elem_classes="main-container"): |
|
with gr.Column(elem_classes="upload-section"): |
|
gr.Markdown("### Document Preprocessing") |
|
with gr.Row(): |
|
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area", height=120) |
|
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf", lines=2) |
|
with gr.Row(): |
|
do_ocr = gr.Checkbox(label="Do OCR on Images", value=False, visible=False) |
|
do_table = gr.Checkbox(label="Parse Tables", value=False, visible=False) |
|
with gr.Accordion("Prompt Designer", open=False): |
|
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description") |
|
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot") |
|
with gr.Accordion("Show Markdown Output", open=False): |
|
markdown_output = gr.Textbox(label="Markdown Output", lines=20) |
|
token_count_text = gr.Markdown("Number of tokens before compression: ") |
|
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2) |
|
retrieval_info_text = gr.Markdown("Number of tokens after compression: ") |
|
tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.") |
|
|
|
|
|
global_local_slider = gr.Radio( |
|
label="Retrieval Mode", |
|
choices=["RAG", "KVCompress"], |
|
value="KVCompress" |
|
) |
|
|
|
compress_button = gr.Button("Compress Document", interactive=False, size="md", elem_classes="upload-button") |
|
|
|
|
|
file_input.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, |
|
hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] |
|
).then( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
url_input.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, |
|
hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] |
|
).then( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
do_ocr.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, |
|
hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] |
|
).then( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
do_table.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, |
|
hidden_token_count, compress_button, compression_done, compressed_doc_state, chat_status_text] |
|
).then( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
task_description_input.change( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
few_shot_input.change( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
markdown_output.change( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
retrieval_slider.change( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
global_local_slider.change( |
|
fn=reset_chat_state, |
|
inputs=None, |
|
outputs=[chat_status_text, compression_done] |
|
) |
|
|
|
|
|
retrieval_slider.change( |
|
fn=update_token_breakdown, |
|
inputs=[hidden_token_count, retrieval_slider, global_local_slider], |
|
outputs=[tokens_breakdown_text, retrieval_info_text] |
|
) |
|
global_local_slider.change( |
|
fn=update_token_breakdown, |
|
inputs=[hidden_token_count, retrieval_slider, global_local_slider], |
|
outputs=[tokens_breakdown_text, retrieval_info_text] |
|
) |
|
|
|
|
|
compress_button.click( |
|
fn=prepare_compression_and_rag, |
|
inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state], |
|
outputs=[compressed_doc_state, chat_status_text, compression_done] |
|
) |
|
with gr.Column(elem_classes="chatbot-container"): |
|
chat_status_text.render() |
|
gr.Markdown("## Chat (LLama 3.1-8B-Instruct)") |
|
gr.Markdown("**Note:** There is currently no chat history available.") |
|
chat_interface = gr.ChatInterface( |
|
fn=chat_response_stream, |
|
additional_inputs=[compressed_doc_state, compression_done], |
|
type="messages", |
|
fill_height=True |
|
) |
|
|
|
demo.queue(max_size=16).launch() |