HEHEBOIOG commited on
Commit
6c15522
Β·
verified Β·
1 Parent(s): cf3027b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -141
app.py CHANGED
@@ -7,28 +7,21 @@ from langchain_core.prompts import ChatPromptTemplate
7
  from langchain_groq import ChatGroq
8
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.memory import ConversationBufferMemory
10
- from langchain.chains import ConversationalRetrievalChain
11
  from transformers import pipeline
12
  from sentence_transformers import SentenceTransformer
13
  import tavily
14
 
15
- # Evaluation Metrics Libraries
16
- from rouge_score import rouge_scorer
17
- from nltk.translate.bleu_score import sentence_bleu
18
- from nltk.tokenize import word_tokenize
19
- from sklearn.metrics.pairwise import cosine_similarity
20
- from textstat import flesch_reading_ease, flesch_kincaid_grade
21
-
22
  class AdvancedRAGChatbot:
23
  def __init__(self,
24
  tavily_api_key: str,
25
  embedding_model: str = "BAAI/bge-large-en-v1.5",
26
  llm_model: str = "llama-3.3-70b-versatile",
27
  temperature: float = 0.7):
28
- """Initialize the Advanced RAG Chatbot with Enhanced Metrics"""
 
29
  os.environ["TAVILY_API_KEY"] = tavily_api_key
30
 
31
- # Tavily Client
32
  self.tavily_client = tavily.TavilyClient(tavily_api_key)
33
 
34
  # NLP Components
@@ -37,69 +30,44 @@ class AdvancedRAGChatbot:
37
  self.sentiment_analyzer = pipeline("sentiment-analysis")
38
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
39
 
40
- # Evaluation Components
41
- self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
42
-
43
  # Language Model Configuration
44
  self.llm = self._configure_llm(llm_model, temperature)
45
 
46
  # Conversation Memory
47
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
48
-
49
- def _calculate_comprehensive_metrics(self, query: str, response: str, web_sources: List[Dict]) -> Dict[str, Any]:
50
- """Calculate comprehensive evaluation metrics"""
51
- metrics = {}
52
-
53
- # Readability Metrics
54
- metrics['flesch_reading_ease'] = flesch_reading_ease(response)
55
- metrics['flesch_kincaid_grade'] = flesch_kincaid_grade(response)
56
-
57
- # Length Metrics
58
- metrics['query_length'] = len(word_tokenize(query))
59
- metrics['response_length'] = len(word_tokenize(response))
60
-
61
- # BLEU Score (compared against web sources)
62
- reference_texts = [word_tokenize(source.get('content', '')) for source in web_sources]
63
- candidate_tokens = word_tokenize(response)
64
-
65
- bleu_scores = []
66
- for ref in reference_texts:
67
- try:
68
- bleu_score = sentence_bleu([ref], candidate_tokens)
69
- bleu_scores.append(bleu_score)
70
- except Exception:
71
- pass
72
-
73
- metrics['average_bleu_score'] = np.mean(bleu_scores) if bleu_scores else 0.0
74
-
75
- # ROUGE Scores
76
- reference_text = ' '.join([source.get('content', '') for source in web_sources])
77
- rouge_scores = self.rouge_scorer.score(reference_text, response)
78
- metrics['rouge_scores'] = {
79
- 'rouge1': rouge_scores['rouge1'].fmeasure,
80
- 'rouge2': rouge_scores['rouge2'].fmeasure,
81
- 'rougeL': rouge_scores['rougeL'].fmeasure
82
- }
83
-
84
- # Semantic Similarity
85
  try:
86
- web_source_embeddings = self.semantic_model.encode([source.get('content', '') for source in web_sources])
87
- response_embedding = self.semantic_model.encode([response])[0]
88
-
89
- semantic_similarities = cosine_similarity([response_embedding], web_source_embeddings)[0]
90
- metrics['semantic_similarity'] = {
91
- 'mean': np.mean(semantic_similarities),
92
- 'max': np.max(semantic_similarities),
93
- 'min': np.min(semantic_similarities)
94
- }
95
  except Exception as e:
96
- st.warning(f"Semantic similarity calculation error: {e}")
97
- metrics['semantic_similarity'] = {'mean': 0, 'max': 0, 'min': 0}
98
-
99
- return metrics
100
-
101
  def process_query(self, query: str) -> Dict[str, Any]:
102
- """Process the user query with comprehensive evaluation"""
103
  # Web Search
104
  web_results = self._tavily_web_search(query)
105
 
@@ -134,91 +102,109 @@ class AdvancedRAGChatbot:
134
 
135
  # Generate Response
136
  response = self.llm.invoke(full_prompt)
137
- response_content = response.content
138
-
139
- # Calculate Comprehensive Metrics
140
- evaluation_metrics = self._calculate_comprehensive_metrics(
141
- query,
142
- response_content,
143
- web_results
144
- )
145
 
146
  return {
147
- "response": response_content,
148
  "web_sources": web_results,
149
  "semantic_similarity": semantic_score.tolist(),
150
  "sentiment": sentiment_result,
151
- "named_entities": entities,
152
- "evaluation_metrics": evaluation_metrics
153
  }
154
 
155
  def main():
156
- # [Previous main function code remains the same]
157
- # Add a new section to display comprehensive metrics
158
- with col2:
159
- st.header("Response & Metrics")
160
- if submit_button and user_input:
161
- with st.spinner("Searching web and processing query..."):
162
- try:
163
- response = chatbot.process_query(user_input)
164
-
165
- # Existing response display code...
166
-
167
- # Comprehensive Metrics Display
168
- st.markdown("### πŸ“Š Comprehensive Evaluation Metrics")
169
-
170
- # Readability Metrics
171
- col_read1, col_read2 = st.columns(2)
172
- with col_read1:
173
- st.metric(
174
- "Flesch Reading Ease",
175
- f"{response['evaluation_metrics']['flesch_reading_ease']:.2f}",
176
- help="Higher scores indicate easier readability"
177
- )
178
- with col_read2:
179
- st.metric(
180
- "Flesch-Kincaid Grade",
181
- f"{response['evaluation_metrics']['flesch_kincaid_grade']:.2f}",
182
- help="US grade level required to understand the text"
183
- )
184
-
185
- # Length and BLEU Metrics
186
- col_len1, col_len2, col_len3 = st.columns(3)
187
- with col_len1:
188
- st.metric("Query Length", response['evaluation_metrics']['query_length'])
189
- with col_len2:
190
- st.metric("Response Length", response['evaluation_metrics']['response_length'])
191
- with col_len3:
192
- st.metric(
193
- "BLEU Score",
194
- f"{response['evaluation_metrics']['average_bleu_score']:.4f}",
195
- help="Measures similarity to reference texts"
196
- )
197
-
198
- # ROUGE Scores
199
- st.markdown("#### πŸ“ˆ ROUGE Scores")
200
- rouge_metrics = response['evaluation_metrics']['rouge_scores']
201
- col_rouge1, col_rouge2, col_rouge3 = st.columns(3)
202
- with col_rouge1:
203
- st.metric("ROUGE-1", f"{rouge_metrics['rouge1']:.4f}")
204
- with col_rouge2:
205
- st.metric("ROUGE-2", f"{rouge_metrics['rouge2']:.4f}")
206
- with col_rouge3:
207
- st.metric("ROUGE-L", f"{rouge_metrics['rougeL']:.4f}")
208
-
209
- # Semantic Similarity
210
- st.markdown("#### πŸ” Semantic Similarity")
211
- sem_sim = response['evaluation_metrics']['semantic_similarity']
212
- col_sem1, col_sem2, col_sem3 = st.columns(3)
213
- with col_sem1:
214
- st.metric("Mean Similarity", f"{sem_sim['mean']:.4f}")
215
- with col_sem2:
216
- st.metric("Max Similarity", f"{sem_sim['max']:.4f}")
217
- with col_sem3:
218
- st.metric("Min Similarity", f"{sem_sim['min']:.4f}")
219
 
220
- except Exception as e:
221
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  if __name__ == "__main__":
224
  main()
 
7
  from langchain_groq import ChatGroq
8
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.memory import ConversationBufferMemory
 
10
  from transformers import pipeline
11
  from sentence_transformers import SentenceTransformer
12
  import tavily
13
 
 
 
 
 
 
 
 
14
  class AdvancedRAGChatbot:
15
  def __init__(self,
16
  tavily_api_key: str,
17
  embedding_model: str = "BAAI/bge-large-en-v1.5",
18
  llm_model: str = "llama-3.3-70b-versatile",
19
  temperature: float = 0.7):
20
+ """Initialize the Advanced RAG Chatbot with Tavily web search integration"""
21
+ # Set the Tavily API key as an environment variable
22
  os.environ["TAVILY_API_KEY"] = tavily_api_key
23
 
24
+ # Correct Tavily Client initialization
25
  self.tavily_client = tavily.TavilyClient(tavily_api_key)
26
 
27
  # NLP Components
 
30
  self.sentiment_analyzer = pipeline("sentiment-analysis")
31
  self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
32
 
 
 
 
33
  # Language Model Configuration
34
  self.llm = self._configure_llm(llm_model, temperature)
35
 
36
  # Conversation Memory
37
  self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
38
+
39
+ def _configure_embeddings(self, model_name: str):
40
+ """Configure embeddings with normalization"""
41
+ encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
42
+ return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
43
+
44
+ def _configure_llm(self, model_name: str, temperature: float):
45
+ """Configure the Language Model with Groq"""
46
+ return ChatGroq(
47
+ model_name=model_name,
48
+ temperature=temperature,
49
+ max_tokens=4096,
50
+ streaming=True
51
+ )
52
+
53
+ def _tavily_web_search(self, query: str, max_results: int = 5) -> List[Dict[str, str]]:
54
+ """Perform web search using Tavily API"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
+ search_result = self.tavily_client.search(
57
+ query=query,
58
+ max_results=max_results,
59
+ search_depth="advanced",
60
+ include_domains=[],
61
+ exclude_domains=[],
62
+ include_answer=True
63
+ )
64
+ return search_result.get('results', [])
65
  except Exception as e:
66
+ st.error(f"Tavily Search Error: {e}")
67
+ return []
68
+
 
 
69
  def process_query(self, query: str) -> Dict[str, Any]:
70
+ """Process the user query with web search and NLP techniques"""
71
  # Web Search
72
  web_results = self._tavily_web_search(query)
73
 
 
102
 
103
  # Generate Response
104
  response = self.llm.invoke(full_prompt)
 
 
 
 
 
 
 
 
105
 
106
  return {
107
+ "response": response.content,
108
  "web_sources": web_results,
109
  "semantic_similarity": semantic_score.tolist(),
110
  "sentiment": sentiment_result,
111
+ "named_entities": entities
 
112
  }
113
 
114
  def main():
115
+ # Page Configuration
116
+ st.set_page_config(
117
+ page_title="Web-Powered RAG Chatbot",
118
+ page_icon="🌐",
119
+ layout="wide",
120
+ initial_sidebar_state="expanded"
121
+ )
122
+
123
+ # Retrieve Tavily API Key from Environment Variable
124
+ tavily_api_key = os.getenv("TAVILY_API_KEY")
125
+
126
+ if not tavily_api_key:
127
+ st.warning("Tavily API Key is missing. Please set the 'TAVILY_API_KEY' environment variable.")
128
+ st.stop()
129
+
130
+ # Sidebar Configuration
131
+ with st.sidebar:
132
+ st.header("πŸ”§ Chatbot Settings")
133
+ st.markdown("Customize your AI assistant's behavior")
134
+
135
+ # Model Configuration
136
+ embedding_model = st.selectbox(
137
+ "Embedding Model",
138
+ ["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
139
+ )
140
+ temperature = st.slider("Creativity Level", 0.0, 1.0, 0.7, help="Higher values make responses more creative")
141
+
142
+ # Additional Controls
143
+ st.divider()
144
+ st.info("Powered by Tavily Web Search")
145
+
146
+ # Initialize Chatbot
147
+ chatbot = AdvancedRAGChatbot(
148
+ tavily_api_key=tavily_api_key,
149
+ embedding_model=embedding_model,
150
+ temperature=temperature
151
+ )
152
+
153
+ # Main Chat Interface
154
+ st.title("🌐 Web-Powered RAG Chatbot")
155
+
156
+ # Chat input with placeholder
157
+ user_input = st.text_area(
158
+ "Ask your question",
159
+ placeholder="Enter your query to search the web...",
160
+ height=250
161
+ )
162
+
163
+ # Submit button
164
+ submit_button = st.button("Search & Analyze", type="primary")
165
+
166
+ # Response container
167
+ if submit_button and user_input:
168
+ with st.spinner("Searching web and processing query..."):
169
+ try:
170
+ response = chatbot.process_query(user_input)
171
+
172
+ # Bot Response
173
+ st.markdown("#### AI's Answer")
174
+ st.write(response['response'])
 
 
 
175
 
176
+ # Sentiment Analysis
177
+ st.markdown("#### Sentiment Analysis")
178
+ sentiment = response['sentiment']
179
+ st.metric(
180
+ label="Sentiment",
181
+ value=sentiment['label'],
182
+ delta=f"{sentiment['score']:.2%}"
183
+ )
184
+
185
+ # Named Entities
186
+ st.markdown("#### Detected Entities")
187
+ if response['named_entities']:
188
+ for entity in response['named_entities']:
189
+ word = entity.get('word', 'Unknown')
190
+ entity_type = entity.get('entity_type', entity.get('entity', 'Unknown Type'))
191
+ st.text(f"{word} ({entity_type})")
192
+ else:
193
+ st.info("No entities detected")
194
+
195
+ # Web Sources
196
+ if response['web_sources']:
197
+ st.markdown("#### Web Sources")
198
+ for i, source in enumerate(response['web_sources'], 1):
199
+ with st.expander(f"Source {i}: {source.get('title', 'Untitled')}"):
200
+ st.write(source.get('content', 'No content available'))
201
+ if source.get('url'):
202
+ st.markdown(f"[Original Source]({source['url']})")
203
+
204
+ except Exception as e:
205
+ st.error(f"An error occurred: {e}")
206
+ else:
207
+ st.info("Enter a query to search the web and get an AI-powered response")
208
 
209
  if __name__ == "__main__":
210
  main()