Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ from langchain.memory import ConversationBufferMemory
|
|
9 |
from langchain.chains import ConversationalRetrievalChain
|
10 |
from transformers import pipeline
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
-
import tavily
|
13 |
|
14 |
class AdvancedRAGChatbot:
|
15 |
def __init__(self,
|
@@ -22,8 +22,8 @@ class AdvancedRAGChatbot:
|
|
22 |
os.environ["TAVILY_API_KEY"] = tavily_api_key
|
23 |
|
24 |
# Correct Tavily Client initialization
|
25 |
-
self.tavily_client = tavily.
|
26 |
-
|
27 |
# NLP Components
|
28 |
self.embeddings = self._configure_embeddings(embedding_model)
|
29 |
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
@@ -35,12 +35,11 @@ class AdvancedRAGChatbot:
|
|
35 |
|
36 |
# Conversation Memory
|
37 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
38 |
-
|
39 |
def _configure_embeddings(self, model_name: str):
|
40 |
"""Configure embeddings with normalization"""
|
41 |
encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
|
42 |
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
|
43 |
-
|
44 |
|
45 |
def _configure_llm(self, model_name: str, temperature: float):
|
46 |
"""Configure the Language Model with Groq"""
|
@@ -81,7 +80,13 @@ class AdvancedRAGChatbot:
|
|
81 |
# NLP Analysis
|
82 |
semantic_score = self.semantic_model.encode([query])[0]
|
83 |
sentiment_result = self.sentiment_analyzer(query)[0]
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Prepare prompt with web search context
|
87 |
full_prompt = f"""
|
@@ -116,10 +121,10 @@ def main():
|
|
116 |
)
|
117 |
|
118 |
# Retrieve Tavily API Key from Environment Variable
|
119 |
-
tavily_api_key = os.getenv("
|
120 |
|
121 |
if not tavily_api_key:
|
122 |
-
st.warning("Tavily API Key is missing. Please set the '
|
123 |
st.stop()
|
124 |
|
125 |
# Sidebar Configuration
|
@@ -186,8 +191,13 @@ def main():
|
|
186 |
|
187 |
# Named Entities
|
188 |
st.markdown("#### Detected Entities")
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# Web Sources
|
193 |
if response['web_sources']:
|
@@ -204,4 +214,4 @@ def main():
|
|
204 |
st.info("Enter a query to search the web and get an AI-powered response")
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
-
main()
|
|
|
9 |
from langchain.chains import ConversationalRetrievalChain
|
10 |
from transformers import pipeline
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
+
import tavily
|
13 |
|
14 |
class AdvancedRAGChatbot:
|
15 |
def __init__(self,
|
|
|
22 |
os.environ["TAVILY_API_KEY"] = tavily_api_key
|
23 |
|
24 |
# Correct Tavily Client initialization
|
25 |
+
self.tavily_client = tavily.Client(api_key=tavily_api_key)
|
26 |
+
|
27 |
# NLP Components
|
28 |
self.embeddings = self._configure_embeddings(embedding_model)
|
29 |
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
35 |
|
36 |
# Conversation Memory
|
37 |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
38 |
+
|
39 |
def _configure_embeddings(self, model_name: str):
|
40 |
"""Configure embeddings with normalization"""
|
41 |
encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
|
42 |
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
|
|
|
43 |
|
44 |
def _configure_llm(self, model_name: str, temperature: float):
|
45 |
"""Configure the Language Model with Groq"""
|
|
|
80 |
# NLP Analysis
|
81 |
semantic_score = self.semantic_model.encode([query])[0]
|
82 |
sentiment_result = self.sentiment_analyzer(query)[0]
|
83 |
+
|
84 |
+
# Safe NER processing
|
85 |
+
try:
|
86 |
+
entities = self.ner_pipeline(query)
|
87 |
+
except Exception as e:
|
88 |
+
st.warning(f"NER processing error: {e}")
|
89 |
+
entities = []
|
90 |
|
91 |
# Prepare prompt with web search context
|
92 |
full_prompt = f"""
|
|
|
121 |
)
|
122 |
|
123 |
# Retrieve Tavily API Key from Environment Variable
|
124 |
+
tavily_api_key = os.getenv("TAVILY_API_KEY")
|
125 |
|
126 |
if not tavily_api_key:
|
127 |
+
st.warning("Tavily API Key is missing. Please set the 'TAVILY_API_KEY' environment variable.")
|
128 |
st.stop()
|
129 |
|
130 |
# Sidebar Configuration
|
|
|
191 |
|
192 |
# Named Entities
|
193 |
st.markdown("#### Detected Entities")
|
194 |
+
if response['named_entities']:
|
195 |
+
for entity in response['named_entities']:
|
196 |
+
word = entity.get('word', 'Unknown')
|
197 |
+
entity_type = entity.get('entity_type', entity.get('entity', 'Unknown Type'))
|
198 |
+
st.text(f"{word} ({entity_type})")
|
199 |
+
else:
|
200 |
+
st.info("No entities detected")
|
201 |
|
202 |
# Web Sources
|
203 |
if response['web_sources']:
|
|
|
214 |
st.info("Enter a query to search the web and get an AI-powered response")
|
215 |
|
216 |
if __name__ == "__main__":
|
217 |
+
main()
|