File size: 5,632 Bytes
5b2aed4
 
 
f46801a
 
 
 
 
414a697
7d3b780
 
48c0d8d
 
0ea8c80
 
cfcb5f3
 
 
 
 
7d3b780
0ea8c80
7d3b780
48c0d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3b780
0ea8c80
 
 
 
 
 
 
 
 
 
 
 
a1b124b
0ea8c80
a1b124b
c4c6959
0ea8c80
c4c6959
 
0ea8c80
 
 
 
 
 
 
 
 
c4c6959
0ea8c80
c4c6959
 
0ea8c80
c4c6959
 
 
 
 
0ea8c80
c4c6959
 
 
 
 
0ea8c80
c4c6959
 
 
 
 
 
7d3b780
0ea8c80
7d3b780
 
 
 
 
 
 
06b60e9
0ea8c80
c4c6959
 
 
 
 
 
 
 
 
06b60e9
09bef47
0ea8c80
 
 
 
2f4ecd3
 
0ea8c80
 
 
 
 
 
 
 
 
2f4ecd3
 
0ea8c80
 
 
 
 
9a244e8
09bef47
0ea8c80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import sys
import logging
import gradio as gr
import re
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
import asyncio
from huggingface_hub import InferenceClient
import json
import warnings

# Suppress all deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def get_huggingface_api_token():
    """ Retrieves the Hugging Face API token from environment variables or a config file. """
    token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
    if token:
        logger.info("Hugging Face API token found in environment variables.")
        return token

    try:
        with open('config.json', 'r') as config_file:
            config = json.load(config_file)
            token = config.get('HUGGINGFACEHUB_API_TOKEN')
            if token:
                logger.info("Hugging Face API token found in config.json file.")
                return token
    except FileNotFoundError:
        logger.warning("Config file not found.")
    except json.JSONDecodeError:
        logger.error("Error reading the config file. Please check its format.")

    logger.error("Hugging Face API token not found. Please set it up.")
    return None

def initialize_hf_client():
    """ Initializes the Hugging Face Inference Client with the API token. """
    try:
        hf_token = get_huggingface_api_token()
        if not hf_token:
            raise ValueError("Hugging Face API token is not set. Please set it up before running the application.")
        client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=hf_token)
        logger.info("Hugging Face Inference Client initialized successfully.")
        return client
    except Exception as e:
        logger.error(f"Failed to initialize Hugging Face client: {e}")
        sys.exit(1)

client = initialize_hf_client()

def sanitize_input(input_text):
    """ Sanitizes input text by removing specific characters. """
    return re.sub(r'[<>&\']', '', input_text)

def setup_classifier():
    """ Sets up and trains a classifier for checking the relevance of a query. """
    approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support']
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(approved_topics)
    y = np.arange(len(approved_topics))
    classifier = MultinomialNB()
    classifier.fit(X, y)
    return vectorizer, classifier

vectorizer, classifier = setup_classifier()

def is_relevant_topic(query):
    """ Checks if the query is relevant based on pre-defined topics. """
    query_vector = vectorizer.transform([query])
    prediction = classifier.predict(query_vector)
    return prediction[0] in range(len(approved_topics))

def redact_sensitive_info(text):
    """ Redacts sensitive information from the text. """
    text = re.sub(r'\b\d{10,12}\b', '[REDACTED]', text)
    text = re.sub(r'[A-Z]{5}[0-9]{4}[A-Z]', '[REDACTED]', text)
    return text

def check_response_content(response):
    """ Checks response content for unauthorized claims or advice. """
    unauthorized_patterns = [
        r'\b(guarantee|assured|certain)\b.*\b(returns|profit)\b',
        r'\b(buy|sell)\b.*\b(specific stocks?|shares?)\b'
    ]
    return not any(re.search(pattern, response, re.IGNORECASE) for pattern in unauthorized_patterns)

async def generate_response(prompt):
    """ Generates a response using the Hugging Face inference client. """
    try:
        response = await client.text_generation(prompt, max_new_tokens=500, temperature=0.7)
        return response
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        return "I apologize, but I'm having trouble generating a response at the moment. Please try again later."

def post_process_response(response):
    """ Post-processes the response to ensure it ends with helpful suggestions. """
    response = re.sub(r'\b(stupid|dumb|idiotic|foolish)\b', 'mistaken', response, flags=re.IGNORECASE)
    
    if not re.search(r'(Thank you|Is there anything else|Hope this helps|Let me know if you need more information)\s*$', response, re.IGNORECASE):
        response += "\n\nIs there anything else I can help you with regarding Zerodha's services?"
    
    if re.search(r'\b(invest|trade|buy|sell|market)\b', response, re.IGNORECASE):
        response += "\n\nPlease note that this information is for educational purposes only and should not be considered as financial advice. Always do your own research and consider consulting with a qualified financial advisor before making investment decisions."
    
    return response

# Gradio interface setup
with gr.Blocks() as app:
    with gr.Row():
        username = gr.Textbox(label="Username")
        password = gr.Textbox(label="Password", type="password")
        login_button = gr.Button("Login")
    
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query")
        submit_button = gr.Button("Submit")
        response_output = gr.Textbox(label="Response")

    login_button.click(
        fn=lambda u, p: "Login successful" if u == "admin" and p == "admin" else "Login failed",
        inputs=[username, password],
        outputs=[gr.Text(label="Login status")]
    )

    submit_button.click(
        fn=lambda x: asyncio.run(generate_response(x)),
        inputs=[query_input],
        outputs=[response_output]
    )

if __name__ == "__main__":
    app.launch()