Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
import gradio as gr | |
import re | |
import numpy as np | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
import asyncio | |
from crewai import Agent, Task, Crew | |
from huggingface_hub import InferenceClient | |
import random | |
import json | |
import warnings | |
# Suppress all deprecation warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_huggingface_api_token(): | |
token = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in environment variables.") | |
return token | |
try: | |
with open('config.json', 'r') as config_file: | |
config = json.load(config_file) | |
token = config.get('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in config.json file.") | |
return token | |
except FileNotFoundError: | |
logger.warning("Config file not found.") | |
except json.JSONDecodeError: | |
logger.error("Error reading the config file. Please check its format.") | |
logger.error("Hugging Face API token not found. Please set it up.") | |
return None | |
token = get_huggingface_api_token() | |
if not token: | |
logger.error("Hugging Face API token is not set. Exiting.") | |
sys.exit(1) | |
hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token) | |
vectorizer = CountVectorizer() | |
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support'] | |
X = vectorizer.fit_transform(approved_topics) | |
classifier = MultinomialNB() | |
classifier.fit(X, np.arange(len(approved_topics))) | |
class CommunicationExpertAgent(Agent): | |
async def run(self, query): | |
sanitized_query = re.sub(r'[<>&\']', '', query) | |
topic_relevance = classifier.predict(vectorizer.transform([sanitized_query]))[0] in range(len(approved_topics)) | |
if not topic_relevance: | |
return "Query not relevant to our services." | |
emotional_context = "Identified emotional context" # Simulate emotional context analysis | |
rephrased_query = f"Rephrased with empathy: {sanitized_query} - {emotional_context}" | |
return rephrased_query | |
class ResponseExpertAgent(Agent): | |
async def run(self, rephrased_query): | |
response = await hf_client.text_generation(rephrased_query, max_new_tokens=500, temperature=0.7) | |
return response['generated_text'] | |
class PostprocessingAgent(Agent): | |
def run(self, response): | |
response += "\n\nThank you for contacting Zerodha. Is there anything else we can help with?" | |
return response | |
# Instantiate agents | |
communication_expert = CommunicationExpertAgent() | |
response_expert = ResponseExpertAgent() | |
postprocessing_agent = PostprocessingAgent() | |
async def handle_query(query): | |
rephrased_query = await communication_expert.run(query) | |
response = await response_expert.run(rephrased_query) | |
final_response = postprocessing_agent.run(response) | |
return final_response | |
# Gradio interface setup | |
def setup_interface(): | |
with gr.Blocks() as app: | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter your query") | |
submit_button = gr.Button("Submit") | |
response_output = gr.Textbox(label="Response") | |
submit_button.click( | |
fn=lambda x: asyncio.run(handle_query(x)), | |
inputs=[query_input], | |
outputs=[response_output] | |
) | |
return app | |
app = setup_interface() | |
if __name__ == "__main__": | |
app.launch() | |