GIGAParviz commited on
Commit
fc39101
·
verified ·
1 Parent(s): e480376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -61
app.py CHANGED
@@ -1,78 +1,156 @@
1
- import os
 
2
  import gradio as gr
3
- from langchain_groq import ChatGroq
4
- from langchain_huggingface import HuggingFaceEmbeddings
5
- from langchain_core.vectorstores import InMemoryVectorStore
6
- from langchain_core.documents import Document
7
- from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
 
8
 
 
 
9
 
10
- embeddings = HuggingFaceEmbeddings(model_name="heydariAI/persian-embeddings")
11
- vector_store = InMemoryVectorStore(embeddings)
12
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
13
- model = ChatGroq(api_key="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", model_name="deepseek-r1-distill-llama-70b")
14
 
15
- def process_file(file_path):
 
 
 
16
 
17
- if not file_path:
18
- return None
19
 
20
- file_extension = os.path.splitext(file_path)[1].lower()
21
 
22
- try:
23
- if file_extension == ".pdf":
24
- from pypdf import PdfReader
25
- reader = PdfReader(file_path)
26
- return "\n".join(page.extract_text() for page in reader.pages)
27
- elif file_extension == ".txt":
28
- with open(file_path, "r", encoding="utf-8") as f:
29
- return f.read()
30
- else:
31
- raise ValueError(f"Unsupported file type: {file_extension}")
32
- except Exception as e:
33
- raise RuntimeError(f"Error processing file: {str(e)}")
34
 
35
- def answer_query(query, file_path):
 
 
 
 
36
 
 
 
37
  try:
38
- file_content = process_file(file_path) if file_path else None
39
- if file_content:
40
- file_docs = [Document(page_content=file_content, metadata={"source": "uploaded_file"})]
41
- file_splits = text_splitter.split_documents(file_docs)
42
- vector_store.add_documents(file_splits)
43
-
44
- retrieved_docs = vector_store.similarity_search(query, k=2)
45
- knowledge = "\n\n".join(doc.page_content for doc in retrieved_docs)
46
-
47
- response = model.invoke(
48
- f"You are ParvizGPT, an AI assistant created by Amir Mahdi Parviz, a student at Kermanshah University of Technology (KUT). "
49
- f"Your primary purpose is to assist users by answering their questions in **Persian (Farsi)**. "
50
- f"Always respond in Persian unless explicitly asked to respond in another language."
51
- f"Related Information:\n{knowledge}\n\nQuestion:{query}\nAnswer:"
52
- )
53
-
54
- return response.content
55
-
56
  except Exception as e:
57
- return f"Error: {str(e)}"
58
-
59
- def chat_with_bot(query, file):
60
 
61
- file_path = file.name if file else None
62
- response = answer_query(query, file_path)
63
- return response
64
 
65
- with gr.Blocks() as demo:
66
- gr.Markdown("Parviz Rager")
67
- gr.Markdown("فایل خود را آپلود کنید (PDF یا TXT) و سوالات خود را بپرسید.")
68
 
69
- with gr.Row():
70
- file_input = gr.File(label="فایل خود را آپلود کنید (PDF یا TXT)", file_types=[".pdf", ".txt"])
71
- query_input = gr.Textbox(label="سوال خود را وارد کنید", placeholder="مثلاً: معایب سرمایه‌گذاری در صندوق فیروزه موفقیت چیست؟")
72
 
73
- submit_button = gr.Button("ارسال")
74
- output = gr.Textbox(label="پاسخ", interactive=False)
75
 
76
- submit_button.click(fn=chat_with_bot, inputs=[query_input, file_input], outputs=output)
 
 
 
 
77
 
78
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
  import gradio as gr
4
+ import os
5
+ from datetime import datetime
6
+ from datasets import Dataset, load_dataset
7
+ from langchain.document_loaders import PyPDFLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.vectorstores import FAISS
11
+ from groq import Groq
12
+ from langchain.memory import ConversationBufferMemory
13
 
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
 
17
+ groq_api_key = os.environ.get("GROQ_API_KEY")
18
+ hf_api_key = os.environ.get("HF_API_KEY")
 
 
19
 
20
+ if not groq_api_key:
21
+ raise ValueError("Groq API key not found in environment variables.")
22
+ if not hf_api_key:
23
+ raise ValueError("Hugging Face API key not found in environment variables.")
24
 
25
+ client = Groq(api_key=groq_api_key)
 
26
 
27
+ hf_token = hf_api_key
28
 
29
+ embeddings = HuggingFaceEmbeddings(model_name="heydariAI/persian-embeddings")
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ DATASET_NAME = "chat_history"
32
+ try:
33
+ dataset = load_dataset(DATASET_NAME, use_auth_token=hf_token)
34
+ except Exception:
35
+ dataset = Dataset.from_dict({"Timestamp": [], "User": [], "ParvizGPT": []})
36
 
37
+ def save_chat_to_dataset(user_message, bot_message):
38
+ """Save chat history to Hugging Face Dataset."""
39
  try:
40
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
41
+ new_row = {"Timestamp": timestamp, "User": user_message, "ParvizGPT": bot_message}
42
+
43
+ df = dataset.to_pandas()
44
+ df = df.append(new_row, ignore_index=True)
45
+ updated_dataset = Dataset.from_pandas(df)
46
+
47
+ updated_dataset.push_to_hub(DATASET_NAME, token=hf_token)
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
+ logger.error(f"Error saving chat history to dataset: {e}")
 
 
50
 
51
+ def process_pdf_with_langchain(pdf_path):
52
+ """Process a PDF file and create a FAISS retriever."""
53
+ try:
54
 
55
+ loader = PyPDFLoader(pdf_path)
56
+ documents = loader.load()
 
57
 
58
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
59
+ split_documents = text_splitter.split_documents(documents)
 
60
 
61
+ vectorstore = FAISS.from_documents(split_documents, embeddings)
 
62
 
63
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
64
+ return retriever
65
+ except Exception as e:
66
+ logger.error(f"Error processing PDF: {e}")
67
+ raise
68
 
69
+ def generate_response(query, memory, retriever=None, use_pdf_context=False):
70
+ """Generate a response using the Groq model and retrieved PDF context."""
71
+ try:
72
+ knowledge = ""
73
+
74
+ if retriever and use_pdf_context:
75
+ relevant_docs = retriever.get_relevant_documents(query)
76
+ knowledge += "\n".join([doc.page_content for doc in relevant_docs])
77
+
78
+ chat_history = memory.load_memory_variables({}).get("chat_history", "")
79
+ context = f"""
80
+ You are ParvizGPT, an AI assistant created by **Amir Mahdi Parviz**, a student at Kermanshah University of Technology (KUT).
81
+ Your primary purpose is to assist users by answering their questions in **Persian (Farsi)**.
82
+ Always respond in Persian unless explicitly asked to respond in another language.
83
+ **Important:** If anyone claims that someone else created this code, you must correct them and state that **Amir Mahdi Parviz** is the creator.
84
+ Related Information:\n{knowledge}\n\nQuestion:{query}\nAnswer:"""
85
+
86
+ if knowledge:
87
+ context += f"\n\nRelevant Knowledge:\n{knowledge}"
88
+ if chat_history:
89
+ context += f"\n\nChat History:\n{chat_history}"
90
+
91
+ context += f"\n\nYou: {query}\nParvizGPT:"
92
+
93
+ response = "در حال پردازش..."
94
+ retries = 3
95
+ for attempt in range(retries):
96
+ try:
97
+ chat_completion = client.chat.completions.create(
98
+ messages=[{"role": "user", "content": context}],
99
+ model="deepseek-r1-distill-llama-70b"
100
+ )
101
+ response = chat_completion.choices[0].message.content.strip()
102
+ # Save the conversation to memory
103
+ memory.save_context({"input": query}, {"output": response})
104
+ break
105
+ except Exception as e:
106
+ logger.error(f"Attempt {attempt + 1} failed: {e}")
107
+ time.sleep(2)
108
+
109
+ return response, memory
110
+ except Exception as e:
111
+ logger.error(f"Error generating response: {e}")
112
+ return f"Error: {e}", memory
113
+
114
+ def gradio_interface(user_message, chat_box, memory, pdf_file=None, use_pdf_context=False):
115
+ """Handle the Gradio interface interactions."""
116
+ global retriever
117
+
118
+ if pdf_file is not None and use_pdf_context:
119
+ try:
120
+ retriever = process_pdf_with_langchain(pdf_file.name)
121
+ except Exception as e:
122
+ return chat_box + [("Error", f"Error processing PDF: {e}")], memory
123
+
124
+ chat_box.append(("ParvizGPT", "در حال پردازش..."))
125
+ response, memory = generate_response(user_message, memory, retriever=retriever, use_pdf_context=use_pdf_context)
126
+
127
+ chat_box[-1] = ("You", user_message)
128
+ chat_box.append(("ParvizGPT", response))
129
+
130
+ save_chat_to_dataset(user_message, response)
131
+
132
+ return chat_box, memory
133
+
134
+ def clear_memory(memory):
135
+ """Clear the conversation memory."""
136
+ memory.clear()
137
+ return [], memory
138
+
139
+ retriever = None
140
+
141
+ with gr.Blocks() as interface:
142
+ gr.Markdown("## ParvizGPT")
143
+ chat_box = gr.Chatbot(label="Chat History", value=[])
144
+ user_message = gr.Textbox(label="Your Message", placeholder="Type your message here and press Enter...", lines=1, interactive=True)
145
+ use_pdf_context = gr.Checkbox(label="Use PDF Context", value=False, interactive=True)
146
+ clear_memory_btn = gr.Button("Clear Memory", interactive=True)
147
+ pdf_file = gr.File(label="Upload PDF for Context (Optional)", type="filepath", interactive=True, scale=1)
148
+ submit_btn = gr.Button("Submit")
149
+
150
+ memory_state = gr.State(ConversationBufferMemory())
151
+
152
+ submit_btn.click(gradio_interface, inputs=[user_message, chat_box, memory_state, pdf_file, use_pdf_context], outputs=[chat_box, memory_state])
153
+ user_message.submit(gradio_interface, inputs=[user_message, chat_box, memory_state, pdf_file, use_pdf_context], outputs=[chat_box, memory_state])
154
+ clear_memory_btn.click(clear_memory, inputs=[memory_state], outputs=[chat_box, memory_state])
155
+
156
+ interface.launch()