TestAssistant / app.py
invincible-jha's picture
Update app.py
0ea8c80 verified
raw
history blame
5.63 kB
import os
import sys
import logging
import gradio as gr
import re
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
import asyncio
from huggingface_hub import InferenceClient
import json
import warnings
# Suppress all deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_huggingface_api_token():
""" Retrieves the Hugging Face API token from environment variables or a config file. """
token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if token:
logger.info("Hugging Face API token found in environment variables.")
return token
try:
with open('config.json', 'r') as config_file:
config = json.load(config_file)
token = config.get('HUGGINGFACEHUB_API_TOKEN')
if token:
logger.info("Hugging Face API token found in config.json file.")
return token
except FileNotFoundError:
logger.warning("Config file not found.")
except json.JSONDecodeError:
logger.error("Error reading the config file. Please check its format.")
logger.error("Hugging Face API token not found. Please set it up.")
return None
def initialize_hf_client():
""" Initializes the Hugging Face Inference Client with the API token. """
try:
hf_token = get_huggingface_api_token()
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set it up before running the application.")
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=hf_token)
logger.info("Hugging Face Inference Client initialized successfully.")
return client
except Exception as e:
logger.error(f"Failed to initialize Hugging Face client: {e}")
sys.exit(1)
client = initialize_hf_client()
def sanitize_input(input_text):
""" Sanitizes input text by removing specific characters. """
return re.sub(r'[<>&\']', '', input_text)
def setup_classifier():
""" Sets up and trains a classifier for checking the relevance of a query. """
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support']
vectorizer = CountVectorizer()
X = vectorizer.fit_transform(approved_topics)
y = np.arange(len(approved_topics))
classifier = MultinomialNB()
classifier.fit(X, y)
return vectorizer, classifier
vectorizer, classifier = setup_classifier()
def is_relevant_topic(query):
""" Checks if the query is relevant based on pre-defined topics. """
query_vector = vectorizer.transform([query])
prediction = classifier.predict(query_vector)
return prediction[0] in range(len(approved_topics))
def redact_sensitive_info(text):
""" Redacts sensitive information from the text. """
text = re.sub(r'\b\d{10,12}\b', '[REDACTED]', text)
text = re.sub(r'[A-Z]{5}[0-9]{4}[A-Z]', '[REDACTED]', text)
return text
def check_response_content(response):
""" Checks response content for unauthorized claims or advice. """
unauthorized_patterns = [
r'\b(guarantee|assured|certain)\b.*\b(returns|profit)\b',
r'\b(buy|sell)\b.*\b(specific stocks?|shares?)\b'
]
return not any(re.search(pattern, response, re.IGNORECASE) for pattern in unauthorized_patterns)
async def generate_response(prompt):
""" Generates a response using the Hugging Face inference client. """
try:
response = await client.text_generation(prompt, max_new_tokens=500, temperature=0.7)
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
return "I apologize, but I'm having trouble generating a response at the moment. Please try again later."
def post_process_response(response):
""" Post-processes the response to ensure it ends with helpful suggestions. """
response = re.sub(r'\b(stupid|dumb|idiotic|foolish)\b', 'mistaken', response, flags=re.IGNORECASE)
if not re.search(r'(Thank you|Is there anything else|Hope this helps|Let me know if you need more information)\s*$', response, re.IGNORECASE):
response += "\n\nIs there anything else I can help you with regarding Zerodha's services?"
if re.search(r'\b(invest|trade|buy|sell|market)\b', response, re.IGNORECASE):
response += "\n\nPlease note that this information is for educational purposes only and should not be considered as financial advice. Always do your own research and consider consulting with a qualified financial advisor before making investment decisions."
return response
# Gradio interface setup
with gr.Blocks() as app:
with gr.Row():
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
login_button = gr.Button("Login")
with gr.Row():
query_input = gr.Textbox(label="Enter your query")
submit_button = gr.Button("Submit")
response_output = gr.Textbox(label="Response")
login_button.click(
fn=lambda u, p: "Login successful" if u == "admin" and p == "admin" else "Login failed",
inputs=[username, password],
outputs=[gr.Text(label="Login status")]
)
submit_button.click(
fn=lambda x: asyncio.run(generate_response(x)),
inputs=[query_input],
outputs=[response_output]
)
if __name__ == "__main__":
app.launch()