Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
import gradio as gr | |
import re | |
import numpy as np | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.naive_bayes import MultinomialNB | |
import asyncio | |
from crewai import Agent | |
from huggingface_hub import InferenceClient | |
import random | |
import json | |
import warnings | |
from typing import Literal | |
# Suppress all deprecation warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def get_huggingface_api_token(): | |
token = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in environment variables.") | |
return token | |
try: | |
with open('config.json', 'r') as config_file: | |
config = json.load(config_file) | |
token = config.get('HUGGINGFACEHUB_API_TOKEN') | |
if token: | |
logger.info("Hugging Face API token found in config.json file.") | |
return token | |
except FileNotFoundError: | |
logger.warning("Config file not found.") | |
except json.JSONDecodeError: | |
logger.error("Error reading the config file. Please check its format.") | |
logger.error("Hugging Face API token not found. Please set it up.") | |
return None | |
token = get_huggingface_api_token() | |
if not token: | |
logger.error("Hugging Face API token is not set. Exiting.") | |
sys.exit(1) | |
hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token) | |
vectorizer = CountVectorizer() | |
approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support'] | |
X = vectorizer.fit_transform(approved_topics) | |
classifier = MultinomialNB() | |
classifier.fit(X, np.arange(len(approved_topics))) | |
# [Include the updated agent class definitions here] | |
# Instantiate agents | |
communication_expert = CommunicationExpertAgent() | |
response_expert = ResponseExpertAgent() | |
postprocessing_agent = PostprocessingAgent() | |
async def handle_query(query): | |
rephrased_query = await communication_expert.run(query) | |
response = await response_expert.run(rephrased_query) | |
final_response = postprocessing_agent.run(response) | |
return final_response | |
# Gradio interface setup | |
def setup_interface(): | |
with gr.Blocks() as app: | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter your query") | |
submit_button = gr.Button("Submit") | |
response_output = gr.Textbox(label="Response") | |
submit_button.click( | |
fn=lambda x: asyncio.run(handle_query(x)), | |
inputs=[query_input], | |
outputs=[response_output] | |
) | |
return app | |
app = setup_interface() | |
if __name__ == "__main__": | |
app.launch() | |