|
import copy |
|
import math |
|
import os |
|
import time |
|
from threading import Thread |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend |
|
from docling.datamodel.pipeline_options import PdfPipelineOptions |
|
from docling.document_converter import DocumentConverter, InputFormat, PdfFormatOption |
|
from langchain.schema.document import Document |
|
from langchain_chroma import Chroma |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_docling import DoclingLoader |
|
from langchain_docling.loader import ExportType |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer |
|
from transformers.models.llama.modeling_llama import rotate_half |
|
|
|
from utils import ( |
|
calculate_tokens_suggest_compression_ratio, |
|
repeat_kv, |
|
update_retrieval_context, |
|
) |
|
|
|
|
|
|
|
|
|
api_token = os.getenv("HF_TOKEN") |
|
model_name = "meta-llama/Llama-3.1-8B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.eval() |
|
model.to(device) |
|
embedding_model = HuggingFaceBgeEmbeddings( |
|
model_name="BAAI/bge-large-en-v1.5", |
|
model_kwargs={"device": str(device)}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
query_instruction="" |
|
) |
|
|
|
|
|
|
|
content_system = "" |
|
content_user = "######" |
|
user_template = [ |
|
{"role": "system", "content": content_system}, |
|
{"role": "user", "content": content_user} |
|
] |
|
user = tokenizer.apply_chat_template(user_template, add_generation_prompt=True, tokenize=False) |
|
prefix, suffix = user.split(content_user) |
|
sink_tokens = max(4, len(tokenizer.encode(prefix))) |
|
|
|
|
|
default_task_description = ( |
|
"Answer the question based on the given passages. " |
|
"Only give me the answer and do not output any other words." |
|
) |
|
default_few_shot = """Examples |
|
question: Which case was brought to court first Miller v. California or Gates v. Collier ? |
|
answer: Miller v. California |
|
question: The actor that plays Phileas Fogg in "Around the World in 80 Days", co-starred with Gary Cooper in a 1939 Goldwyn Productions film based on a novel by what author? |
|
answer: Charles L. Clifford |
|
question: Prior to playing for Michigan State, Keith Nichol played football for a school located in what city? |
|
answer: Norman |
|
""" |
|
|
|
class FinchCache(DynamicCache): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.key_cache = [] |
|
self.value_cache = [] |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
return (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
|
|
@staticmethod |
|
def _rerotate_cos_sin(x, inv_freq, important_pos_batch): |
|
B, H, L = important_pos_batch.shape |
|
device = important_pos_batch.device |
|
device_type = x.device.type |
|
dtype = x.dtype |
|
idx = torch.arange(0, L, device=device) |
|
idx = idx.unsqueeze(0) |
|
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) |
|
idx = idx[:, None, :].float().expand(B, H, L) |
|
delta_pos = idx - important_pos_batch |
|
delta_pos = delta_pos.unsqueeze(2) |
|
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = delta_pos.float() * inv_freq.float() |
|
freqs = freqs.transpose(2, 3) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().contiguous() |
|
sin = emb.sin().contiguous() |
|
return cos.to(dtype=dtype), sin.to(dtype=dtype) |
|
|
|
@staticmethod |
|
def gather_important_tokens(states, indices): |
|
return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() |
|
|
|
def compress_cache(self, layer_index, important_pos, inv_freq): |
|
new_length = important_pos.size(2) |
|
new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) |
|
gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() |
|
self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) |
|
gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() |
|
self.value_cache[layer_index] = gathered_values |
|
self._seen_tokens = new_length |
|
|
|
def save(self, path: str): |
|
"""Save the cache to disk, moving tensors to CPU.""" |
|
try: |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
torch.save( |
|
{"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, |
|
path, |
|
) |
|
except Exception as e: |
|
print(f"Error occurred while saving: {e}") |
|
|
|
@classmethod |
|
def load(cls, path: str, device: str = "cpu") -> "FinchCache": |
|
"""Load the cache from disk and move tensors to the specified device.""" |
|
data = torch.load(path, map_location=device) |
|
cache = cls() |
|
cache.key_cache = [k.to(device) for k in data["key_cache"]] |
|
cache.value_cache = [v.to(device) for v in data["value_cache"]] |
|
cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 |
|
return cache |
|
|
|
|
|
|
|
def convert_to_markdown(file_objs, url, do_ocr, do_table_structure): |
|
file_path = file_objs if file_objs is not None else url |
|
pipeline_options = PdfPipelineOptions() |
|
pipeline_options.do_ocr = do_ocr |
|
pipeline_options.do_table_structure = do_table_structure |
|
pdf_format_options = PdfFormatOption( |
|
pipeline_options=pipeline_options, |
|
backend=PyPdfiumDocumentBackend, |
|
) |
|
doc_converter = DocumentConverter( |
|
allowed_formats=[InputFormat.PDF], |
|
format_options={ |
|
InputFormat.PDF: pdf_format_options |
|
} |
|
) |
|
|
|
|
|
loader = DoclingLoader( |
|
file_path=file_path, |
|
export_type=ExportType.MARKDOWN, |
|
converter=doc_converter |
|
) |
|
docs = loader.load() |
|
return docs[0].page_content |
|
|
|
def create_rag_index(text_no_prefix): |
|
"""Loads the PDF, splits its text, and builds a vectorstore for naive RAG.""" |
|
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
|
tokenizer, |
|
chunk_size=256, |
|
chunk_overlap=0, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
|
|
docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] |
|
vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model) |
|
return vectorstore |
|
|
|
|
|
@spaces.GPU |
|
def auto_convert(file_objs, url, do_ocr, do_table_structure): |
|
if file_objs is None and (url is None or url.strip() == ""): |
|
return ( |
|
gr.update(value=""), |
|
"Number of tokens before compression: ", |
|
gr.update(), |
|
"Number of tokens after compression: ", |
|
0, |
|
gr.update(interactive=False), |
|
False, |
|
{} |
|
) |
|
|
|
print("Converting to markdown") |
|
markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure) |
|
print("Done") |
|
combined_text = prefix + markdown |
|
print("Calculating tokens") |
|
token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model) |
|
print("Done") |
|
min_ratio = min(suggestions) |
|
max_ratio = max(suggestions) |
|
default_ratio = suggestions[len(suggestions) // 2] |
|
retrieval_tokens = int(token_count / default_ratio) |
|
token_count_str = f"Number of tokens before compression: {token_count}" |
|
retrieval_str = f"Number of tokens after compression: {retrieval_tokens}" |
|
slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1) |
|
|
|
|
|
if combined_text.startswith(prefix): |
|
rag_text = combined_text[len(prefix):] |
|
else: |
|
rag_text = combined_text |
|
print("Creating RAG index") |
|
rag_index = create_rag_index(rag_text) |
|
print("Done") |
|
state = {"rag_index": rag_index} |
|
|
|
return ( |
|
combined_text, |
|
token_count_str, |
|
slider_update, |
|
retrieval_str, |
|
token_count, |
|
gr.update(interactive=True), |
|
False, |
|
state |
|
) |
|
|
|
|
|
def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): |
|
device = model.device |
|
dtype = model.dtype |
|
sink_tokens = sink_tokens |
|
num_chunks = step_size |
|
context_ids = context_ids.to(device) |
|
context_attention_mask = context_attention_mask.to(device) |
|
question_ids = question_ids.to(device) |
|
question_attention_mask = question_attention_mask.to(device) |
|
question_len = question_ids.size(1) |
|
total_len = context_ids.size(1) |
|
max_context_tokens_allowed = model.config.max_position_embeddings - question_len |
|
if total_len > max_context_tokens_allowed: |
|
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) |
|
|
|
if total_len <= sink_tokens or num_chunks == 1: |
|
|
|
context_ids_list = [context_ids] |
|
context_attention_mask_list = [context_attention_mask] |
|
else: |
|
|
|
remainder_len = total_len - sink_tokens |
|
|
|
|
|
base = remainder_len // num_chunks |
|
leftover = remainder_len % num_chunks |
|
|
|
|
|
|
|
chunk_sizes = [sink_tokens + base] |
|
|
|
|
|
for _ in range(num_chunks - 2): |
|
chunk_sizes.append(base) |
|
|
|
|
|
if num_chunks > 1: |
|
chunk_sizes.append(base + leftover) |
|
|
|
|
|
context_ids_list = [] |
|
context_attention_mask_list = [] |
|
offset = 0 |
|
for size in chunk_sizes: |
|
end = offset + size |
|
context_ids_list.append(context_ids[:, offset:end]) |
|
context_attention_mask_list.append(context_attention_mask[:, offset:end]) |
|
offset = end |
|
|
|
|
|
len_rest = max(total_len - sink_tokens, 1) |
|
compression_factor = len_rest // target_token_size |
|
if compression_factor < 1: |
|
compression_factor = 1 |
|
|
|
tokenized_doc_chunks = [] |
|
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): |
|
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) |
|
|
|
print("Number of chunks: ", len(tokenized_doc_chunks)) |
|
|
|
rotary_emb = model.model.rotary_emb.to(device) |
|
inv_freq = rotary_emb.inv_freq |
|
batch_size = question_ids.size(0) |
|
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) |
|
|
|
cache = FinchCache() |
|
past_cache_len = 0 |
|
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) |
|
num_chunks = len(tokenized_doc_chunks) |
|
|
|
|
|
query_context_matrices = {} |
|
|
|
|
|
def query_hook_fn(module, input, output): |
|
layer_idx = getattr(module, "layer_idx", None) |
|
if layer_idx is not None: |
|
query_states = output.detach() |
|
bsz, seq_len, hidden_dim = query_states.size() |
|
num_query_heads = module.num_query_heads |
|
head_dim = hidden_dim // num_query_heads |
|
query_states = ( |
|
query_states.view(bsz, seq_len, num_query_heads, head_dim) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
|
|
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() |
|
|
|
|
|
hooks = [] |
|
for i, layer in enumerate(model.model.layers): |
|
layer.self_attn.q_proj.layer_idx = i |
|
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads |
|
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) |
|
hooks.append(hook) |
|
|
|
|
|
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): |
|
current_seq_length = tokenized_doc_chunk["input_ids"].size(1) |
|
|
|
_current_chunk_offset = current_seq_length |
|
|
|
query_context_matrices.clear() |
|
|
|
|
|
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() |
|
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() |
|
segment_attention_mask = torch.cat( |
|
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 |
|
).contiguous() |
|
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() |
|
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() |
|
|
|
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens + chunk_input_ids.shape[1], |
|
past_seen_tokens + current_input_ids.shape[1], |
|
device=device |
|
) |
|
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( |
|
current_attention_mask, |
|
sequence_length=question_ids.size(1), |
|
target_length=current_attention_mask.size(-1), |
|
dtype=dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=current_input_ids.size(0), |
|
).contiguous() |
|
|
|
with torch.no_grad(): |
|
outputs = model.model( |
|
input_ids=current_input_ids, |
|
use_cache=True, |
|
past_key_values=cache, |
|
) |
|
cache = outputs.past_key_values |
|
|
|
len_question = question_ids.size(1) |
|
|
|
for layer_idx in range(len(model.model.layers)): |
|
key_matrix = cache.key_cache[layer_idx] |
|
query_matrix = query_context_matrices[layer_idx] |
|
layer_cache_pos = torch.arange( |
|
past_cache_len + current_seq_length, |
|
past_cache_len + current_seq_length + len_question, |
|
device=device |
|
) |
|
position_ids = layer_cache_pos.unsqueeze(0) |
|
cos, sin = rotary_emb(query_matrix, position_ids) |
|
cos = cos.unsqueeze(1) |
|
sin = sin.unsqueeze(1) |
|
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) |
|
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads |
|
key_matrix = repeat_kv(key_matrix, num_repeats) |
|
|
|
scaling = math.sqrt(model.config.head_dim) |
|
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling |
|
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] |
|
attention_matrix = attention_matrix + causal_mask_sliced |
|
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) |
|
|
|
tol = 1e-8 |
|
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) |
|
non_zero_counts = binary_mask.sum(dim=3, keepdim=True) |
|
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) |
|
attention_matrix = attention_matrix / non_zero_counts |
|
if j != num_chunks - 1: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() |
|
else: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() |
|
attention_matrix = torch.sum(attention_matrix, dim=-2) |
|
attention_matrix = attention_matrix.view( |
|
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 |
|
).sum(dim=2) |
|
full_context_size = attention_matrix.size(-1) |
|
attention_matrix[..., :sink_tokens] = float("inf") |
|
if j == num_chunks - 1: |
|
attention_matrix[..., -len_question:] = float("inf") |
|
if j == 0: |
|
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) |
|
k = min(k + past_cache_len, full_context_size) |
|
elif j < num_chunks - 1: |
|
to_keep_new = int(current_seq_length // compression_factor) |
|
k = min(past_cache_len + to_keep_new, full_context_size) |
|
else: |
|
desired_final = sink_tokens + target_token_size + len_question |
|
k = desired_final if full_context_size >= desired_final else full_context_size |
|
k = max(k, sink_tokens) |
|
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices |
|
selected_indices, _ = torch.sort(selected_indices, dim=-1) |
|
cache.compress_cache(layer_idx, selected_indices, inv_freq) |
|
|
|
past_cache_len = cache._seen_tokens |
|
past_attention_mask = torch.ones(1, past_cache_len, device=device) |
|
|
|
|
|
for hook in hooks: |
|
hook.remove() |
|
|
|
return cache |
|
|
|
|
|
def run_naive_rag_query(vectorstore, query, rag_token_size, prefix, task, few_shot_examples): |
|
""" |
|
For naive RAG, retrieves top-k chunks (k based on target token size) |
|
and generates an answer using those chunks. |
|
""" |
|
k = max(1, rag_token_size // 256) |
|
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) |
|
retrieved_docs = retriever.invoke(query) |
|
for doc in retrieved_docs: |
|
print("=================") |
|
print(doc.page_content) |
|
print("=================") |
|
formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
|
rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples |
|
|
|
return rag_context |
|
|
|
|
|
@spaces.GPU |
|
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state): |
|
""" |
|
Prepares the compressed KV cache. Uses the precomputed rag_index from state. |
|
""" |
|
percentage = int(global_local_value.replace('%', '')) |
|
question_text = task_description + "\n" + few_shot |
|
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device) |
|
question_encoding = tokenizer(question_text, return_tensors="pt").to(device) |
|
context_ids = context_encoding["input_ids"] |
|
context_attention_mask = context_encoding["attention_mask"] |
|
question_ids = question_encoding["input_ids"] |
|
question_attention_mask = question_encoding["attention_mask"] |
|
retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value) |
|
|
|
if percentage > 0: |
|
target_token_size = int(retrieval_context_length * (percentage / 100)) |
|
print("Target token size for compression: ", target_token_size) |
|
step_size = 2 |
|
start_time_prefill = time.perf_counter() |
|
past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size, |
|
context_ids, context_attention_mask, |
|
question_ids, question_attention_mask)) |
|
compressed_length = past_key_values.get_seq_length() |
|
print("Context size after compression: ", compressed_length) |
|
print("Compression rate: ", context_ids.size(1) / compressed_length) |
|
else: |
|
start_time_prefill = 0 |
|
target_token_size = 0 |
|
past_key_values = FinchCache() |
|
compressed_length = past_key_values.get_seq_length() |
|
|
|
|
|
|
|
rag_index = state.get("rag_index", None) |
|
if rag_index is None: |
|
if combined_text.startswith(prefix): |
|
rag_text = combined_text[len(prefix):] |
|
else: |
|
rag_text = combined_text |
|
rag_index = create_rag_index(rag_text, device) |
|
|
|
state.update({ |
|
"compressed_cache": past_key_values, |
|
"compressed_length": compressed_length, |
|
"rag_index": rag_index, |
|
"target_token_size": target_token_size, |
|
"global_local": percentage, |
|
"combined_text": combined_text, |
|
"task_description": task_description, |
|
"few_shot": few_shot, |
|
"retrieval_slider": retrieval_context_length, |
|
"prefill_time": time.perf_counter() - start_time_prefill |
|
}) |
|
return state, True |
|
|
|
|
|
@spaces.GPU |
|
def chat_response_stream(message: str, history: list, state: dict): |
|
""" |
|
Generates a chat response with streaming output. |
|
Returns a simple string (not a list of message dicts) for ChatInterface. |
|
""" |
|
user_message = message |
|
past_key_values = state["compressed_cache"] |
|
compressed_length = past_key_values.get_seq_length() |
|
rag_index = state["rag_index"] |
|
retrieval_slider_value = state["retrieval_slider"] |
|
percentage = state["global_local"] |
|
|
|
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100))) |
|
print("RAG retrieval size: ", rag_retrieval_size) |
|
|
|
if percentage == 0: |
|
rag_prefix = prefix |
|
rag_task = state["task_description"] |
|
rag_few_shot = state["few_shot"] |
|
else: |
|
rag_prefix = "" |
|
rag_task = "" |
|
rag_few_shot = "" |
|
print("user message: ", user_message) |
|
if rag_retrieval_size != 0: |
|
rag_context = run_naive_rag_query(rag_index, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot) |
|
new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:" |
|
else: |
|
new_input = "\nquestion: " + user_message + suffix + "answer:" |
|
tokenized_new_input = tokenizer(new_input, return_tensors="pt").to(device) |
|
eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long) |
|
new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1) |
|
new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1) |
|
|
|
print("New input is: ", new_input) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=new_input_ids, |
|
attention_mask=new_attention_mask, |
|
past_key_values=past_key_values, |
|
streamer=streamer, |
|
use_cache=True, |
|
max_new_tokens=1024, |
|
num_beams=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_p=1.0, |
|
top_k=None, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
full_output = "" |
|
for text in streamer: |
|
full_output += text |
|
time.sleep(0.05) |
|
yield full_output |
|
|
|
state["compressed_cache"] = past_key_values |
|
return full_output |
|
|
|
|
|
|
|
|
|
CSS = """ |
|
body { |
|
font-family: "Times New Roman", Times, serif; |
|
} |
|
.upload-section { |
|
padding: 10px; |
|
border: 2px dashed #ccc; |
|
border-radius: 10px; |
|
} |
|
.upload-button { |
|
background: #34c759 !important; |
|
color: white !important; |
|
border-radius: 25px !important; |
|
} |
|
.chatbot-container { |
|
margin-top: 20px; |
|
} |
|
.status-output { |
|
margin-top: 10px; |
|
font-size: 14px; |
|
} |
|
.processing-info { |
|
margin-top: 5px; |
|
font-size: 12px; |
|
color: #666; |
|
} |
|
.info-container { |
|
margin-top: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
.file-list { |
|
margin-top: 0; |
|
max-height: 200px; |
|
overflow-y: auto; |
|
padding: 5px; |
|
border: 1px solid #eee; |
|
border-radius: 5px; |
|
} |
|
.stats-box { |
|
margin-top: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
font-size: 12px; |
|
} |
|
.submit-btn { |
|
background: #1a73e8 !important; |
|
color: white !important; |
|
border-radius: 25px !important; |
|
margin-left: 10px; |
|
padding: 5px 10px; |
|
font-size: 16px; |
|
} |
|
.input-row { |
|
display: flex; |
|
align-items: center; |
|
} |
|
@media (min-width: 768px) { |
|
.main-container { |
|
display: flex; |
|
justify-content: space-between; |
|
gap: 20px; |
|
} |
|
.upload-section { |
|
flex: 3; |
|
} |
|
.chatbot-container { |
|
flex: 1; |
|
margin-top: 0; |
|
} |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: |
|
gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>") |
|
gr.HTML("<center><p>Compress your document and chat with it.</p></center>") |
|
|
|
hidden_token_count = gr.State(value=0) |
|
compression_done = gr.State(value=False) |
|
compressed_doc_state = gr.State(value={}) |
|
|
|
with gr.Row(elem_classes="main-container"): |
|
with gr.Column(elem_classes="upload-section"): |
|
gr.Markdown("## Document Preprocessing") |
|
with gr.Row(): |
|
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area") |
|
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf") |
|
with gr.Row(): |
|
do_ocr = gr.Checkbox(label="Do OCR", value=False) |
|
do_table = gr.Checkbox(label="Include Table Structure", value=False) |
|
with gr.Accordion("Prompt Designer", open=False): |
|
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description") |
|
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot") |
|
with gr.Accordion("Show Markdown Output", open=False): |
|
markdown_output = gr.Textbox(label="Markdown Output", lines=20) |
|
token_count_text = gr.Markdown("Number of tokens before compression: ") |
|
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2) |
|
retrieval_info_text = gr.Markdown("Number of tokens after compression: ") |
|
global_local_slider = gr.Radio(label="Global vs Local (0 is all RAG, 100 is all global)", |
|
choices=["0%", "25%", "50%", "75%", "100%"], value="75%") |
|
compress_button = gr.Button("Compress Document", interactive=False, elem_classes="upload-button") |
|
|
|
file_input.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state] |
|
) |
|
url_input.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state] |
|
) |
|
do_ocr.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state] |
|
) |
|
do_table.change( |
|
fn=auto_convert, |
|
inputs=[file_input, url_input, do_ocr, do_table], |
|
outputs=[markdown_output, token_count_text, retrieval_slider, retrieval_info_text, hidden_token_count, compress_button, compression_done, compressed_doc_state] |
|
) |
|
retrieval_slider.change( |
|
fn=update_retrieval_context, |
|
inputs=[hidden_token_count, retrieval_slider], |
|
outputs=retrieval_info_text |
|
) |
|
compress_button.click( |
|
fn=prepare_compression_and_rag, |
|
inputs=[markdown_output, retrieval_slider, global_local_slider, task_description_input, few_shot_input, compressed_doc_state], |
|
outputs=[compressed_doc_state, compression_done] |
|
) |
|
|
|
with gr.Column(elem_classes="chatbot-container"): |
|
gr.Markdown("## Chat") |
|
chat_interface = gr.ChatInterface( |
|
fn=chat_response_stream, |
|
additional_inputs=[compressed_doc_state], |
|
type="messages" |
|
) |
|
|
|
demo.queue().launch() |
|
|