invincible-jha commited on
Commit
9483f8f
·
verified ·
1 Parent(s): 0ea8c80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -94
app.py CHANGED
@@ -7,7 +7,9 @@ import numpy as np
7
  from sklearn.feature_extraction.text import CountVectorizer
8
  from sklearn.naive_bayes import MultinomialNB
9
  import asyncio
 
10
  from huggingface_hub import InferenceClient
 
11
  import json
12
  import warnings
13
 
@@ -19,12 +21,10 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
19
  logger = logging.getLogger(__name__)
20
 
21
  def get_huggingface_api_token():
22
- """ Retrieves the Hugging Face API token from environment variables or a config file. """
23
  token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
24
  if token:
25
  logger.info("Hugging Face API token found in environment variables.")
26
  return token
27
-
28
  try:
29
  with open('config.json', 'r') as config_file:
30
  config = json.load(config_file)
@@ -36,105 +36,68 @@ def get_huggingface_api_token():
36
  logger.warning("Config file not found.")
37
  except json.JSONDecodeError:
38
  logger.error("Error reading the config file. Please check its format.")
39
-
40
  logger.error("Hugging Face API token not found. Please set it up.")
41
  return None
42
 
43
- def initialize_hf_client():
44
- """ Initializes the Hugging Face Inference Client with the API token. """
45
- try:
46
- hf_token = get_huggingface_api_token()
47
- if not hf_token:
48
- raise ValueError("Hugging Face API token is not set. Please set it up before running the application.")
49
- client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=hf_token)
50
- logger.info("Hugging Face Inference Client initialized successfully.")
51
- return client
52
- except Exception as e:
53
- logger.error(f"Failed to initialize Hugging Face client: {e}")
54
- sys.exit(1)
55
-
56
- client = initialize_hf_client()
57
-
58
- def sanitize_input(input_text):
59
- """ Sanitizes input text by removing specific characters. """
60
- return re.sub(r'[<>&\']', '', input_text)
61
-
62
- def setup_classifier():
63
- """ Sets up and trains a classifier for checking the relevance of a query. """
64
- approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support']
65
- vectorizer = CountVectorizer()
66
- X = vectorizer.fit_transform(approved_topics)
67
- y = np.arange(len(approved_topics))
68
- classifier = MultinomialNB()
69
- classifier.fit(X, y)
70
- return vectorizer, classifier
71
-
72
- vectorizer, classifier = setup_classifier()
73
-
74
- def is_relevant_topic(query):
75
- """ Checks if the query is relevant based on pre-defined topics. """
76
- query_vector = vectorizer.transform([query])
77
- prediction = classifier.predict(query_vector)
78
- return prediction[0] in range(len(approved_topics))
79
-
80
- def redact_sensitive_info(text):
81
- """ Redacts sensitive information from the text. """
82
- text = re.sub(r'\b\d{10,12}\b', '[REDACTED]', text)
83
- text = re.sub(r'[A-Z]{5}[0-9]{4}[A-Z]', '[REDACTED]', text)
84
- return text
85
-
86
- def check_response_content(response):
87
- """ Checks response content for unauthorized claims or advice. """
88
- unauthorized_patterns = [
89
- r'\b(guarantee|assured|certain)\b.*\b(returns|profit)\b',
90
- r'\b(buy|sell)\b.*\b(specific stocks?|shares?)\b'
91
- ]
92
- return not any(re.search(pattern, response, re.IGNORECASE) for pattern in unauthorized_patterns)
93
-
94
- async def generate_response(prompt):
95
- """ Generates a response using the Hugging Face inference client. """
96
- try:
97
- response = await client.text_generation(prompt, max_new_tokens=500, temperature=0.7)
98
  return response
99
- except Exception as e:
100
- logger.error(f"Error generating response: {e}")
101
- return "I apologize, but I'm having trouble generating a response at the moment. Please try again later."
102
 
103
- def post_process_response(response):
104
- """ Post-processes the response to ensure it ends with helpful suggestions. """
105
- response = re.sub(r'\b(stupid|dumb|idiotic|foolish)\b', 'mistaken', response, flags=re.IGNORECASE)
106
-
107
- if not re.search(r'(Thank you|Is there anything else|Hope this helps|Let me know if you need more information)\s*$', response, re.IGNORECASE):
108
- response += "\n\nIs there anything else I can help you with regarding Zerodha's services?"
109
-
110
- if re.search(r'\b(invest|trade|buy|sell|market)\b', response, re.IGNORECASE):
111
- response += "\n\nPlease note that this information is for educational purposes only and should not be considered as financial advice. Always do your own research and consider consulting with a qualified financial advisor before making investment decisions."
112
-
113
- return response
114
 
115
- # Gradio interface setup
116
- with gr.Blocks() as app:
117
- with gr.Row():
118
- username = gr.Textbox(label="Username")
119
- password = gr.Textbox(label="Password", type="password")
120
- login_button = gr.Button("Login")
121
-
122
- with gr.Row():
123
- query_input = gr.Textbox(label="Enter your query")
124
- submit_button = gr.Button("Submit")
125
- response_output = gr.Textbox(label="Response")
126
-
127
- login_button.click(
128
- fn=lambda u, p: "Login successful" if u == "admin" and p == "admin" else "Login failed",
129
- inputs=[username, password],
130
- outputs=[gr.Text(label="Login status")]
131
- )
132
 
133
- submit_button.click(
134
- fn=lambda x: asyncio.run(generate_response(x)),
135
- inputs=[query_input],
136
- outputs=[response_output]
137
- )
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
  app.launch()
 
7
  from sklearn.feature_extraction.text import CountVectorizer
8
  from sklearn.naive_bayes import MultinomialNB
9
  import asyncio
10
+ from crewai import Agent, Task, Crew
11
  from huggingface_hub import InferenceClient
12
+ import random
13
  import json
14
  import warnings
15
 
 
21
  logger = logging.getLogger(__name__)
22
 
23
  def get_huggingface_api_token():
 
24
  token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
25
  if token:
26
  logger.info("Hugging Face API token found in environment variables.")
27
  return token
 
28
  try:
29
  with open('config.json', 'r') as config_file:
30
  config = json.load(config_file)
 
36
  logger.warning("Config file not found.")
37
  except json.JSONDecodeError:
38
  logger.error("Error reading the config file. Please check its format.")
 
39
  logger.error("Hugging Face API token not found. Please set it up.")
40
  return None
41
 
42
+ token = get_huggingface_api_token()
43
+ if not token:
44
+ logger.error("Hugging Face API token is not set. Exiting.")
45
+ sys.exit(1)
46
+
47
+ hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token)
48
+
49
+ vectorizer = CountVectorizer()
50
+ approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support']
51
+ X = vectorizer.fit_transform(approved_topics)
52
+ classifier = MultinomialNB()
53
+ classifier.fit(X, np.arange(len(approved_topics)))
54
+
55
+ class CommunicationExpertAgent(Agent):
56
+ async def run(self, query):
57
+ sanitized_query = re.sub(r'[<>&\']', '', query)
58
+ topic_relevance = classifier.predict(vectorizer.transform([sanitized_query]))[0] in range(len(approved_topics))
59
+ if not topic_relevance:
60
+ return "Query not relevant to our services."
61
+ emotional_context = "Identified emotional context" # Simulate emotional context analysis
62
+ rephrased_query = f"Rephrased with empathy: {sanitized_query} - {emotional_context}"
63
+ return rephrased_query
64
+
65
+ class ResponseExpertAgent(Agent):
66
+ async def run(self, rephrased_query):
67
+ response = await hf_client.text_generation(rephrased_query, max_new_tokens=500, temperature=0.7)
68
+ return response['generated_text']
69
+
70
+ class PostprocessingAgent(Agent):
71
+ def run(self, response):
72
+ response += "\n\nThank you for contacting Zerodha. Is there anything else we can help with?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return response
 
 
 
74
 
75
+ # Instantiate agents
76
+ communication_expert = CommunicationExpertAgent()
77
+ response_expert = ResponseExpertAgent()
78
+ postprocessing_agent = PostprocessingAgent()
 
 
 
 
 
 
 
79
 
80
+ async def handle_query(query):
81
+ rephrased_query = await communication_expert.run(query)
82
+ response = await response_expert.run(rephrased_query)
83
+ final_response = postprocessing_agent.run(response)
84
+ return final_response
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Gradio interface setup
87
+ def setup_interface():
88
+ with gr.Blocks() as app:
89
+ with gr.Row():
90
+ query_input = gr.Textbox(label="Enter your query")
91
+ submit_button = gr.Button("Submit")
92
+ response_output = gr.Textbox(label="Response")
93
+ submit_button.click(
94
+ fn=lambda x: asyncio.run(handle_query(x)),
95
+ inputs=[query_input],
96
+ outputs=[response_output]
97
+ )
98
+ return app
99
+
100
+ app = setup_interface()
101
 
102
  if __name__ == "__main__":
103
  app.launch()