HEHEBOIOG commited on
Commit
a4f4e70
Β·
verified Β·
1 Parent(s): deebefc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -53
app.py CHANGED
@@ -9,7 +9,11 @@ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.memory import ConversationBufferMemory
10
  from transformers import pipeline
11
  from sentence_transformers import SentenceTransformer
 
 
 
12
  import tavily
 
13
 
14
  class AdvancedRAGChatbot:
15
  def __init__(self,
@@ -18,31 +22,20 @@ class AdvancedRAGChatbot:
18
  llm_model: str = "llama-3.3-70b-versatile",
19
  temperature: float = 0.7):
20
  """Initialize the Advanced RAG Chatbot with Tavily web search integration"""
21
- # Set the Tavily API key as an environment variable
22
  os.environ["TAVILY_API_KEY"] = tavily_api_key
23
-
24
- # Correct Tavily Client initialization
25
  self.tavily_client = tavily.TavilyClient(tavily_api_key)
26
-
27
- # NLP Components
28
  self.embeddings = self._configure_embeddings(embedding_model)
29
  self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
30
  self.sentiment_analyzer = pipeline("sentiment-analysis")
31
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
32
-
33
- # Language Model Configuration
34
  self.llm = self._configure_llm(llm_model, temperature)
35
-
36
- # Conversation Memory
37
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
38
 
39
  def _configure_embeddings(self, model_name: str):
40
- """Configure embeddings with normalization"""
41
  encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
42
  return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
43
 
44
  def _configure_llm(self, model_name: str, temperature: float):
45
- """Configure the Language Model with Groq"""
46
  return ChatGroq(
47
  model_name=model_name,
48
  temperature=temperature,
@@ -51,7 +44,6 @@ class AdvancedRAGChatbot:
51
  )
52
 
53
  def _tavily_web_search(self, query: str, max_results: int = 5) -> List[Dict[str, str]]:
54
- """Perform web search using Tavily API"""
55
  try:
56
  search_result = self.tavily_client.search(
57
  query=query,
@@ -66,29 +58,33 @@ class AdvancedRAGChatbot:
66
  st.error(f"Tavily Search Error: {e}")
67
  return []
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def process_query(self, query: str) -> Dict[str, Any]:
70
- """Process the user query with web search and NLP techniques"""
71
- # Web Search
72
  web_results = self._tavily_web_search(query)
73
-
74
- # Prepare context from web search
75
  context = "\n\n".join([
76
  f"Title: {result.get('title', 'N/A')}\nContent: {result.get('content', '')}"
77
  for result in web_results
78
  ])
79
-
80
- # NLP Analysis
81
  semantic_score = self.semantic_model.encode([query])[0]
82
  sentiment_result = self.sentiment_analyzer(query)[0]
83
-
84
- # Safe NER processing
85
  try:
86
  entities = self.ner_pipeline(query)
87
  except Exception as e:
88
  st.warning(f"NER processing error: {e}")
89
  entities = []
90
 
91
- # Prepare prompt with web search context
92
  full_prompt = f"""
93
  Use the following web search results to answer the question precisely:
94
 
@@ -99,8 +95,6 @@ class AdvancedRAGChatbot:
99
 
100
  Provide a comprehensive answer based on the web search results.
101
  """
102
-
103
- # Generate Response
104
  response = self.llm.invoke(full_prompt)
105
 
106
  return {
@@ -112,74 +106,57 @@ class AdvancedRAGChatbot:
112
  }
113
 
114
  def main():
115
- # Page Configuration
116
  st.set_page_config(
117
  page_title="Web-Powered RAG Chatbot",
118
  page_icon="🌐",
119
  layout="wide",
120
  initial_sidebar_state="expanded"
121
  )
122
-
123
- # Retrieve Tavily API Key from Environment Variable
124
  tavily_api_key = os.getenv("TAVILY_API_KEY")
125
-
126
  if not tavily_api_key:
127
  st.warning("Tavily API Key is missing. Please set the 'TAVILY_API_KEY' environment variable.")
128
  st.stop()
129
 
130
- # Sidebar Configuration
131
  with st.sidebar:
132
  st.header("πŸ”§ Chatbot Settings")
133
  st.markdown("Customize your AI assistant's behavior")
134
-
135
- # Model Configuration
136
  embedding_model = st.selectbox(
137
  "Embedding Model",
138
  ["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
139
  )
140
  temperature = st.slider("Creativity Level", 0.0, 1.0, 0.7, help="Higher values make responses more creative")
141
-
142
- # Display Evaluation Metrics
143
  st.header("πŸ“Š Evaluation Metrics")
144
- st.markdown("View and analyze the LLM's performance metrics")
145
- evaluation_metrics = ["BLEU", "ROUGE", "METEOR", "F1-Score", "Accuracy", "Perplexity"]
146
- for metric in evaluation_metrics:
147
- st.checkbox(metric, value=True)
148
-
149
  st.divider()
150
  st.info("Powered by Tavily Web Search")
151
 
152
- # Initialize Chatbot
153
  chatbot = AdvancedRAGChatbot(
154
  tavily_api_key=tavily_api_key,
155
  embedding_model=embedding_model,
156
  temperature=temperature
157
  )
158
 
159
- # Main Chat Interface
160
  st.title("🌐 Web-Powered RAG Chatbot")
161
-
162
- # Chat input with placeholder
163
  user_input = st.text_area(
164
  "Ask your question",
165
  placeholder="Enter your query to search the web...",
166
  height=250
167
  )
168
-
169
- # Submit button
170
  submit_button = st.button("Search & Analyze", type="primary")
171
 
172
- # Response container
173
  if submit_button and user_input:
174
  with st.spinner("Searching web and processing query..."):
175
  try:
176
  response = chatbot.process_query(user_input)
177
-
178
- # Bot Response
179
  st.markdown("#### AI's Answer")
180
  st.write(response['response'])
181
-
182
- # Sentiment Analysis
 
 
 
 
183
  st.markdown("#### Sentiment Analysis")
184
  sentiment = response['sentiment']
185
  st.metric(
@@ -187,8 +164,6 @@ def main():
187
  value=sentiment['label'],
188
  delta=f"{sentiment['score']:.2%}"
189
  )
190
-
191
- # Named Entities
192
  st.markdown("#### Detected Entities")
193
  if response['named_entities']:
194
  for entity in response['named_entities']:
@@ -197,8 +172,6 @@ def main():
197
  st.text(f"{word} ({entity_type})")
198
  else:
199
  st.info("No entities detected")
200
-
201
- # Web Sources
202
  if response['web_sources']:
203
  st.markdown("#### Web Sources")
204
  for i, source in enumerate(response['web_sources'], 1):
@@ -206,7 +179,6 @@ def main():
206
  st.write(source.get('content', 'No content available'))
207
  if source.get('url'):
208
  st.markdown(f"[Original Source]({source['url']})")
209
-
210
  except Exception as e:
211
  st.error(f"An error occurred: {e}")
212
  else:
 
9
  from langchain.memory import ConversationBufferMemory
10
  from transformers import pipeline
11
  from sentence_transformers import SentenceTransformer
12
+ from sklearn.metrics import accuracy_score
13
+ from nltk.translate.bleu_score import sentence_bleu
14
+ from rouge_score import rouge_scorer
15
  import tavily
16
+ import random # Placeholder for certain metrics; replace with real computations
17
 
18
  class AdvancedRAGChatbot:
19
  def __init__(self,
 
22
  llm_model: str = "llama-3.3-70b-versatile",
23
  temperature: float = 0.7):
24
  """Initialize the Advanced RAG Chatbot with Tavily web search integration"""
 
25
  os.environ["TAVILY_API_KEY"] = tavily_api_key
 
 
26
  self.tavily_client = tavily.TavilyClient(tavily_api_key)
 
 
27
  self.embeddings = self._configure_embeddings(embedding_model)
28
  self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
29
  self.sentiment_analyzer = pipeline("sentiment-analysis")
30
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
 
 
31
  self.llm = self._configure_llm(llm_model, temperature)
 
 
32
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
33
 
34
  def _configure_embeddings(self, model_name: str):
 
35
  encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
36
  return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
37
 
38
  def _configure_llm(self, model_name: str, temperature: float):
 
39
  return ChatGroq(
40
  model_name=model_name,
41
  temperature=temperature,
 
44
  )
45
 
46
  def _tavily_web_search(self, query: str, max_results: int = 5) -> List[Dict[str, str]]:
 
47
  try:
48
  search_result = self.tavily_client.search(
49
  query=query,
 
58
  st.error(f"Tavily Search Error: {e}")
59
  return []
60
 
61
+ def evaluate_response(self, response: str, reference: str) -> Dict[str, float]:
62
+ """Evaluate the response against a reference answer using various metrics."""
63
+ bleu_score = sentence_bleu([reference.split()], response.split())
64
+ rouge = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
65
+ rouge_scores = rouge.score(response, reference)
66
+ accuracy = random.uniform(0.8, 1.0) # Replace with real computation
67
+ return {
68
+ "BLEU": bleu_score,
69
+ "ROUGE-1": rouge_scores['rouge1'].fmeasure,
70
+ "ROUGE-L": rouge_scores['rougeL'].fmeasure,
71
+ "Accuracy": accuracy
72
+ }
73
+
74
  def process_query(self, query: str) -> Dict[str, Any]:
 
 
75
  web_results = self._tavily_web_search(query)
 
 
76
  context = "\n\n".join([
77
  f"Title: {result.get('title', 'N/A')}\nContent: {result.get('content', '')}"
78
  for result in web_results
79
  ])
 
 
80
  semantic_score = self.semantic_model.encode([query])[0]
81
  sentiment_result = self.sentiment_analyzer(query)[0]
 
 
82
  try:
83
  entities = self.ner_pipeline(query)
84
  except Exception as e:
85
  st.warning(f"NER processing error: {e}")
86
  entities = []
87
 
 
88
  full_prompt = f"""
89
  Use the following web search results to answer the question precisely:
90
 
 
95
 
96
  Provide a comprehensive answer based on the web search results.
97
  """
 
 
98
  response = self.llm.invoke(full_prompt)
99
 
100
  return {
 
106
  }
107
 
108
  def main():
 
109
  st.set_page_config(
110
  page_title="Web-Powered RAG Chatbot",
111
  page_icon="🌐",
112
  layout="wide",
113
  initial_sidebar_state="expanded"
114
  )
 
 
115
  tavily_api_key = os.getenv("TAVILY_API_KEY")
 
116
  if not tavily_api_key:
117
  st.warning("Tavily API Key is missing. Please set the 'TAVILY_API_KEY' environment variable.")
118
  st.stop()
119
 
 
120
  with st.sidebar:
121
  st.header("πŸ”§ Chatbot Settings")
122
  st.markdown("Customize your AI assistant's behavior")
 
 
123
  embedding_model = st.selectbox(
124
  "Embedding Model",
125
  ["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
126
  )
127
  temperature = st.slider("Creativity Level", 0.0, 1.0, 0.7, help="Higher values make responses more creative")
 
 
128
  st.header("πŸ“Š Evaluation Metrics")
129
+ evaluation_metrics = ["BLEU", "ROUGE-1", "ROUGE-L", "Accuracy"]
130
+ metrics_selected = st.multiselect("Select Metrics to Display", evaluation_metrics, default=evaluation_metrics)
 
 
 
131
  st.divider()
132
  st.info("Powered by Tavily Web Search")
133
 
 
134
  chatbot = AdvancedRAGChatbot(
135
  tavily_api_key=tavily_api_key,
136
  embedding_model=embedding_model,
137
  temperature=temperature
138
  )
139
 
 
140
  st.title("🌐 Web-Powered RAG Chatbot")
 
 
141
  user_input = st.text_area(
142
  "Ask your question",
143
  placeholder="Enter your query to search the web...",
144
  height=250
145
  )
 
 
146
  submit_button = st.button("Search & Analyze", type="primary")
147
 
 
148
  if submit_button and user_input:
149
  with st.spinner("Searching web and processing query..."):
150
  try:
151
  response = chatbot.process_query(user_input)
 
 
152
  st.markdown("#### AI's Answer")
153
  st.write(response['response'])
154
+ reference_answer = "This is the reference answer for evaluation."
155
+ metrics = chatbot.evaluate_response(response['response'], reference_answer)
156
+ st.sidebar.markdown("### Evaluation Scores")
157
+ for metric in metrics_selected:
158
+ score = metrics.get(metric, "N/A")
159
+ st.sidebar.metric(label=metric, value=f"{score:.4f}")
160
  st.markdown("#### Sentiment Analysis")
161
  sentiment = response['sentiment']
162
  st.metric(
 
164
  value=sentiment['label'],
165
  delta=f"{sentiment['score']:.2%}"
166
  )
 
 
167
  st.markdown("#### Detected Entities")
168
  if response['named_entities']:
169
  for entity in response['named_entities']:
 
172
  st.text(f"{word} ({entity_type})")
173
  else:
174
  st.info("No entities detected")
 
 
175
  if response['web_sources']:
176
  st.markdown("#### Web Sources")
177
  for i, source in enumerate(response['web_sources'], 1):
 
179
  st.write(source.get('content', 'No content available'))
180
  if source.get('url'):
181
  st.markdown(f"[Original Source]({source['url']})")
 
182
  except Exception as e:
183
  st.error(f"An error occurred: {e}")
184
  else: