|
import gradio as gr |
|
import tempfile |
|
import os |
|
import fitz |
|
import uuid |
|
import shutil |
|
from pymilvus import MilvusClient |
|
|
|
from middleware import Middleware |
|
from rag import Rag |
|
from pathlib import Path |
|
import subprocess |
|
import getpass |
|
|
|
from dotenv import load_dotenv, dotenv_values |
|
import dotenv |
|
import platform |
|
import time |
|
|
|
|
|
dotenv_file = dotenv.find_dotenv() |
|
dotenv.load_dotenv(dotenv_file) |
|
|
|
|
|
|
|
|
|
rag = Rag() |
|
|
|
|
|
|
|
def generate_uuid(state): |
|
|
|
if state["user_uuid"] is None: |
|
|
|
state["user_uuid"] = str(uuid.uuid4()) |
|
|
|
return state["user_uuid"] |
|
|
|
|
|
class PDFSearchApp: |
|
def __init__(self): |
|
self.indexed_docs = {} |
|
self.current_pdf = None |
|
|
|
def upload_and_convert(self, state, files, max_pages): |
|
|
|
|
|
|
|
|
|
pages = 0 |
|
|
|
if files is None: |
|
return "No file uploaded" |
|
try: |
|
for file in files[:]: |
|
|
|
|
|
filename = os.path.basename(file.name) |
|
|
|
|
|
name, ext = os.path.splitext(filename) |
|
self.current_pdf = file.name |
|
pdf_path=file.name |
|
|
|
|
|
|
|
|
|
modified_filename = name.replace(" ", "_").replace("-", "_") |
|
|
|
id = modified_filename |
|
|
|
print(f"Uploading file: {id}, id: abc") |
|
middleware = Middleware(modified_filename, create_collection=True) |
|
|
|
|
|
pages = middleware.index(pdf_path, id=id, max_pages=max_pages) |
|
|
|
|
|
self.indexed_docs[id] = True |
|
|
|
|
|
files = [] |
|
return f"Uploaded and extracted all pages" |
|
except Exception as e: |
|
return f"Error processing PDF: {str(e)}" |
|
|
|
|
|
def display_file_list(text): |
|
try: |
|
|
|
directory_path = "pages" |
|
current_working_directory = os.getcwd() |
|
directory_path = os.path.join(current_working_directory, directory_path) |
|
entries = os.listdir(directory_path) |
|
|
|
directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))] |
|
return directories |
|
except FileNotFoundError: |
|
return f"The directory {directory_path} does not exist." |
|
except PermissionError: |
|
return f"Permission denied to access {directory_path}." |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def search_documents(self, state, query, num_results=1): |
|
print(f"Searching for query: {query}") |
|
|
|
id = "test" |
|
|
|
""" |
|
if not self.indexed_docs[id]: |
|
print("Please index documents first") |
|
return "Please index documents first", "--" |
|
""" |
|
if not query: |
|
print("Please enter a search query") |
|
return "Please enter a search query", "--" |
|
try: |
|
|
|
middleware = Middleware(id, create_collection=False) |
|
|
|
search_results = middleware.search([query])[0] |
|
|
|
|
|
|
|
page_num = search_results[0][1] + 1 |
|
coll_num = search_results[0][2] |
|
|
|
print(f"Retrieved page number: {page_num}") |
|
|
|
img_path = f"pages/{coll_num}/page_{page_num}.png" |
|
path = f"pages/{coll_num}/page_{page_num}" |
|
|
|
print(f"Retrieved image path: {img_path}") |
|
|
|
rag_response = rag.get_answer_from_gemini(query, [img_path]) |
|
|
|
return path,img_path, rag_response |
|
|
|
except Exception as e: |
|
return f"Error during search: {str(e)}", "--" |
|
|
|
def delete(state,choice): |
|
|
|
|
|
client = MilvusClient("./milvus_demo.db") |
|
path = f"pages/{choice}" |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
|
|
client.drop_collection(collection_name=choice) |
|
return f"Deleted {choice}" |
|
else: |
|
return "Directory not found" |
|
def dbupdate(state,metric_type,m_num,ef_num,topk): |
|
os.environ['metrictype'] = metric_type |
|
|
|
dotenv.set_key(dotenv_file, 'metrictype', metric_type) |
|
os.environ['mnum'] = str(m_num) |
|
dotenv.set_key(dotenv_file, 'mnum', str(m_num)) |
|
os.environ['efnum'] = str(ef_num) |
|
dotenv.set_key(dotenv_file, 'efnum', str(ef_num)) |
|
os.environ['topk'] = str(topk) |
|
dotenv.set_key(dotenv_file, 'topk', str(topk)) |
|
|
|
return "DB Settings Updated, Restart App To Load" |
|
|
|
def list_downloaded_hf_models(state): |
|
|
|
hf_cache_dir = Path(os.getenv('HF_HOME', Path.home() / '.cache/huggingface/hub')) |
|
|
|
|
|
model_names = [] |
|
|
|
|
|
for repo_dir in hf_cache_dir.glob('models--*'): |
|
|
|
model_name = repo_dir.name.split('--', 1)[-1].replace('--', '/') |
|
model_names.append(model_name) |
|
|
|
return model_names |
|
|
|
|
|
def list_downloaded_ollama_models(state): |
|
|
|
username = getpass.getuser() |
|
|
|
|
|
|
|
base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" |
|
|
|
try: |
|
|
|
with os.scandir(base_path) as entries: |
|
|
|
directories = [entry.name for entry in entries if entry.is_dir()] |
|
|
|
return directories |
|
except FileNotFoundError: |
|
print(f"The directory {base_path} does not exist.") |
|
except PermissionError: |
|
print(f"Permission denied to access {base_path}.") |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
|
|
def model_settings(state,hfchoice, ollamachoice,flash, temp): |
|
os.environ['colpali'] = hfchoice |
|
|
|
dotenv.set_key(dotenv_file, 'colpali', hfchoice) |
|
os.environ['ollama'] = ollamachoice |
|
dotenv.set_key(dotenv_file, 'ollama', ollamachoice) |
|
if flash == "Enabled": |
|
os.environ['flashattn'] = "1" |
|
dotenv.set_key(dotenv_file, 'flashattn', "1") |
|
else: |
|
os.environ['flashattn'] = "0" |
|
dotenv.set_key(dotenv_file, 'flashattn', "0") |
|
os.environ['temperature'] = str(temp) |
|
dotenv.set_key(dotenv_file, 'temperature', str(temp)) |
|
|
|
return "Models Updated, Restart App To Use New Settings" |
|
|
|
|
|
|
|
def create_ui(): |
|
app = PDFSearchApp() |
|
|
|
with gr.Blocks(theme=gr.themes.Ocean(),css ="footer{display:none !important}") as demo: |
|
state = gr.State(value={"user_uuid": None}) |
|
|
|
|
|
gr.Markdown("# Collar Multimodal RAG Demo") |
|
gr.Markdown("Settings Available On Local Offline Setup") |
|
|
|
with gr.Tab("Upload PDF"): |
|
with gr.Column(): |
|
max_pages_input = gr.Slider( |
|
minimum=1, |
|
maximum=10000, |
|
value=20, |
|
step=10, |
|
label="Max pages to extract and index per document" |
|
) |
|
file_input = gr.Files(label="Upload PDFs") |
|
file_list = gr.Textbox(label="Uploaded Files", interactive=False, value="Available on Local Setup") |
|
status = gr.Textbox(label="Indexing Status", interactive=False) |
|
|
|
|
|
with gr.Tab("Query"): |
|
with gr.Column(): |
|
query_input = gr.Textbox(label="Enter query") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search_btn = gr.Button("Query") |
|
llm_answer = gr.Textbox(label="RAG Response", interactive=False) |
|
path = gr.Textbox(label="Link To Document Page", interactive=False) |
|
images = gr.Image(label="Top page matching query") |
|
with gr.Tab("Data Settings"): |
|
with gr.Column(): |
|
|
|
choice = gr.Dropdown(list(app.display_file_list()),label="Choice") |
|
status1 = gr.Textbox(label="Deletion Status", interactive=False) |
|
delete_button = gr.Button("Delete Document From DB") |
|
|
|
|
|
|
|
metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)") |
|
m_num = gr.Dropdown( |
|
choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)") |
|
ef_num = gr.Slider( |
|
minimum=50, |
|
maximum=1000, |
|
value=500, |
|
step=10, |
|
label="EF Construction (Number of candidate neighbors considered for connection during index construction)" |
|
) |
|
topk = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
label="Top-K (Maximum number of entities to return in a single search of a document)" |
|
) |
|
db_button = gr.Button("Update DB Settings") |
|
status3 = gr.Textbox(label="DB Update Status", interactive=False) |
|
|
|
|
|
with gr.Tab("AI Model Settings"): |
|
with gr.Column(): |
|
|
|
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),value=os.environ['colpali'], label="Primary Visual Model") |
|
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),value=os.environ['ollama'],label="Secondary Visual Retrieval-Augmented Generation (RAG) Model") |
|
flash = gr.Dropdown(["Enabled","Disabled"], value = "Enabled",label ="Flash Attention 2.0 Acceleration") |
|
temp = gr.Slider( |
|
minimum=0.1, |
|
maximum=1, |
|
value=0.8, |
|
step=0.1, |
|
label="RAG Temperature" |
|
) |
|
model_button = gr.Button("Update Settings") |
|
status2 = gr.Textbox(label="Update Status", interactive=False) |
|
|
|
|
|
file_input.change( |
|
fn=app.upload_and_convert, |
|
inputs=[state, file_input, max_pages_input], |
|
outputs=[status] |
|
) |
|
|
|
search_btn.click( |
|
|
|
fn= app.search_documents, |
|
inputs=[state, query_input], |
|
outputs=[path,images, llm_answer] |
|
) |
|
""" |
|
delete_button.click( |
|
fn=app.delete, |
|
inputs=[choice], |
|
outputs=[status1] |
|
) |
|
|
|
db_button.click( |
|
fn=app.dbupdate, |
|
inputs=[metric_type,m_num,ef_num,topk], |
|
outputs=[status3] |
|
) |
|
|
|
model_button.click( |
|
fn=app.model_settings, |
|
inputs=[hfchoice, ollamachoice,flash,temp], |
|
outputs=[status2] |
|
) |
|
""" |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_ui() |
|
|
|
demo.launch() |
|
|
|
|