|
import httpx |
|
import os |
|
import time |
|
import subprocess |
|
import uuid |
|
from loguru import logger |
|
from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict |
|
import httpx |
|
import os |
|
import time |
|
import subprocess |
|
import uuid |
|
import streamlit as st |
|
from openai import OpenAI |
|
import base64 |
|
from tqdm import tqdm |
|
|
|
from app.config import app_settings |
|
|
|
from app.qdrant_db import MyQdrantClient |
|
|
|
from app.vdr_utils import ( |
|
get_text_embedding, |
|
get_image_embedding, |
|
pdf_folder_to_images, |
|
scale_image, |
|
pil_image_to_base64, |
|
load_images, |
|
) |
|
|
|
class VDRSession: |
|
def __init__(self): |
|
self.client = None |
|
self.api_key = None |
|
self.base_url = app_settings.GLOBAL_API_BASE |
|
self.SAVE_DIR = None |
|
self.db_collection = None |
|
self.session_id = str(uuid.uuid4())[:5] |
|
self.indexed_images = [] |
|
self.model_name_list = [] |
|
self.vector_db_client = None |
|
|
|
def set_api_key(self, api_key: str): |
|
if api_key is not None and len(api_key)>10: |
|
try: |
|
api_key = api_key.strip() |
|
client = OpenAI(api_key=api_key, |
|
base_url=self.base_url) |
|
models = client.models.list() |
|
if models: |
|
self.api_key = api_key |
|
self.client = client |
|
return True |
|
except Exception as e: |
|
logger.debug(f'Incorrect API Key: {e}') |
|
|
|
self.client = None |
|
return False |
|
|
|
def set_context(self, embed_model: str): |
|
self.embed_model = embed_model |
|
|
|
if not self.SAVE_DIR: |
|
self.SAVE_DIR=os.path.join('./temp_data', self.session_id) |
|
os.makedirs(self.SAVE_DIR, exist_ok=True) |
|
self.SAVE_IMAGE_DIR=os.path.join(self.SAVE_DIR, 'images') |
|
logger.debug(f'Created folder: {self.SAVE_DIR} and {self.SAVE_IMAGE_DIR}') |
|
|
|
if not self.vector_db_client: |
|
self.vector_db_client = MyQdrantClient(path=self.SAVE_DIR) |
|
|
|
if not self.db_collection: |
|
self.db_collection = f"qd-{embed_model}-{self.session_id}" |
|
try: |
|
if self.embed_model == "tsi-embedding-colqwen2-2b-v1": |
|
self.vector_db_client.create_collection(self.db_collection, vector_dim=128, vector_type="colbert") |
|
elif self.embed_model == "jina-embedding-clip-v1": |
|
self.vector_db_client.create_collection(self.db_collection, vector_dim=768, vector_type="dense") |
|
else: |
|
raise ValueError(f"Embedding model {self.embed_model} not supported") |
|
except Exception as e: |
|
logger.error(f"Error while creating collection: {e}") |
|
|
|
return True |
|
|
|
def get_available_vlms(self) -> List[str]: |
|
assert self.client != None |
|
|
|
if self.model_name_list: |
|
return self.model_name_list |
|
try: |
|
models = self.client.models.list() |
|
for model in models.data: |
|
model_name = model.id |
|
substrings = ['gemini-2.0','claude','Qwen2.5-VL-72B-Instruct'] |
|
if any(substring in model_name for substring in substrings): |
|
self.model_name_list.append(model.id) |
|
|
|
except Exception as e: |
|
logger.error(f"Error while query all models: {e}") |
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.model_name_list |
|
|
|
def get_available_image_embeds(self) -> List[str]: |
|
assert self.client != None |
|
model_name_list = [] |
|
try: |
|
models = self.client.models.list() |
|
for model in models.data: |
|
model_name = model.id |
|
substrings = ['tsi-embedding','clip'] |
|
if any(substring in model_name for substring in substrings): |
|
model_name_list.append(model.id) |
|
|
|
except Exception as e: |
|
logger.error(f"Error while query all models: {e}") |
|
raise e |
|
|
|
return model_name_list |
|
|
|
def search_images(self, text: str, top_k: int = 5) -> list[str]: |
|
assert self.client != None |
|
assert self.vector_db_client != None |
|
try: |
|
if not self.indexed_images: |
|
raise Exception("No indexed images found. You need to click on 'Add selected context' button to index images.") |
|
text = text.strip() |
|
if len(text) < 2: |
|
return False |
|
|
|
embeddings = get_text_embedding( |
|
texts=text, |
|
openai_client=self.client, |
|
model=self.embed_model |
|
)[0] |
|
|
|
index_results = self.vector_db_client.query_multivector( |
|
multivector_input=embeddings, |
|
collection_name=self.db_collection, |
|
top_k=top_k |
|
) |
|
image_list=[self.indexed_images[i] for i in index_results] |
|
images = [] |
|
for img in image_list: |
|
|
|
|
|
encoded = pil_image_to_base64(img) |
|
images.append(f"data:image/png;base64,{encoded}") |
|
return images |
|
except Exception as e: |
|
logger.error(f"Error while generating image: {e}") |
|
raise e |
|
|
|
def ask(self, query: str, model: str, prompt_template: str, retrieved_context: Any, modality: str = "image", stream: bool = False) -> str: |
|
assert self.client != None |
|
assert query != None |
|
assert prompt_template != None |
|
assert retrieved_context != None |
|
|
|
try: |
|
prompt = prompt_template.format(user_question=query) |
|
if modality == "image": |
|
context = [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": base64_image |
|
} |
|
} for base64_image in retrieved_context |
|
] |
|
|
|
content = [ |
|
{ |
|
"type": "text", |
|
"text": prompt |
|
} |
|
] |
|
content=content+context |
|
|
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": content, |
|
} |
|
] |
|
|
|
chat_response = self.client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
temperature=0.1, |
|
max_tokens=2048, |
|
stream=stream, |
|
) |
|
if not stream: |
|
return chat_response.choices[0].message.content |
|
else: |
|
for chunk in chat_response: |
|
if chunk.choices: |
|
if chunk.choices[0].delta.content is not None: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error while asking: {e}") |
|
raise e |
|
|
|
def indexing(self, uploaded_files: list[str], embed_model: str, indexing_bar: Optional[st.progress] = None) -> bool: |
|
self.set_context(embed_model) |
|
|
|
assert self.client != None |
|
assert self.db_collection != None |
|
assert self.SAVE_DIR != None |
|
assert self.embed_model != None |
|
assert len(uploaded_files) > 0 |
|
|
|
|
|
for file in uploaded_files : |
|
path = os.path.join(self.SAVE_DIR, file.name) |
|
if os.path.exists(path): |
|
print("File existed, skip") |
|
continue |
|
with open(path, "wb") as f: |
|
f.write(file.getvalue()) |
|
|
|
image_path_list = pdf_folder_to_images(pdf_folder=self.SAVE_DIR, output_folder=self.SAVE_IMAGE_DIR) |
|
logger.debug(f"Extracted {len(image_path_list)} images from {len(uploaded_files)} files.") |
|
|
|
indexed_images = self.index_from_images(image_path_list, indexing_bar=indexing_bar) |
|
logger.debug(f"Indexed {len(indexed_images)} images.") |
|
|
|
self.indexed_images.extend(indexed_images) |
|
return True |
|
|
|
def clear_context(self): |
|
self.indexed_images = [] |
|
self.vector_db_client.delete_collection(self.db_collection) |
|
self.db_collection = None |
|
self.vector_db_client = None |
|
|
|
if self.SAVE_DIR: |
|
if os.path.exists(self.SAVE_DIR): |
|
subprocess.run(['rm', '-rf', self.SAVE_DIR]) |
|
logger.debug(f'Removed folder: {self.SAVE_DIR}') |
|
self.SAVE_DIR = None |
|
return True |
|
|
|
def __del__(self): |
|
self.clear_context() |
|
logger.debug('VDR session is cleaned up.') |
|
|
|
def index_from_images(self, |
|
images_path_list: list, |
|
batch_size: int =5, |
|
indexing_bar: Optional[st.progress] = None |
|
): |
|
try: |
|
indexed_images = [] |
|
total_len = len(images_path_list) |
|
with tqdm(total=total_len, desc="Indexing Progress") as pbar: |
|
for i in range(0, total_len, batch_size): |
|
try: |
|
batch = images_path_list[i:min(i+batch_size,total_len)] |
|
|
|
batch = [scale_image(x, 768) for x in batch] |
|
|
|
embeddings = get_image_embedding( |
|
image_list=batch, |
|
openai_client=self.client, |
|
model=self.embed_model |
|
) |
|
self.vector_db_client.upsert_multivector( |
|
index=i, |
|
multivector_input_list=embeddings, |
|
collection_name=self.db_collection |
|
) |
|
|
|
indexed_images.extend(batch) |
|
|
|
pbar.update(batch_size) |
|
indexing_bar.progress(i/total_len, text=f"Indexing {i}/{total_len}") |
|
except Exception as e: |
|
logger.exception(f"Error during indexing: {e}") |
|
continue |
|
|
|
return indexed_images |
|
|
|
logger.debug("Indexing complete!") |
|
except Exception as e: |
|
raise Exception(f"Error during indexing: {e}") |
|
|
|
|