Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,63 +1,329 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
def respond(
|
11 |
-
message,
|
12 |
-
history: list[tuple[str, str]],
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
-
|
20 |
-
for val in history:
|
21 |
-
if val[0]:
|
22 |
-
messages.append({"role": "user", "content": val[0]})
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
-
|
26 |
-
messages.append({"role": "user", "content": message})
|
27 |
-
|
28 |
-
response = ""
|
29 |
-
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
"""
|
43 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
-
"""
|
45 |
-
demo = gr.ChatInterface(
|
46 |
-
respond,
|
47 |
-
additional_inputs=[
|
48 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
-
],
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
if __name__ == "__main__":
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from fastapi import FastAPI, HTTPException, Depends
|
5 |
+
from fastapi.security import OAuth2PasswordBearer
|
6 |
+
from sqlalchemy.orm import Session
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from typing import List
|
9 |
+
import autogen
|
10 |
+
from crewai import Agent, Task, Crew, Process
|
11 |
from huggingface_hub import InferenceClient
|
12 |
+
import redis
|
13 |
+
import json
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from database import SessionLocal, engine, Base
|
17 |
+
from models import User, Query, Response
|
18 |
+
from auth import create_access_token, get_current_user
|
19 |
+
|
20 |
+
# Initialize FastAPI app
|
21 |
+
app = FastAPI(title="Zerodha Support System MVP")
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
# Initialize LLM client
|
28 |
+
hf_client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
|
29 |
+
|
30 |
+
# Initialize Redis client
|
31 |
+
redis_client = redis.Redis(host='localhost', port=6379, db=0)
|
32 |
+
|
33 |
+
# AutoGen configuration
|
34 |
+
config_list = [{"model": "gpt-3.5-turbo"}]
|
35 |
+
|
36 |
+
# AutoGen Agents
|
37 |
+
query_analyzer = autogen.AssistantAgent(
|
38 |
+
name="QueryAnalyzer",
|
39 |
+
system_message="Analyze and categorize incoming customer queries for Zerodha support. Determine query priority and complexity.",
|
40 |
+
llm_config={"config_list": config_list},
|
41 |
+
)
|
42 |
+
|
43 |
+
compliance_agent = autogen.AssistantAgent(
|
44 |
+
name="ComplianceAgent",
|
45 |
+
system_message="Ensure all responses comply with financial regulations and Zerodha policies.",
|
46 |
+
llm_config={"config_list": config_list},
|
47 |
+
)
|
48 |
+
|
49 |
+
kb_manager = autogen.AssistantAgent(
|
50 |
+
name="KnowledgeBaseManager",
|
51 |
+
system_message="Update and organize Zerodha's knowledge base based on customer interactions.",
|
52 |
+
llm_config={"config_list": config_list},
|
53 |
+
)
|
54 |
+
|
55 |
+
sentiment_analyzer = autogen.AssistantAgent(
|
56 |
+
name="SentimentAnalyzer",
|
57 |
+
system_message="Analyze customer sentiment from interactions.",
|
58 |
+
llm_config={"config_list": config_list},
|
59 |
+
)
|
60 |
|
61 |
+
coordinator = autogen.AssistantAgent(
|
62 |
+
name="Coordinator",
|
63 |
+
system_message="Coordinate responses from different agents and synthesize a final response.",
|
64 |
+
llm_config={"config_list": config_list},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
|
67 |
+
# CrewAI Agents
|
68 |
+
account_specialist = Agent(
|
69 |
+
role='Account Specialist',
|
70 |
+
goal='Handle account-related queries and processes',
|
71 |
+
backstory='Expert in Zerodha\'s account management systems and procedures.',
|
72 |
+
verbose=True
|
73 |
+
)
|
74 |
+
|
75 |
+
trading_expert = Agent(
|
76 |
+
role='Trading Expert',
|
77 |
+
goal='Assist with trading-related questions and provide market insights',
|
78 |
+
backstory='Seasoned trader with deep knowledge of Zerodha\'s trading platforms.',
|
79 |
+
verbose=True
|
80 |
+
)
|
81 |
+
|
82 |
+
technical_support = Agent(
|
83 |
+
role='Technical Support',
|
84 |
+
goal='Troubleshoot platform issues and provide technical guidance',
|
85 |
+
backstory='Technical expert familiar with all Zerodha platforms and common issues.',
|
86 |
+
verbose=True
|
87 |
+
)
|
88 |
+
|
89 |
+
learning_dev = Agent(
|
90 |
+
role='Learning and Development',
|
91 |
+
goal='Design educational content and trading tutorials',
|
92 |
+
backstory='Educational expert specializing in financial literacy and trading education.',
|
93 |
+
verbose=True
|
94 |
+
)
|
95 |
+
|
96 |
+
product_specialist = Agent(
|
97 |
+
role='Product Specialist',
|
98 |
+
goal='Provide information on Zerodha\'s products and compare with competitors',
|
99 |
+
backstory='Expert in Zerodha\'s product line and the broader financial services market.',
|
100 |
+
verbose=True
|
101 |
+
)
|
102 |
+
|
103 |
+
# CrewAI Tasks
|
104 |
+
account_task = Task(
|
105 |
+
description='Handle account-related query and provide detailed guidance',
|
106 |
+
agent=account_specialist
|
107 |
+
)
|
108 |
+
|
109 |
+
trading_task = Task(
|
110 |
+
description='Address trading-related question and offer market insights',
|
111 |
+
agent=trading_expert
|
112 |
+
)
|
113 |
+
|
114 |
+
tech_support_task = Task(
|
115 |
+
description='Troubleshoot technical issue and provide step-by-step guidance',
|
116 |
+
agent=technical_support
|
117 |
+
)
|
118 |
+
|
119 |
+
learning_task = Task(
|
120 |
+
description='Create educational content based on user query and skill level',
|
121 |
+
agent=learning_dev
|
122 |
+
)
|
123 |
+
|
124 |
+
product_task = Task(
|
125 |
+
description='Provide product information and recommendations',
|
126 |
+
agent=product_specialist
|
127 |
+
)
|
128 |
+
|
129 |
+
# Create CrewAI Crew
|
130 |
+
zerodha_crew = Crew(
|
131 |
+
agents=[account_specialist, trading_expert, technical_support, learning_dev, product_specialist],
|
132 |
+
tasks=[account_task, trading_task, tech_support_task, learning_task, product_task],
|
133 |
+
verbose=2
|
134 |
+
)
|
135 |
+
|
136 |
+
# Pydantic models
|
137 |
+
class QueryInput(BaseModel):
|
138 |
+
text: str
|
139 |
+
|
140 |
+
class QueryOutput(BaseModel):
|
141 |
+
response: str
|
142 |
+
sentiment: str
|
143 |
+
|
144 |
+
# Dependency to get the database session
|
145 |
+
def get_db():
|
146 |
+
db = SessionLocal()
|
147 |
+
try:
|
148 |
+
yield db
|
149 |
+
finally:
|
150 |
+
db.close()
|
151 |
|
152 |
+
# Helper function to generate LLM response
|
153 |
+
def generate_llm_response(prompt):
|
154 |
+
return hf_client.text_generation(prompt, max_new_tokens=200, temperature=0.7)
|
155 |
+
|
156 |
+
# Helper function to check cache
|
157 |
+
def check_cache(query):
|
158 |
+
cached_response = redis_client.get(query)
|
159 |
+
if cached_response:
|
160 |
+
return json.loads(cached_response)
|
161 |
+
return None
|
162 |
+
|
163 |
+
# Helper function to update cache
|
164 |
+
def update_cache(query, response):
|
165 |
+
redis_client.setex(query, 3600, json.dumps(response)) # Cache for 1 hour
|
166 |
+
|
167 |
+
# Main query processing function
|
168 |
+
async def process_query(query: str, db: Session):
|
169 |
+
try:
|
170 |
+
# Check cache
|
171 |
+
cached_result = check_cache(query)
|
172 |
+
if cached_result:
|
173 |
+
logger.info(f"Cache hit for query: {query[:50]}...")
|
174 |
+
return cached_result
|
175 |
+
|
176 |
+
# Step 1: Query Analysis
|
177 |
+
analysis = query_analyzer.generate_response(f"Analyze this query: {query}")
|
178 |
+
|
179 |
+
# Step 2: Route to Appropriate Specialist Agents
|
180 |
+
specialist_responses = {}
|
181 |
+
if "account" in analysis.lower():
|
182 |
+
specialist_responses['account'] = account_specialist.execute(account_task, {"query": query})
|
183 |
+
if "trading" in analysis.lower():
|
184 |
+
specialist_responses['trading'] = trading_expert.execute(trading_task, {"query": query})
|
185 |
+
if "technical" in analysis.lower():
|
186 |
+
specialist_responses['technical'] = technical_support.execute(tech_support_task, {"query": query})
|
187 |
+
if "product" in analysis.lower():
|
188 |
+
specialist_responses['product'] = product_specialist.execute(product_task, {"query": query})
|
189 |
+
|
190 |
+
# Step 3: Compliance Check
|
191 |
+
for key in specialist_responses:
|
192 |
+
specialist_responses[key] = compliance_agent.generate_response(f"Ensure this response is compliant: {specialist_responses[key]}")
|
193 |
+
|
194 |
+
# Step 4: Coordinate Final Response
|
195 |
+
final_response = coordinator.generate_response(f"Synthesize these responses into a final answer: {specialist_responses}")
|
196 |
+
|
197 |
+
# Step 5: Sentiment Analysis
|
198 |
+
sentiment = sentiment_analyzer.generate_response(f"Analyze the sentiment of this interaction: Query: {query}, Response: {final_response}")
|
199 |
+
|
200 |
+
# Step 6: Update Knowledge Base
|
201 |
+
kb_manager.generate_response(f"Update knowledge base based on: Query: {query}, Response: {final_response}")
|
202 |
+
|
203 |
+
# Step 7: Generate Learning Content (if needed)
|
204 |
+
if "educational" in analysis.lower():
|
205 |
+
learning_dev.execute(learning_task, {"query": query, "response": final_response})
|
206 |
+
|
207 |
+
# Save query and response to database
|
208 |
+
db_query = Query(text=query)
|
209 |
+
db.add(db_query)
|
210 |
+
db.commit()
|
211 |
+
db.refresh(db_query)
|
212 |
+
|
213 |
+
db_response = Response(text=final_response, query_id=db_query.id)
|
214 |
+
db.add(db_response)
|
215 |
+
db.commit()
|
216 |
+
|
217 |
+
result = {"response": final_response, "sentiment": sentiment}
|
218 |
+
|
219 |
+
# Update cache
|
220 |
+
update_cache(query, result)
|
221 |
+
|
222 |
+
return result
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
logger.error(f"Error processing query: {str(e)}", exc_info=True)
|
226 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing your query")
|
227 |
+
|
228 |
+
# API Endpoints
|
229 |
+
@app.post("/query", response_model=QueryOutput)
|
230 |
+
async def handle_query(query: QueryInput, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
231 |
+
result = await process_query(query.text, db)
|
232 |
+
return QueryOutput(**result)
|
233 |
+
|
234 |
+
# Run the application
|
235 |
if __name__ == "__main__":
|
236 |
+
import uvicorn
|
237 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
238 |
+
|
239 |
+
# models.py
|
240 |
+
|
241 |
+
from sqlalchemy import Column, Integer, String, ForeignKey
|
242 |
+
from sqlalchemy.orm import relationship
|
243 |
+
from database import Base
|
244 |
+
|
245 |
+
class User(Base):
|
246 |
+
__tablename__ = "users"
|
247 |
+
|
248 |
+
id = Column(Integer, primary_key=True, index=True)
|
249 |
+
username = Column(String, unique=True, index=True)
|
250 |
+
hashed_password = Column(String)
|
251 |
+
|
252 |
+
class Query(Base):
|
253 |
+
__tablename__ = "queries"
|
254 |
+
|
255 |
+
id = Column(Integer, primary_key=True, index=True)
|
256 |
+
text = Column(String)
|
257 |
+
responses = relationship("Response", back_populates="query")
|
258 |
+
|
259 |
+
class Response(Base):
|
260 |
+
__tablename__ = "responses"
|
261 |
+
|
262 |
+
id = Column(Integer, primary_key=True, index=True)
|
263 |
+
text = Column(String)
|
264 |
+
query_id = Column(Integer, ForeignKey("queries.id"))
|
265 |
+
query = relationship("Query", back_populates="responses")
|
266 |
+
|
267 |
+
# database.py
|
268 |
+
|
269 |
+
from sqlalchemy import create_engine
|
270 |
+
from sqlalchemy.ext.declarative import declarative_base
|
271 |
+
from sqlalchemy.orm import sessionmaker
|
272 |
+
|
273 |
+
SQLALCHEMY_DATABASE_URL = "sqlite:///./zerodha_support.db"
|
274 |
+
|
275 |
+
engine = create_engine(
|
276 |
+
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
277 |
+
)
|
278 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
279 |
+
|
280 |
+
Base = declarative_base()
|
281 |
+
|
282 |
+
# auth.py
|
283 |
+
|
284 |
+
from datetime import datetime, timedelta
|
285 |
+
from jose import JWTError, jwt
|
286 |
+
from passlib.context import CryptContext
|
287 |
+
from fastapi import Depends, HTTPException, status
|
288 |
+
from fastapi.security import OAuth2PasswordBearer
|
289 |
+
from sqlalchemy.orm import Session
|
290 |
+
from models import User
|
291 |
+
from database import get_db
|
292 |
+
|
293 |
+
SECRET_KEY = "your-secret-key"
|
294 |
+
ALGORITHM = "HS256"
|
295 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
296 |
+
|
297 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
298 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
299 |
+
|
300 |
+
def verify_password(plain_password, hashed_password):
|
301 |
+
return pwd_context.verify(plain_password, hashed_password)
|
302 |
+
|
303 |
+
def get_password_hash(password):
|
304 |
+
return pwd_context.hash(password)
|
305 |
+
|
306 |
+
def create_access_token(data: dict):
|
307 |
+
to_encode = data.copy()
|
308 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
309 |
+
to_encode.update({"exp": expire})
|
310 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
311 |
+
return encoded_jwt
|
312 |
+
|
313 |
+
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
|
314 |
+
credentials_exception = HTTPException(
|
315 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
316 |
+
detail="Could not validate credentials",
|
317 |
+
headers={"WWW-Authenticate": "Bearer"},
|
318 |
+
)
|
319 |
+
try:
|
320 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
321 |
+
username: str = payload.get("sub")
|
322 |
+
if username is None:
|
323 |
+
raise credentials_exception
|
324 |
+
except JWTError:
|
325 |
+
raise credentials_exception
|
326 |
+
user = db.query(User).filter(User.username == username).first()
|
327 |
+
if user is None:
|
328 |
+
raise credentials_exception
|
329 |
+
return user
|