Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
import gradio as gr | |
import re | |
import numpy as np | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
import asyncio | |
from crewai import Agent | |
from huggingface_hub import InferenceClient | |
import random | |
import json | |
import warnings | |
from typing import Literal | |
# Suppress all deprecation warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_huggingface_api_token(): | |
token = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in environment variables.") | |
return token | |
try: | |
with open('config.json', 'r') as config_file: | |
config = json.load(config_file) | |
token = config.get('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in config.json file.") | |
return token | |
except FileNotFoundError: | |
logger.warning("Config file not found.") | |
except json.JSONDecodeError: | |
logger.error("Error reading the config file. Please check its format.") | |
logger.error("Hugging Face API token not found. Please set it up.") | |
return None | |
token = get_huggingface_api_token() | |
if not token: | |
logger.error("Hugging Face API token is not set. Exiting.") | |
sys.exit(1) | |
hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token) | |
vectorizer = CountVectorizer() | |
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support'] | |
X = vectorizer.fit_transform(approved_topics) | |
classifier = MultinomialNB() | |
classifier.fit(X, np.arange(len(approved_topics))) | |
class CommunicationExpertAgent(Agent): | |
role: Literal["Communication Expert"] = "Communication Expert" | |
goal: Literal["To interpret and rephrase user queries with empathy and respect"] = "To interpret and rephrase user queries with empathy and respect" | |
backstory: Literal["You are an expert in communication, specializing in understanding and rephrasing queries to ensure they are interpreted in the most positive and constructive light. Your role is crucial in setting the tone for respectful and empathetic interactions."] = \ | |
"You are an expert in communication, specializing in understanding and rephrasing queries to ensure they are interpreted in the most positive and constructive light. Your role is crucial in setting the tone for respectful and empathetic interactions." | |
async def run(self, query): | |
sanitized_query = re.sub(r'[<>&\']', '', query) | |
topic_relevance = classifier.predict(vectorizer.transform([sanitized_query]))[0] in range(len(approved_topics)) | |
if not topic_relevance: | |
return "Query not relevant to our services." | |
emotional_context = "Identified emotional context" | |
rephrased_query = f"Rephrased with empathy: {sanitized_query} - {emotional_context}" | |
return rephrased_query | |
class ResponseExpertAgent(Agent): | |
role: Literal["Response Expert"] = "Response Expert" | |
goal: Literal["To provide accurate, helpful, and emotionally intelligent responses to user queries"] = "To provide accurate, helpful, and emotionally intelligent responses to user queries" | |
backstory: Literal["You are an expert in Zerodha's services and policies, with a keen ability to provide comprehensive and empathetic responses. Your role is to ensure that all user queries are addressed accurately while maintaining a respectful and supportive tone."] = \ | |
"You are an expert in Zerodha's services and policies, with a keen ability to provide comprehensive and empathetic responses. Your role is to ensure that all user queries are addressed accurately while maintaining a respectful and supportive tone." | |
async def run(self, rephrased_query): | |
try: | |
logger.info(f"Sending query for generation: {rephrased_query}") | |
response = await hf_client.text_generation(rephrased_query, max_new_tokens=500, temperature=0.7) | |
return response['generated_text'] | |
except Exception as e: | |
logger.error(f"Failed to generate text due to: {str(e)}") | |
return "Error in generating response. Please try again." | |
class PostprocessingAgent(Agent): | |
role: Literal["Postprocessing Expert"] = "Postprocessing Expert" | |
goal: Literal["To enhance and finalize responses ensuring quality and completeness"] = "To enhance and finalize responses ensuring quality and completeness" | |
backstory: Literal["You are responsible for finalizing communications, adding polite terminations, and ensuring that the responses meet the quality standards expected in customer interactions."] = \ | |
"You are responsible for finalizing communications, adding polite terminations, and ensuring that the responses meet the quality standards expected in customer interactions." | |
def run(self, response): | |
response += "\n\nThank you for contacting Zerodha. Is there anything else we can help with?" | |
return response | |
# Instantiate agents | |
communication_expert = CommunicationExpertAgent() | |
response_expert = ResponseExpertAgent() | |
postprocessing_agent = PostprocessingAgent() | |
async def handle_query(query): | |
rephrased_query = await communication_expert.run(query) | |
response = await response_expert.run(rephrased_query) | |
final_response = postprocessing_agent.run(response) | |
return final_response | |
# Gradio interface setup | |
def setup_interface(): | |
with gr.Blocks() as app: | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter your query") | |
submit_button = gr.Button("Submit") | |
response_output = gr.Textbox(label="Response") | |
submit_button.click( | |
fn=lambda x: asyncio.run(handle_query(x)), | |
inputs=[query_input], | |
outputs=[response_output] | |
) | |
return app | |
app = setup_interface() | |
if __name__ == "__main__": | |
app.launch() | |