import os import sys import logging import gradio as gr import re import numpy as np from sklearn.feature_extraction.text import CountVectorizer from sklearn.naive_bayes import MultinomialNB import asyncio from crewai import Agent from huggingface_hub import InferenceClient import random import json import warnings from typing import Literal # Suppress all deprecation warnings warnings.filterwarnings("ignore", category=DeprecationWarning) # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def get_huggingface_api_token(): token = os.getenv('HUGGINGFACEHUB_API_TOKEN') if token: logger.info("Hugging Face API token found in environment variables.") return token try: with open('config.json', 'r') as config_file: config = json.load(config_file) token = config.get('HUGGINGFACEHUB_API_TOKEN') if token: logger.info("Hugging Face API token found in config.json file.") return token except FileNotFoundError: logger.warning("Config file not found.") except json.JSONDecodeError: logger.error("Error reading the config file. Please check its format.") logger.error("Hugging Face API token not found. Please set it up.") return None token = get_huggingface_api_token() if not token: logger.error("Hugging Face API token is not set. Exiting.") sys.exit(1) hf_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.2", token=token) vectorizer = CountVectorizer() approved_topics = ['account opening', 'trading', 'fees', 'platforms', 'funds', 'regulations', 'support'] X = vectorizer.fit_transform(approved_topics) classifier = MultinomialNB() classifier.fit(X, np.arange(len(approved_topics))) # [Include the updated agent class definitions here] # Instantiate agents communication_expert = CommunicationExpertAgent() response_expert = ResponseExpertAgent() postprocessing_agent = PostprocessingAgent() async def handle_query(query): rephrased_query = await communication_expert.run(query) response = await response_expert.run(rephrased_query) final_response = postprocessing_agent.run(response) return final_response # Gradio interface setup def setup_interface(): with gr.Blocks() as app: with gr.Row(): query_input = gr.Textbox(label="Enter your query") submit_button = gr.Button("Submit") response_output = gr.Textbox(label="Response") submit_button.click( fn=lambda x: asyncio.run(handle_query(x)), inputs=[query_input], outputs=[response_output] ) return app app = setup_interface() if __name__ == "__main__": app.launch()