Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
import gradio as gr | |
import re | |
import numpy as np | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
import asyncio | |
from huggingface_hub import InferenceClient | |
import json | |
import warnings | |
# Suppress all deprecation warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_huggingface_api_token(): | |
""" Retrieves the Hugging Face API token from environment variables or a config file. """ | |
token = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in environment variables.") | |
return token | |
try: | |
with open('config.json', 'r') as config_file: | |
config = json.load(config_file) | |
token = config.get('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in config.json file.") | |
return token | |
except FileNotFoundError: | |
logger.warning("Config file not found.") | |
except json.JSONDecodeError: | |
logger.error("Error reading the config file. Please check its format.") | |
logger.error("Hugging Face API token not found. Please set it up.") | |
return None | |
def initialize_hf_client(): | |
""" Initializes the Hugging Face Inference Client with the API token. """ | |
try: | |
hf_token = get_huggingface_api_token() | |
if not hf_token: | |
raise ValueError("Hugging Face API token is not set. Please set it up before running the application.") | |
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=hf_token) | |
logger.info("Hugging Face Inference Client initialized successfully.") | |
return client | |
except Exception as e: | |
logger.error(f"Failed to initialize Hugging Face client: {e}") | |
sys.exit(1) | |
client = initialize_hf_client() | |
def sanitize_input(input_text): | |
""" Sanitizes input text by removing specific characters. """ | |
return re.sub(r'[<>&\']', '', input_text) | |
def setup_classifier(): | |
""" Sets up and trains a classifier for checking the relevance of a query. """ | |
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support'] | |
vectorizer = CountVectorizer() | |
X = vectorizer.fit_transform(approved_topics) | |
y = np.arange(len(approved_topics)) | |
classifier = MultinomialNB() | |
classifier.fit(X, y) | |
return vectorizer, classifier | |
vectorizer, classifier = setup_classifier() | |
def is_relevant_topic(query): | |
""" Checks if the query is relevant based on pre-defined topics. """ | |
query_vector = vectorizer.transform([query]) | |
prediction = classifier.predict(query_vector) | |
return prediction[0] in range(len(approved_topics)) | |
def redact_sensitive_info(text): | |
""" Redacts sensitive information from the text. """ | |
text = re.sub(r'\b\d{10,12}\b', '[REDACTED]', text) | |
text = re.sub(r'[A-Z]{5}[0-9]{4}[A-Z]', '[REDACTED]', text) | |
return text | |
def check_response_content(response): | |
""" Checks response content for unauthorized claims or advice. """ | |
unauthorized_patterns = [ | |
r'\b(guarantee|assured|certain)\b.*\b(returns|profit)\b', | |
r'\b(buy|sell)\b.*\b(specific stocks?|shares?)\b' | |
] | |
return not any(re.search(pattern, response, re.IGNORECASE) for pattern in unauthorized_patterns) | |
async def generate_response(prompt): | |
""" Generates a response using the Hugging Face inference client. """ | |
try: | |
response = await client.text_generation(prompt, max_new_tokens=500, temperature=0.7) | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return "I apologize, but I'm having trouble generating a response at the moment. Please try again later." | |
def post_process_response(response): | |
""" Post-processes the response to ensure it ends with helpful suggestions. """ | |
response = re.sub(r'\b(stupid|dumb|idiotic|foolish)\b', 'mistaken', response, flags=re.IGNORECASE) | |
if not re.search(r'(Thank you|Is there anything else|Hope this helps|Let me know if you need more information)\s*$', response, re.IGNORECASE): | |
response += "\n\nIs there anything else I can help you with regarding Zerodha's services?" | |
if re.search(r'\b(invest|trade|buy|sell|market)\b', response, re.IGNORECASE): | |
response += "\n\nPlease note that this information is for educational purposes only and should not be considered as financial advice. Always do your own research and consider consulting with a qualified financial advisor before making investment decisions." | |
return response | |
# Gradio interface setup | |
with gr.Blocks() as app: | |
with gr.Row(): | |
username = gr.Textbox(label="Username") | |
password = gr.Textbox(label="Password", type="password") | |
login_button = gr.Button("Login") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter your query") | |
submit_button = gr.Button("Submit") | |
response_output = gr.Textbox(label="Response") | |
login_button.click( | |
fn=lambda u, p: "Login successful" if u == "admin" and p == "admin" else "Login failed", | |
inputs=[username, password], | |
outputs=[gr.Text(label="Login status")] | |
) | |
submit_button.click( | |
fn=lambda x: asyncio.run(generate_response(x)), | |
inputs=[query_input], | |
outputs=[response_output] | |
) | |
if __name__ == "__main__": | |
app.launch() | |