import gradio as gr import os import torch from langchain_community.vectorstores import Chroma from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFaceEndpoint from langchain.memory import ConversationBufferMemory from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever # Environment variable for API token api_token = os.getenv("API_TOKEN") print(f"API Token loaded: {api_token[:5]}...") # Debug: Show first 5 chars of token if not api_token: raise ValueError("Environment variable 'FirstToken' not set. Please set the Hugging Face API token.") # Available LLM models list_llm = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", # Publicly accessible "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/deepseek-llm-7b-chat" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # ----------------------------------------------------------------------------- # Document Loading and Splitting # ----------------------------------------------------------------------------- def load_doc(list_file_path, progress=gr.Progress()): """Load and split PDF documents into chunks.""" if not list_file_path: raise ValueError("No files provided for processing.") loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for i, loader in enumerate(loaders): progress((i + 1) / len(loaders), "Loading PDFs...") pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) return text_splitter.split_documents(pages) # ----------------------------------------------------------------------------- # Vector Database Creation # ----------------------------------------------------------------------------- def create_chromadb(splits, persist_directory="chroma_db"): """Create ChromaDB vector database from document splits.""" embeddings = HuggingFaceEmbeddings() chromadb = Chroma.from_documents( documents=splits, embedding=embeddings, persist_directory=persist_directory ) return chromadb # ----------------------------------------------------------------------------- # Retrievers # ----------------------------------------------------------------------------- def create_bm25_retriever(splits): """Create BM25 retriever from document splits.""" retriever = BM25Retriever.from_documents(splits) retriever.k = 3 return retriever def create_ensemble_retriever(vector_db, bm25_retriever): """Create an ensemble retriever combining vector DB and BM25.""" return EnsembleRetriever( retrievers=[vector_db.as_retriever(), bm25_retriever], weights=[0.7, 0.3] ) # ----------------------------------------------------------------------------- # Initialize Database # ----------------------------------------------------------------------------- def initialize_database(list_file_obj, progress=gr.Progress()): """Initialize the document database with error handling.""" try: list_file_path = [x.name for x in list_file_obj if x is not None] doc_splits = load_doc(list_file_path, progress) chromadb = create_chromadb(doc_splits) bm25_retriever = create_bm25_retriever(doc_splits) ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever) return ensemble_retriever, "Database created successfully!" except Exception as e: return None, f"Error initializing database: {str(e)}" # ----------------------------------------------------------------------------- # Initialize LLM Chain # ----------------------------------------------------------------------------- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever): """Initialize the language model chain with error handling.""" if retriever is None: raise ValueError("Retriever is None. Please process documents first.") try: print(f"Initializing LLM: {llm_model} with token: {api_token[:5]}...") llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, task="text-generation" ) memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", return_messages=True ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False ) return qa_chain except Exception as e: raise RuntimeError(f"Failed to initialize LLM chain: {str(e)}") # ----------------------------------------------------------------------------- # Initialize LLM # ----------------------------------------------------------------------------- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()): """Initialize the Language Model.""" if retriever is None: return None, "Error: No database initialized. Please process documents first." try: llm_name = list_llm[llm_option] print(f"Selected LLM model: {llm_name}") qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, retriever) return qa_chain, "Analysis Assistant initialized and ready!" except Exception as e: return None, f"Error initializing LLM: {str(e)}" # ----------------------------------------------------------------------------- # Chat History Formatting # ----------------------------------------------------------------------------- def format_chat_history(message, chat_history): """Format chat history for the model.""" return [f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in chat_history] # ----------------------------------------------------------------------------- # Conversation Function # ----------------------------------------------------------------------------- def conversation(qa_chain, message, history, lang): """Handle conversation and document analysis.""" if not qa_chain: return None, gr.update(value="Assistant not initialized"), history, "", 0, "", 0, "", 0 lang_instruction = " (Responda em Português)" if lang == "pt" else " (Respond in English)" query = message + lang_instruction try: formatted_chat_history = format_chat_history(message, history) response = qa_chain.invoke({"question": query, "chat_history": formatted_chat_history}) answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"] sources = response["source_documents"] source_data = [("Unknown", 0)] * 3 for i, doc in enumerate(sources[:3]): source_data[i] = (doc.page_content.strip(), doc.metadata["page"] + 1) new_history = history + [(message, answer)] return ( qa_chain, gr.update(value=""), new_history, source_data[0][0], source_data[0][1], source_data[1][0], source_data[1][1], source_data[2][0], source_data[2][1] ) except Exception as e: return qa_chain, gr.update(value=f"Error: {str(e)}"), history, "", 0, "", 0, "", 0 # ----------------------------------------------------------------------------- # Gradio Demo # ----------------------------------------------------------------------------- def demo(): """Main demo application with enhanced layout.""" theme = gr.themes.Default(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate") custom_css = """ .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);} .header {text-align: center; margin-bottom: 2rem;} .header h1 {color: #1a365d; font-size: 2.5rem; margin-bottom: 0.5rem;} .section {margin-bottom: 1.5rem; padding: 1rem; background: #f8fafc; border-radius: 8px;} """ with gr.Blocks(theme=theme, css=custom_css) as demo: retriever = gr.State() qa_chain = gr.State() language = gr.State(value="en") gr.HTML( '
Expert System for Metrology Report Analysis