Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
import logging | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.document_loaders import PyMuPDFLoader # β More stable PDF loader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline | |
# Setup Logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set Hugging Face Cache Directory | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
# Check for GPU availability | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Global variables | |
conversation_retrieval_chain = None | |
chat_history = [] | |
llm_pipeline = None | |
embeddings = None | |
persist_directory = "/tmp/chroma_db" # Storage for vector DB | |
def init_llm(): | |
"""Initialize LLM and Embeddings""" | |
global llm_pipeline, embeddings | |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not hf_token: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.") | |
model_id = "tiiuae/falcon-rw-1b" # β Can switch to "tiiuae/falcon-rw-1b" for lighter model | |
hf_pipeline = pipeline( | |
"text-generation", | |
model=model_id, | |
device=DEVICE, | |
max_new_tokens=512 # Increase this as needed | |
) | |
llm_pipeline = HuggingFacePipeline(pipeline=hf_pipeline) | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": DEVICE} | |
) | |
logger.info("β LLM and Embeddings Initialized Successfully!") | |
def process_document(file): | |
"""Process uploaded PDF and create a retriever""" | |
global conversation_retrieval_chain | |
if not llm_pipeline or not embeddings: | |
init_llm() | |
try: | |
file_path = file.name # β Ensures correct file path is passed | |
logger.info(f"π Processing PDF: {file_path}") | |
loader = PyMuPDFLoader(file_path) # β Alternative loader for stability | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) | |
texts = text_splitter.split_documents(documents) | |
# Load or create ChromaDB | |
db = Chroma.from_documents(texts, embedding=embeddings, persist_directory=persist_directory) | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 6}) | |
conversation_retrieval_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm_pipeline, retriever=retriever | |
) | |
logger.info("β PDF Processed Successfully!") | |
return "π PDF uploaded and processed successfully! You can now ask questions." | |
except Exception as e: | |
logger.error(f"β Error processing PDF: {str(e)}") | |
return f"β Error processing PDF: {str(e)}" | |
def process_prompt(prompt, chat_history_display): | |
"""Generate a response using the retrieval chain""" | |
global conversation_retrieval_chain, chat_history | |
if not conversation_retrieval_chain: | |
return chat_history_display + [("β No document uploaded.", "Please upload a PDF first.")] | |
output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history}) | |
answer = output["answer"] | |
chat_history.append((prompt, answer)) | |
return answer | |
# Define Gradio UI | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("<h1 style='text-align: center;'>Personal Data Assistant</h1>") | |
with gr.Row(): | |
dark_mode = gr.Checkbox(label="π Toggle light/dark mode") | |
with gr.Column(): | |
gr.Markdown("Hello there! I'm your friendly data assistant, ready to answer any questions regarding your data. Could you please upload a PDF file for me to analyze?") | |
file_input = gr.File(label="Upload File") | |
upload_button = gr.Button("π Upload File") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
chat_history_display = gr.Chatbot(label="Chat History") | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Type your message here...", scale=4) | |
submit_button = gr.Button("π©", scale=1) | |
clear_button = gr.Button("π", scale=1) | |
# Button Click Actions | |
upload_button.click(process_document, inputs=file_input, outputs=status_output) | |
submit_button.click(process_prompt, inputs=[user_input, chat_history_display], outputs=chat_history_display) | |
clear_button.click(lambda: [], outputs=chat_history_display) | |
# Launch Gradio App | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) # β Works in Hugging Face Spaces |