Deepakraj2006 commited on
Commit
6357e8e
Β·
verified Β·
1 Parent(s): 853da04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from langchain.chains import ConversationalRetrievalChain
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from transformers import pipeline
11
+
12
+ # Set Hugging Face Cache Directory
13
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
14
+
15
+ # Check for GPU availability
16
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Global variables
19
+ conversation_retrieval_chain = None
20
+ chat_history = []
21
+ llm_pipeline = None
22
+ embeddings = None
23
+ persist_directory = "/tmp/chroma_db" # Storage for vector DB
24
+
25
+
26
+ def init_llm():
27
+ """Initialize LLM and Embeddings"""
28
+ global llm_pipeline, embeddings
29
+
30
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
31
+ if not hf_token:
32
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.")
33
+
34
+ model_id = "tiiuae/falcon-7b-instruct"
35
+ hf_pipeline = pipeline("text-generation", model=model_id, device=DEVICE)
36
+ llm_pipeline = HuggingFacePipeline(pipeline=hf_pipeline)
37
+
38
+ embeddings = HuggingFaceEmbeddings(
39
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
40
+ model_kwargs={"device": DEVICE}
41
+ )
42
+
43
+
44
+ def process_document(file):
45
+ """Process uploaded PDF and create a retriever"""
46
+ global conversation_retrieval_chain
47
+
48
+ if not llm_pipeline or not embeddings:
49
+ init_llm()
50
+
51
+ # Load PDF and split text
52
+ loader = PyPDFLoader(file.name)
53
+ documents = loader.load()
54
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
55
+ texts = text_splitter.split_documents(documents)
56
+
57
+ # Load or create ChromaDB
58
+ if os.path.exists(persist_directory):
59
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
60
+ else:
61
+ db = Chroma.from_documents(texts, embedding=embeddings, persist_directory=persist_directory)
62
+
63
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={'k': 6})
64
+
65
+ # Initialize ConversationalRetrievalChain
66
+ conversation_retrieval_chain = ConversationalRetrievalChain.from_llm(
67
+ llm=llm_pipeline, retriever=retriever
68
+ )
69
+
70
+ return "πŸ“„ PDF uploaded and processed successfully! You can now ask questions."
71
+
72
+
73
+ def process_prompt(prompt, chat_history_display):
74
+ """Generate a response using the retrieval chain"""
75
+ global conversation_retrieval_chain, chat_history
76
+
77
+ if not conversation_retrieval_chain:
78
+ return chat_history_display + [("❌ No document uploaded.", "Please upload a PDF first.")]
79
+
80
+ output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
81
+ answer = output["answer"]
82
+
83
+ chat_history.append((prompt, answer))
84
+
85
+ return chat_history
86
+
87
+
88
+ # Define Gradio UI
89
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("<h1 style='text-align: center;'>Personal Data Assistant</h1>")
91
+
92
+ with gr.Row():
93
+ dark_mode = gr.Checkbox(label="πŸŒ™ Toggle light/dark mode")
94
+
95
+ with gr.Row():
96
+ with gr.Box():
97
+ gr.Markdown("Hello there! I'm your friendly data assistant, ready to answer any questions regarding your data. Could you please upload a PDF file for me to analyze?")
98
+ file_input = gr.File(label="Upload File")
99
+ upload_button = gr.Button("πŸ“‚ Upload File")
100
+
101
+ status_output = gr.Textbox(label="Status", interactive=False)
102
+
103
+ chat_history_display = gr.Chatbot(label="Chat History")
104
+
105
+ with gr.Row():
106
+ user_input = gr.Textbox(placeholder="Type your message here...", scale=4)
107
+ submit_button = gr.Button("πŸ“©", scale=1)
108
+ clear_button = gr.Button("πŸ”„", scale=1)
109
+
110
+ # Button Click Actions
111
+ upload_button.click(process_document, inputs=file_input, outputs=status_output)
112
+ submit_button.click(process_prompt, inputs=[user_input, chat_history_display], outputs=chat_history_display)
113
+ clear_button.click(lambda: [], outputs=chat_history_display)
114
+
115
+ # Launch Gradio App
116
+ if __name__ == "__main__":
117
+ demo.launch(share=True)