Spaces:
Building
Building
import httpx | |
import os | |
import time | |
import subprocess | |
import uuid | |
from loguru import logger | |
from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict | |
import httpx | |
import os | |
import time | |
import subprocess | |
import uuid | |
import streamlit as st | |
from openai import OpenAI | |
import base64 | |
from tqdm import tqdm | |
from app.config import app_settings | |
from app.qdrant_db import MyQdrantClient | |
from app.vdr_utils import ( | |
get_text_embedding, | |
get_image_embedding, | |
pdf_folder_to_images, | |
scale_image, | |
pil_image_to_base64, | |
load_images, | |
) | |
class VDRSession: | |
def __init__(self): | |
self.client = None | |
self.api_key = None | |
self.base_url = app_settings.GLOBAL_API_BASE | |
self.SAVE_DIR = None | |
self.db_collection = None | |
self.session_id = str(uuid.uuid4())[:5] | |
self.indexed_images = [] | |
self.vector_db_client = None | |
def set_api_key(self, api_key: str): | |
if api_key is not None and len(api_key)>10: | |
try: | |
api_key = api_key.strip() | |
client = OpenAI(api_key=api_key, | |
base_url=self.base_url) | |
models = client.models.list() | |
if models: | |
self.api_key = api_key | |
self.client = client | |
return True | |
except Exception as e: | |
logger.debug(f'Incorrect API Key: {e}') | |
self.client = None | |
return False | |
def set_context(self, embed_model: str): | |
self.embed_model = embed_model | |
if not self.SAVE_DIR: | |
self.SAVE_DIR=os.path.join('./temp_data', self.session_id) | |
os.makedirs(self.SAVE_DIR, exist_ok=True) | |
self.SAVE_IMAGE_DIR=os.path.join(self.SAVE_DIR, 'images') | |
logger.debug(f'Created folder: {self.SAVE_DIR} and {self.SAVE_IMAGE_DIR}') | |
if not self.vector_db_client: | |
self.vector_db_client = MyQdrantClient(path=self.SAVE_DIR) | |
if not self.db_collection: | |
self.db_collection = f"qd-{embed_model}-{self.session_id}" | |
try: | |
if self.embed_model == "tsi-embedding-colqwen2-2b-v1": | |
self.vector_db_client.create_collection(self.db_collection, vector_dim=128, vector_type="colbert") | |
elif self.embed_model == "jina-embedding-clip-v1": | |
self.vector_db_client.create_collection(self.db_collection, vector_dim=768, vector_type="dense") | |
else: | |
raise ValueError(f"Embedding model {self.embed_model} not supported") | |
except Exception as e: | |
logger.error(f"Error while creating collection: {e}") | |
return True | |
def get_available_vlms(self) -> List[str]: | |
assert self.client != None | |
model_name_list = [] | |
try: | |
models = self.client.models.list() | |
for model in models.data: | |
model_name = model.id | |
substrings = ['gemini','QWEN-VL2-7B'] | |
if any(substring in model_name for substring in substrings): | |
model_name_list.append(model.id) | |
except Exception as e: | |
logger.error(f"Error while query all models: {e}") | |
raise e | |
# Prioritize name | |
# Remove the item if it exists in the list | |
priority_item = "gemini-2.0-flash-exp-US" | |
if priority_item in model_name_list: | |
model_name_list.remove(priority_item) | |
# Insert the item at the beginning of the list | |
model_name_list.insert(0, priority_item) | |
return model_name_list | |
def get_available_image_embeds(self) -> List[str]: | |
assert self.client != None | |
model_name_list = [] | |
try: | |
models = self.client.models.list() | |
for model in models.data: | |
model_name = model.id | |
substrings = ['tsi-embedding','clip'] | |
if any(substring in model_name for substring in substrings): | |
model_name_list.append(model.id) | |
except Exception as e: | |
logger.error(f"Error while query all models: {e}") | |
raise e | |
return model_name_list | |
def search_images(self, text: str, top_k: int = 5) -> list[str]: | |
assert self.client != None | |
assert self.vector_db_client != None | |
try: | |
if not self.indexed_images: | |
raise Exception("No indexed images found. You need to click on 'Add selected context' button to index images.") | |
text = text.strip() | |
if len(text) < 2: | |
return False | |
embeddings = get_text_embedding( | |
texts=text, | |
openai_client=self.client, | |
model=self.embed_model | |
)[0] | |
index_results = self.vector_db_client.query_multivector( | |
multivector_input=embeddings, | |
collection_name=self.db_collection, | |
top_k=top_k | |
) | |
image_list=[self.indexed_images[i] for i in index_results] | |
images = [] | |
for img in image_list: | |
#with open(file, "rb") as image: | |
#encoded = base64.b64encode(image.read()).decode() | |
encoded = pil_image_to_base64(img) | |
images.append(f"data:image/png;base64,{encoded}") | |
return images | |
except Exception as e: | |
logger.error(f"Error while generating image: {e}") | |
raise e | |
def ask(self, query: str, model: str, prompt_template: str, retrieved_context: Any, modality: str = "image", stream: bool = False) -> str: | |
assert self.client != None | |
assert query != None | |
assert prompt_template != None | |
assert retrieved_context != None | |
try: | |
prompt = prompt_template.format(user_question=query) | |
if modality == "image": | |
context = [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": base64_image | |
} | |
} for base64_image in retrieved_context | |
] | |
content = [ | |
{ | |
"type": "text", | |
"text": prompt | |
} | |
] | |
content=content+context | |
messages=[ | |
{ | |
"role": "user", | |
"content": content, | |
} | |
] | |
chat_response = self.client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=0.1, | |
max_tokens=2048, | |
stream=stream, | |
) | |
if not stream: | |
return chat_response.choices[0].message.content | |
else: | |
for chunk in chat_response: | |
if chunk.choices: | |
if chunk.choices[0].delta.content is not None: | |
yield chunk.choices[0].delta.content | |
#print(chunk.choices[0].delta.content, end="", flush=True) | |
except Exception as e: | |
logger.error(f"Error while asking: {e}") | |
raise e | |
def indexing(self, uploaded_files: list[str], embed_model: str, indexing_bar: Optional[st.progress] = None) -> bool: | |
self.set_context(embed_model) | |
assert self.client != None | |
assert self.db_collection != None | |
assert self.SAVE_DIR != None | |
assert self.embed_model != None | |
assert len(uploaded_files) > 0 | |
# Write files to disk | |
for file in uploaded_files : | |
path = os.path.join(self.SAVE_DIR, file.name) | |
if os.path.exists(path): | |
print("File existed, skip") | |
continue | |
with open(path, "wb") as f: | |
f.write(file.getvalue()) | |
image_path_list = pdf_folder_to_images(pdf_folder=self.SAVE_DIR, output_folder=self.SAVE_IMAGE_DIR) | |
logger.debug(f"Extracted {len(image_path_list)} images from {len(uploaded_files)} files.") | |
indexed_images = self.index_from_images(image_path_list, indexing_bar=indexing_bar) | |
logger.debug(f"Indexed {len(indexed_images)} images.") | |
self.indexed_images.extend(indexed_images) | |
return True | |
def clear_context(self): | |
self.indexed_images = [] | |
self.vector_db_client.delete_collection(self.db_collection) | |
self.db_collection = None | |
self.vector_db_client = None | |
if self.SAVE_DIR: | |
if os.path.exists(self.SAVE_DIR): | |
subprocess.run(['rm', '-rf', self.SAVE_DIR]) | |
logger.debug(f'Removed folder: {self.SAVE_DIR}') | |
self.SAVE_DIR = None | |
return True | |
def __del__(self): | |
self.clear_context() | |
logger.debug('VDR session is cleaned up.') | |
def index_from_images(self, | |
images_path_list: list, | |
batch_size: int =5, | |
indexing_bar: Optional[st.progress] = None | |
): | |
try: | |
indexed_images = [] | |
total_len = len(images_path_list) | |
with tqdm(total=total_len, desc="Indexing Progress") as pbar: | |
for i in range(0, total_len, batch_size): | |
try: | |
batch = images_path_list[i:min(i+batch_size,total_len)] | |
#batch = load_images(batch) | |
batch = [scale_image(x, 768) for x in batch] | |
embeddings = get_image_embedding( | |
image_list=batch, | |
openai_client=self.client, | |
model=self.embed_model | |
) | |
self.vector_db_client.upsert_multivector( | |
index=i, | |
multivector_input_list=embeddings, | |
collection_name=self.db_collection | |
) | |
indexed_images.extend(batch) | |
# Update the progress bar | |
pbar.update(batch_size) | |
indexing_bar.progress(i/total_len, text=f"Indexing {i}/{total_len}") | |
except Exception as e: | |
logger.exception(f"Error during indexing: {e}") | |
continue | |
return indexed_images | |
logger.debug("Indexing complete!") | |
except Exception as e: | |
raise Exception(f"Error during indexing: {e}") | |