invincible-jha commited on
Commit
4c650d8
·
verified ·
1 Parent(s): 4c88486

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -57
app.py CHANGED
@@ -1,63 +1,329 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ import os
4
+ from fastapi import FastAPI, HTTPException, Depends
5
+ from fastapi.security import OAuth2PasswordBearer
6
+ from sqlalchemy.orm import Session
7
+ from pydantic import BaseModel
8
+ from typing import List
9
+ import autogen
10
+ from crewai import Agent, Task, Crew, Process
11
  from huggingface_hub import InferenceClient
12
+ import redis
13
+ import json
14
+ import logging
15
+
16
+ from database import SessionLocal, engine, Base
17
+ from models import User, Query, Response
18
+ from auth import create_access_token, get_current_user
19
+
20
+ # Initialize FastAPI app
21
+ app = FastAPI(title="Zerodha Support System MVP")
22
+
23
+ # Set up logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Initialize LLM client
28
+ hf_client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
29
+
30
+ # Initialize Redis client
31
+ redis_client = redis.Redis(host='localhost', port=6379, db=0)
32
+
33
+ # AutoGen configuration
34
+ config_list = [{"model": "gpt-3.5-turbo"}]
35
+
36
+ # AutoGen Agents
37
+ query_analyzer = autogen.AssistantAgent(
38
+ name="QueryAnalyzer",
39
+ system_message="Analyze and categorize incoming customer queries for Zerodha support. Determine query priority and complexity.",
40
+ llm_config={"config_list": config_list},
41
+ )
42
+
43
+ compliance_agent = autogen.AssistantAgent(
44
+ name="ComplianceAgent",
45
+ system_message="Ensure all responses comply with financial regulations and Zerodha policies.",
46
+ llm_config={"config_list": config_list},
47
+ )
48
+
49
+ kb_manager = autogen.AssistantAgent(
50
+ name="KnowledgeBaseManager",
51
+ system_message="Update and organize Zerodha's knowledge base based on customer interactions.",
52
+ llm_config={"config_list": config_list},
53
+ )
54
+
55
+ sentiment_analyzer = autogen.AssistantAgent(
56
+ name="SentimentAnalyzer",
57
+ system_message="Analyze customer sentiment from interactions.",
58
+ llm_config={"config_list": config_list},
59
+ )
60
 
61
+ coordinator = autogen.AssistantAgent(
62
+ name="Coordinator",
63
+ system_message="Coordinate responses from different agents and synthesize a final response.",
64
+ llm_config={"config_list": config_list},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
+ # CrewAI Agents
68
+ account_specialist = Agent(
69
+ role='Account Specialist',
70
+ goal='Handle account-related queries and processes',
71
+ backstory='Expert in Zerodha\'s account management systems and procedures.',
72
+ verbose=True
73
+ )
74
+
75
+ trading_expert = Agent(
76
+ role='Trading Expert',
77
+ goal='Assist with trading-related questions and provide market insights',
78
+ backstory='Seasoned trader with deep knowledge of Zerodha\'s trading platforms.',
79
+ verbose=True
80
+ )
81
+
82
+ technical_support = Agent(
83
+ role='Technical Support',
84
+ goal='Troubleshoot platform issues and provide technical guidance',
85
+ backstory='Technical expert familiar with all Zerodha platforms and common issues.',
86
+ verbose=True
87
+ )
88
+
89
+ learning_dev = Agent(
90
+ role='Learning and Development',
91
+ goal='Design educational content and trading tutorials',
92
+ backstory='Educational expert specializing in financial literacy and trading education.',
93
+ verbose=True
94
+ )
95
+
96
+ product_specialist = Agent(
97
+ role='Product Specialist',
98
+ goal='Provide information on Zerodha\'s products and compare with competitors',
99
+ backstory='Expert in Zerodha\'s product line and the broader financial services market.',
100
+ verbose=True
101
+ )
102
+
103
+ # CrewAI Tasks
104
+ account_task = Task(
105
+ description='Handle account-related query and provide detailed guidance',
106
+ agent=account_specialist
107
+ )
108
+
109
+ trading_task = Task(
110
+ description='Address trading-related question and offer market insights',
111
+ agent=trading_expert
112
+ )
113
+
114
+ tech_support_task = Task(
115
+ description='Troubleshoot technical issue and provide step-by-step guidance',
116
+ agent=technical_support
117
+ )
118
+
119
+ learning_task = Task(
120
+ description='Create educational content based on user query and skill level',
121
+ agent=learning_dev
122
+ )
123
+
124
+ product_task = Task(
125
+ description='Provide product information and recommendations',
126
+ agent=product_specialist
127
+ )
128
+
129
+ # Create CrewAI Crew
130
+ zerodha_crew = Crew(
131
+ agents=[account_specialist, trading_expert, technical_support, learning_dev, product_specialist],
132
+ tasks=[account_task, trading_task, tech_support_task, learning_task, product_task],
133
+ verbose=2
134
+ )
135
+
136
+ # Pydantic models
137
+ class QueryInput(BaseModel):
138
+ text: str
139
+
140
+ class QueryOutput(BaseModel):
141
+ response: str
142
+ sentiment: str
143
+
144
+ # Dependency to get the database session
145
+ def get_db():
146
+ db = SessionLocal()
147
+ try:
148
+ yield db
149
+ finally:
150
+ db.close()
151
 
152
+ # Helper function to generate LLM response
153
+ def generate_llm_response(prompt):
154
+ return hf_client.text_generation(prompt, max_new_tokens=200, temperature=0.7)
155
+
156
+ # Helper function to check cache
157
+ def check_cache(query):
158
+ cached_response = redis_client.get(query)
159
+ if cached_response:
160
+ return json.loads(cached_response)
161
+ return None
162
+
163
+ # Helper function to update cache
164
+ def update_cache(query, response):
165
+ redis_client.setex(query, 3600, json.dumps(response)) # Cache for 1 hour
166
+
167
+ # Main query processing function
168
+ async def process_query(query: str, db: Session):
169
+ try:
170
+ # Check cache
171
+ cached_result = check_cache(query)
172
+ if cached_result:
173
+ logger.info(f"Cache hit for query: {query[:50]}...")
174
+ return cached_result
175
+
176
+ # Step 1: Query Analysis
177
+ analysis = query_analyzer.generate_response(f"Analyze this query: {query}")
178
+
179
+ # Step 2: Route to Appropriate Specialist Agents
180
+ specialist_responses = {}
181
+ if "account" in analysis.lower():
182
+ specialist_responses['account'] = account_specialist.execute(account_task, {"query": query})
183
+ if "trading" in analysis.lower():
184
+ specialist_responses['trading'] = trading_expert.execute(trading_task, {"query": query})
185
+ if "technical" in analysis.lower():
186
+ specialist_responses['technical'] = technical_support.execute(tech_support_task, {"query": query})
187
+ if "product" in analysis.lower():
188
+ specialist_responses['product'] = product_specialist.execute(product_task, {"query": query})
189
+
190
+ # Step 3: Compliance Check
191
+ for key in specialist_responses:
192
+ specialist_responses[key] = compliance_agent.generate_response(f"Ensure this response is compliant: {specialist_responses[key]}")
193
+
194
+ # Step 4: Coordinate Final Response
195
+ final_response = coordinator.generate_response(f"Synthesize these responses into a final answer: {specialist_responses}")
196
+
197
+ # Step 5: Sentiment Analysis
198
+ sentiment = sentiment_analyzer.generate_response(f"Analyze the sentiment of this interaction: Query: {query}, Response: {final_response}")
199
+
200
+ # Step 6: Update Knowledge Base
201
+ kb_manager.generate_response(f"Update knowledge base based on: Query: {query}, Response: {final_response}")
202
+
203
+ # Step 7: Generate Learning Content (if needed)
204
+ if "educational" in analysis.lower():
205
+ learning_dev.execute(learning_task, {"query": query, "response": final_response})
206
+
207
+ # Save query and response to database
208
+ db_query = Query(text=query)
209
+ db.add(db_query)
210
+ db.commit()
211
+ db.refresh(db_query)
212
+
213
+ db_response = Response(text=final_response, query_id=db_query.id)
214
+ db.add(db_response)
215
+ db.commit()
216
+
217
+ result = {"response": final_response, "sentiment": sentiment}
218
+
219
+ # Update cache
220
+ update_cache(query, result)
221
+
222
+ return result
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing query: {str(e)}", exc_info=True)
226
+ raise HTTPException(status_code=500, detail="An error occurred while processing your query")
227
+
228
+ # API Endpoints
229
+ @app.post("/query", response_model=QueryOutput)
230
+ async def handle_query(query: QueryInput, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
231
+ result = await process_query(query.text, db)
232
+ return QueryOutput(**result)
233
+
234
+ # Run the application
235
  if __name__ == "__main__":
236
+ import uvicorn
237
+ uvicorn.run(app, host="0.0.0.0", port=8000)
238
+
239
+ # models.py
240
+
241
+ from sqlalchemy import Column, Integer, String, ForeignKey
242
+ from sqlalchemy.orm import relationship
243
+ from database import Base
244
+
245
+ class User(Base):
246
+ __tablename__ = "users"
247
+
248
+ id = Column(Integer, primary_key=True, index=True)
249
+ username = Column(String, unique=True, index=True)
250
+ hashed_password = Column(String)
251
+
252
+ class Query(Base):
253
+ __tablename__ = "queries"
254
+
255
+ id = Column(Integer, primary_key=True, index=True)
256
+ text = Column(String)
257
+ responses = relationship("Response", back_populates="query")
258
+
259
+ class Response(Base):
260
+ __tablename__ = "responses"
261
+
262
+ id = Column(Integer, primary_key=True, index=True)
263
+ text = Column(String)
264
+ query_id = Column(Integer, ForeignKey("queries.id"))
265
+ query = relationship("Query", back_populates="responses")
266
+
267
+ # database.py
268
+
269
+ from sqlalchemy import create_engine
270
+ from sqlalchemy.ext.declarative import declarative_base
271
+ from sqlalchemy.orm import sessionmaker
272
+
273
+ SQLALCHEMY_DATABASE_URL = "sqlite:///./zerodha_support.db"
274
+
275
+ engine = create_engine(
276
+ SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
277
+ )
278
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
279
+
280
+ Base = declarative_base()
281
+
282
+ # auth.py
283
+
284
+ from datetime import datetime, timedelta
285
+ from jose import JWTError, jwt
286
+ from passlib.context import CryptContext
287
+ from fastapi import Depends, HTTPException, status
288
+ from fastapi.security import OAuth2PasswordBearer
289
+ from sqlalchemy.orm import Session
290
+ from models import User
291
+ from database import get_db
292
+
293
+ SECRET_KEY = "your-secret-key"
294
+ ALGORITHM = "HS256"
295
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
296
+
297
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
298
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
299
+
300
+ def verify_password(plain_password, hashed_password):
301
+ return pwd_context.verify(plain_password, hashed_password)
302
+
303
+ def get_password_hash(password):
304
+ return pwd_context.hash(password)
305
+
306
+ def create_access_token(data: dict):
307
+ to_encode = data.copy()
308
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
309
+ to_encode.update({"exp": expire})
310
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
311
+ return encoded_jwt
312
+
313
+ def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
314
+ credentials_exception = HTTPException(
315
+ status_code=status.HTTP_401_UNAUTHORIZED,
316
+ detail="Could not validate credentials",
317
+ headers={"WWW-Authenticate": "Bearer"},
318
+ )
319
+ try:
320
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
321
+ username: str = payload.get("sub")
322
+ if username is None:
323
+ raise credentials_exception
324
+ except JWTError:
325
+ raise credentials_exception
326
+ user = db.query(User).filter(User.username == username).first()
327
+ if user is None:
328
+ raise credentials_exception
329
+ return user