import gradio as gr import os from concurrent.futures import ThreadPoolExecutor from langchain_community.vectorstores import Chroma from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFaceEndpoint from langchain.memory import ConversationBufferMemory from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever # Environment variable for API token api_token = os.getenv("API_TOKEN") print(f"API Token loaded: {api_token[:5]}...") # Debug if not api_token: raise ValueError("Environment variable 'FirstToken' not set.") # Available LLM models list_llm = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/deepseek-llm-7b-chat" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # ----------------------------------------------------------------------------- # Document Loading and Splitting (Optimized with Threading) # ----------------------------------------------------------------------------- def load_single_pdf(file_path): """Load a single PDF file.""" loader = PyPDFLoader(file_path) return loader.load() def load_doc(list_file_path, progress=gr.Progress()): """Load and split PDF documents into chunks with multi-threading.""" if not list_file_path: raise ValueError("No files provided for processing.") # Use ThreadPoolExecutor to parallelize PDF loading with ThreadPoolExecutor() as executor: pages = list(executor.map(load_single_pdf, list_file_path)) pages = [page for sublist in pages for page in sublist] # Flatten list progress(0.5, "Splitting documents...") text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) # Increased chunk size doc_splits = text_splitter.split_documents(pages) return doc_splits # ----------------------------------------------------------------------------- # Vector Database Creation (Optimized with Lightweight Embeddings) # ----------------------------------------------------------------------------- def create_chromadb(splits, persist_directory="chroma_db", progress=gr.Progress()): """Create ChromaDB vector database with optimized embeddings.""" # Use a lighter embedding model embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") progress(0.7, "Creating vector database...") chromadb = Chroma.from_documents( documents=splits, embedding=embeddings, persist_directory=persist_directory ) return chromadb # ----------------------------------------------------------------------------- # Retrievers # ----------------------------------------------------------------------------- def create_bm25_retriever(splits): """Create BM25 retriever from document splits.""" retriever = BM25Retriever.from_documents(splits) retriever.k = 2 # Reduced to 2 documents for faster retrieval return retriever def create_ensemble_retriever(vector_db, bm25_retriever): """Create an ensemble retriever.""" return EnsembleRetriever( retrievers=[vector_db.as_retriever(search_kwargs={"k": 2}), bm25_retriever], # Limit to 2 docs weights=[0.7, 0.3] ) # ----------------------------------------------------------------------------- # Initialize Database # ----------------------------------------------------------------------------- def initialize_database(list_file_obj, progress=gr.Progress()): """Initialize the document database with error handling.""" try: list_file_path = [x.name for x in list_file_obj if x is not None] progress(0.1, "Loading documents...") doc_splits = load_doc(list_file_path, progress) chromadb = create_chromadb(doc_splits, progress=progress) bm25_retriever = create_bm25_retriever(doc_splits) ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever) progress(1.0, "Database creation complete!") return ensemble_retriever, "Database created successfully!" except Exception as e: return None, f"Error initializing database: {str(e)}" # ----------------------------------------------------------------------------- # Initialize LLM Chain # ----------------------------------------------------------------------------- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever): """Initialize the language model chain.""" if retriever is None: raise ValueError("Retriever is None. Please process documents first.") try: print(f"Initializing LLM: {llm_model} with token: {api_token[:5]}...") llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, task="text-generation" ) memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", return_messages=True ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False ) return qa_chain except Exception as e: raise RuntimeError(f"Failed to initialize LLM chain: {str(e)}") # ----------------------------------------------------------------------------- # Initialize LLM # ----------------------------------------------------------------------------- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()): """Initialize the Language Model.""" if retriever is None: return None, "Error: No database initialized. Please process documents first." try: llm_name = list_llm[llm_option] print(f"Selected LLM model: {llm_name}") qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, retriever) return qa_chain, "Analysis Assistant initialized and ready!" except Exception as e: return None, f"Error initializing LLM: {str(e)}" # ----------------------------------------------------------------------------- # Chat History Formatting # ----------------------------------------------------------------------------- def format_chat_history(message, chat_history): """Format chat history for the model.""" return [f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in chat_history] # ----------------------------------------------------------------------------- # Conversation Function # ----------------------------------------------------------------------------- def conversation(qa_chain, message, history, lang): """Handle conversation and document analysis.""" if not qa_chain: return None, gr.update(value="Assistant not initialized"), history, "", 0, "", 0, "", 0 lang_instruction = " (Responda em Português)" if lang == "pt" else " (Respond in English)" query = message + lang_instruction try: formatted_chat_history = format_chat_history(message, history) response = qa_chain.invoke({"question": query, "chat_history": formatted_chat_history}) answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"] sources = response["source_documents"] source_data = [("Unknown", 0)] * 3 for i, doc in enumerate(sources[:3]): source_data[i] = (doc.page_content.strip(), doc.metadata["page"] + 1) new_history = history + [(message, answer)] return ( qa_chain, gr.update(value=""), new_history, source_data[0][0], source_data[0][1], source_data[1][0], source_data[1][1], source_data[2][0], source_data[2][1] ) except Exception as e: return qa_chain, gr.update(value=f"Error: {str(e)}"), history, "", 0, "", 0, "", 0 # ----------------------------------------------------------------------------- # Gradio Demo # ----------------------------------------------------------------------------- def demo(): """Main demo application with enhanced layout.""" theme = gr.themes.Default(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate") custom_css = """ .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);} .header {text-align: center; margin-bottom: 2rem;} .header h1 {color: #1a365d; font-size: 2.5rem; margin-bottom: 0.5rem;} .section {margin-bottom: 1.5rem; padding: 1rem; background: #f8fafc; border-radius: 8px;} """ with gr.Blocks(theme=theme, css=custom_css) as demo: retriever = gr.State() qa_chain = gr.State() language = gr.State(value="en") gr.HTML( '

MetroAssist AI

Expert System for Metrology Report Analysis

' ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Document Processing") with gr.Column(elem_classes="section"): document = gr.Files(label="Metrology Reports (PDF)", file_count="multiple", file_types=["pdf"]) db_btn = gr.Button("Process Documents") db_progress = gr.Textbox(value="Ready for documents", label="Processing Status") gr.Markdown("## Model Configuration") with gr.Column(elem_classes="section"): llm_btn = gr.Radio(choices=list_llm_simple, label="Select AI Model", value=list_llm_simple[0], type="index") language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English") with gr.Accordion("Advanced Settings", open=False): slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision") slider_maxtokens = gr.Slider(128, 2048, value=1024, step=128, label="Response Length") # Reduced max_tokens slider_topk = gr.Slider(1, 5, value=3, step=1, label="Analysis Diversity") # Reduced range qachain_btn = gr.Button("Initialize Assistant", interactive=False) llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status") with gr.Column(scale=2): gr.Markdown("## Interactive Analysis") chatbot = gr.Chatbot(height=400, label="Analysis Conversation") with gr.Row(): msg = gr.Textbox(placeholder="Ask about your metrology report...", label="Query") submit_btn = gr.Button("Send") clear_btn = gr.ClearButton([msg, chatbot], value="Clear") with gr.Accordion("Document References", open=False): with gr.Row(): doc_source1, source1_page = gr.Textbox(label="Reference 1", lines=2), gr.Number(label="Page") doc_source2, source2_page = gr.Textbox(label="Reference 2", lines=2), gr.Number(label="Page") doc_source3, source3_page = gr.Textbox(label="Reference 3", lines=2), gr.Number(label="Page") # Event Handlers language_btn.change(lambda x: "en" if x == "English" else "pt", inputs=language_btn, outputs=language) def enable_qachain_btn(retriever, status): return gr.update(interactive=retriever is not None and "successfully" in status) db_btn.click( initialize_database, inputs=[document], outputs=[retriever, db_progress] ).then( enable_qachain_btn, inputs=[retriever, db_progress], outputs=[qachain_btn] ) qachain_btn.click( initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever], outputs=[qa_chain, llm_progress] ) submit_btn.click( conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page] ) msg.submit( conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page] ) demo.launch(debug=True) if __name__ == "__main__": demo()