demo / app.py
Kazel
lib
c575567
import gradio as gr
import tempfile
import os
import fitz # PyMuPDF
import uuid
import shutil
from pymilvus import MilvusClient
from middleware import Middleware
from rag import Rag
from pathlib import Path
import subprocess
import getpass
# importing necessary functions from dotenv library
from dotenv import load_dotenv, dotenv_values
import dotenv
import platform
import time
# loading variables from .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
#kickstart docker and ollama servers
rag = Rag()
def generate_uuid(state):
# Check if UUID already exists in session state
if state["user_uuid"] is None:
# Generate a new UUID if not already set
state["user_uuid"] = str(uuid.uuid4())
return state["user_uuid"]
class PDFSearchApp:
def __init__(self):
self.indexed_docs = {}
self.current_pdf = None
def upload_and_convert(self, state, files, max_pages):
#change id
#id = generate_uuid(state)
pages = 0
if files is None:
return "No file uploaded"
try: #if onlyy one file
for file in files[:]: # Iterate over a shallow copy of the list, TEST THIS
# Extract the last part of the path (file name)
filename = os.path.basename(file.name)
# Split the base name into name and extension
name, ext = os.path.splitext(filename)
self.current_pdf = file.name
pdf_path=file.name
#if ppt will get replaced with path of ppt!
# Replace spaces and hyphens with underscores in the name
modified_filename = name.replace(" ", "_").replace("-", "_")
id = modified_filename #if string cmi then serialize the name, test for later
print(f"Uploading file: {id}, id: abc")
middleware = Middleware(modified_filename, create_collection=True)
pages = middleware.index(pdf_path, id=id, max_pages=max_pages)
self.indexed_docs[id] = True
#clear files for next consec upload after loop is complete
files = []
return f"Uploaded and extracted all pages"
except Exception as e:
return f"Error processing PDF: {str(e)}"
def display_file_list(text):
try:
# Retrieve all entries in the specified directory
directory_path = "pages"
current_working_directory = os.getcwd()
directory_path = os.path.join(current_working_directory, directory_path)
entries = os.listdir(directory_path)
# Filter out entries that are directories
directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))]
return directories
except FileNotFoundError:
return f"The directory {directory_path} does not exist."
except PermissionError:
return f"Permission denied to access {directory_path}."
except Exception as e:
return str(e)
def search_documents(self, state, query, num_results=1):
print(f"Searching for query: {query}")
#id = generate_uuid(state)
id = "test" # not used anyway
"""
if not self.indexed_docs[id]:
print("Please index documents first")
return "Please index documents first", "--"
""" #edited out to allow direct query on db to test persistency
if not query:
print("Please enter a search query")
return "Please enter a search query", "--"
try:
middleware = Middleware(id, create_collection=False)
search_results = middleware.search([query])[0]
#direct retrieve file path rather than rely on page nums!
#try to retrieve multiple files rather than a single page (TBD)
page_num = search_results[0][1] + 1 # final return value is a list of tuples, each tuple being: (score, doc_id, collection_name), so use [0][2] to get collection name of first ranked item
coll_num = search_results[0][2]
print(f"Retrieved page number: {page_num}")
img_path = f"pages/{coll_num}/page_{page_num}.png"
path = f"pages/{coll_num}/page_{page_num}"
print(f"Retrieved image path: {img_path}")
rag_response = rag.get_answer_from_gemini(query, [img_path])
return path,img_path, rag_response
except Exception as e:
return f"Error during search: {str(e)}", "--"
def delete(state,choice):
#delete file in pages, then use middleware to delete collection
# 1. Create a milvus client
client = MilvusClient("./milvus_demo.db")
path = f"pages/{choice}"
if os.path.exists(path):
shutil.rmtree(path)
#call milvus manager to delete collection
client.drop_collection(collection_name=choice)
return f"Deleted {choice}"
else:
return "Directory not found"
def dbupdate(state,metric_type,m_num,ef_num,topk):
os.environ['metrictype'] = metric_type
# Update the .env file with the new value
dotenv.set_key(dotenv_file, 'metrictype', metric_type)
os.environ['mnum'] = str(m_num)
dotenv.set_key(dotenv_file, 'mnum', str(m_num))
os.environ['efnum'] = str(ef_num)
dotenv.set_key(dotenv_file, 'efnum', str(ef_num))
os.environ['topk'] = str(topk)
dotenv.set_key(dotenv_file, 'topk', str(topk))
return "DB Settings Updated, Restart App To Load"
def list_downloaded_hf_models(state):
# Determine the cache directory
hf_cache_dir = Path(os.getenv('HF_HOME', Path.home() / '.cache/huggingface/hub'))
# Initialize a list to store model names
model_names = []
# Traverse the cache directory
for repo_dir in hf_cache_dir.glob('models--*'):
# Extract the model name from the directory structure
model_name = repo_dir.name.split('--', 1)[-1].replace('--', '/')
model_names.append(model_name)
return model_names
def list_downloaded_ollama_models(state):
# Retrieve the current user's name
username = getpass.getuser()
# Construct the target directory path
#base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library" #this is for if ollama pull is called from C://, if ollama pulls are called from the proj dir, use the NEW_PATH in the proj dir!
base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" #relative to proj dir! (IMPT: OLLAMA PULL COMMAND IN PROJ DIR!!!)
try:
# List all entries in the directory
with os.scandir(base_path) as entries:
# Filter and print only directories
directories = [entry.name for entry in entries if entry.is_dir()]
return directories
except FileNotFoundError:
print(f"The directory {base_path} does not exist.")
except PermissionError:
print(f"Permission denied to access {base_path}.")
except Exception as e:
print(f"An error occurred: {e}")
def model_settings(state,hfchoice, ollamachoice,flash, temp):
os.environ['colpali'] = hfchoice
# Update the .env file with the new value
dotenv.set_key(dotenv_file, 'colpali', hfchoice)
os.environ['ollama'] = ollamachoice
dotenv.set_key(dotenv_file, 'ollama', ollamachoice)
if flash == "Enabled":
os.environ['flashattn'] = "1"
dotenv.set_key(dotenv_file, 'flashattn', "1")
else:
os.environ['flashattn'] = "0"
dotenv.set_key(dotenv_file, 'flashattn', "0")
os.environ['temperature'] = str(temp)
dotenv.set_key(dotenv_file, 'temperature', str(temp))
return "Models Updated, Restart App To Use New Settings"
def create_ui():
app = PDFSearchApp()
with gr.Blocks(theme=gr.themes.Ocean(),css ="footer{display:none !important}") as demo:
state = gr.State(value={"user_uuid": None})
gr.Markdown("# Collar Multimodal RAG Demo")
gr.Markdown("Settings Available On Local Offline Setup")
with gr.Tab("Upload PDF"):
with gr.Column():
max_pages_input = gr.Slider(
minimum=1,
maximum=10000,
value=20,
step=10,
label="Max pages to extract and index per document"
)
file_input = gr.Files(label="Upload PDFs")
file_list = gr.Textbox(label="Uploaded Files", interactive=False, value="Available on Local Setup")
status = gr.Textbox(label="Indexing Status", interactive=False)
with gr.Tab("Query"):
with gr.Column():
query_input = gr.Textbox(label="Enter query")
#num_results = gr.Slider(
# minimum=1,
# maximum=10,
# value=5,
# step=1,
# label="Number of results"
#)
search_btn = gr.Button("Query")
llm_answer = gr.Textbox(label="RAG Response", interactive=False)
path = gr.Textbox(label="Link To Document Page", interactive=False)
images = gr.Image(label="Top page matching query")
with gr.Tab("Data Settings"): #deletion of collections, changing of model parameters etc
with gr.Column():
# Button to delete (TBD)
choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
status1 = gr.Textbox(label="Deletion Status", interactive=False)
delete_button = gr.Button("Delete Document From DB")
# Create the dropdown component with default value as the first option
#Milvusindex = gr.Dropdown(["HNSW","FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "RHNSW_FLAT"], value="HNSW", label="Select Vector DB Index Parameter")
metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)")
m_num = gr.Dropdown(
choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)")
ef_num = gr.Slider(
minimum=50,
maximum=1000,
value=500,
step=10,
label="EF Construction (Number of candidate neighbors considered for connection during index construction)"
)
topk = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-K (Maximum number of entities to return in a single search of a document)"
)
db_button = gr.Button("Update DB Settings")
status3 = gr.Textbox(label="DB Update Status", interactive=False)
with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
with gr.Column():
# Button to delete (TBD)
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),value=os.environ['colpali'], label="Primary Visual Model")
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),value=os.environ['ollama'],label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
flash = gr.Dropdown(["Enabled","Disabled"], value = "Enabled",label ="Flash Attention 2.0 Acceleration")
temp = gr.Slider(
minimum=0.1,
maximum=1,
value=0.8,
step=0.1,
label="RAG Temperature"
)
model_button = gr.Button("Update Settings")
status2 = gr.Textbox(label="Update Status", interactive=False)
# Event handlers
file_input.change(
fn=app.upload_and_convert,
inputs=[state, file_input, max_pages_input],
outputs=[status]
)
search_btn.click(
#try to query without uploading first
fn= app.search_documents,
inputs=[state, query_input],
outputs=[path,images, llm_answer]
)
"""
delete_button.click(
fn=app.delete,
inputs=[choice],
outputs=[status1]
)
db_button.click(
fn=app.dbupdate,
inputs=[metric_type,m_num,ef_num,topk],
outputs=[status3]
)
model_button.click(
fn=app.model_settings,
inputs=[hfchoice, ollamachoice,flash,temp],
outputs=[status2]
)
"""
return demo
if __name__ == "__main__":
demo = create_ui()
#demo.launch(auth=("admin", "pass1234")) for with login page config
demo.launch()