File size: 3,714 Bytes
5b2aed4
 
 
f46801a
 
 
 
 
414a697
9483f8f
7d3b780
9483f8f
7d3b780
48c0d8d
 
0ea8c80
 
cfcb5f3
 
 
 
 
7d3b780
 
48c0d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3b780
9483f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3b780
 
9483f8f
 
 
 
06b60e9
9483f8f
 
 
 
 
2f4ecd3
9483f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a244e8
09bef47
0ea8c80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import logging
import gradio as gr
import re
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
import asyncio
from crewai import Agent, Task, Crew
from huggingface_hub import InferenceClient
import random
import json
import warnings

# Suppress all deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def get_huggingface_api_token():
    token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
    if token:
        logger.info("Hugging Face API token found in environment variables.")
        return token
    try:
        with open('config.json', 'r') as config_file:
            config = json.load(config_file)
            token = config.get('HUGGINGFACEHUB_API_TOKEN')
            if token:
                logger.info("Hugging Face API token found in config.json file.")
                return token
    except FileNotFoundError:
        logger.warning("Config file not found.")
    except json.JSONDecodeError:
        logger.error("Error reading the config file. Please check its format.")
    logger.error("Hugging Face API token not found. Please set it up.")
    return None

token = get_huggingface_api_token()
if not token:
    logger.error("Hugging Face API token is not set. Exiting.")
    sys.exit(1)

hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token)

vectorizer = CountVectorizer()
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support']
X = vectorizer.fit_transform(approved_topics)
classifier = MultinomialNB()
classifier.fit(X, np.arange(len(approved_topics)))

class CommunicationExpertAgent(Agent):
    async def run(self, query):
        sanitized_query = re.sub(r'[<>&\']', '', query)
        topic_relevance = classifier.predict(vectorizer.transform([sanitized_query]))[0] in range(len(approved_topics))
        if not topic_relevance:
            return "Query not relevant to our services."
        emotional_context = "Identified emotional context"  # Simulate emotional context analysis
        rephrased_query = f"Rephrased with empathy: {sanitized_query} - {emotional_context}"
        return rephrased_query

class ResponseExpertAgent(Agent):
    async def run(self, rephrased_query):
        response = await hf_client.text_generation(rephrased_query, max_new_tokens=500, temperature=0.7)
        return response['generated_text']

class PostprocessingAgent(Agent):
    def run(self, response):
        response += "\n\nThank you for contacting Zerodha. Is there anything else we can help with?"
        return response

# Instantiate agents
communication_expert = CommunicationExpertAgent()
response_expert = ResponseExpertAgent()
postprocessing_agent = PostprocessingAgent()

async def handle_query(query):
    rephrased_query = await communication_expert.run(query)
    response = await response_expert.run(rephrased_query)
    final_response = postprocessing_agent.run(response)
    return final_response

# Gradio interface setup
def setup_interface():
    with gr.Blocks() as app:
        with gr.Row():
            query_input = gr.Textbox(label="Enter your query")
            submit_button = gr.Button("Submit")
            response_output = gr.Textbox(label="Response")
        submit_button.click(
            fn=lambda x: asyncio.run(handle_query(x)),
            inputs=[query_input],
            outputs=[response_output]
        )
    return app

app = setup_interface()

if __name__ == "__main__":
    app.launch()