chatbot / app.py
Phoenix21's picture
Updated app.py with multiple feature
a004b34 verified
raw
history blame
12.3 kB
# app.py
import os
import getpass
import pandas as pd
import chardet
import logging
import gradio as gr
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from langchain_community.retrievers import BM25Retriever
from smolagents import (
CodeAgent,
HfApiModel,
DuckDuckGoSearchTool,
Tool,
ManagedAgent,
LiteLLMModel
)
# --------------------------------------------------------------------------------
# Set up logging
# --------------------------------------------------------------------------------
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("Daily Wellness AI Guru")
# --------------------------------------------------------------------------------
# Ensure Hugging Face API Token
# --------------------------------------------------------------------------------
# In a Hugging Face Space, you can set HF_API_TOKEN as a secret variable.
# If it's not set, you could prompt for it locally, but in Spaces,
# you typically wouldn't do getpass. We'll leave the logic here as fallback.
if 'HF_API_TOKEN' not in os.environ or not os.environ['HF_API_TOKEN']:
os.environ['HF_API_TOKEN'] = getpass.getpass('Enter your Hugging Face API Token: ')
else:
print("HF_API_TOKEN is already set.")
# --------------------------------------------------------------------------------
# CSV Loading and Processing
# --------------------------------------------------------------------------------
def load_csv(file_path):
"""
Load and process a CSV file into two lists: questions and answers.
"""
try:
# Detect the encoding of the file
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
# Load the CSV using the detected encoding
data = pd.read_csv(file_path, encoding=encoding)
# Validate that the required columns are present
if 'Question' not in data.columns or 'Answers' not in data.columns:
raise ValueError("The CSV file must contain 'Question' and 'Answers' columns.")
# Drop any rows with missing values in 'Question' or 'Answers'
data = data.dropna(subset=['Question', 'Answers'])
# Extract questions and answers
questions = data['Question'].tolist()
answers = data['Answers'].tolist()
logger.info(f"Loaded {len(questions)} questions and {len(answers)} answers from {file_path}")
return questions, answers
except Exception as e:
logger.error(f"Error loading CSV file: {e}")
return [], []
# --------------------------------------------------------------------------------
# Load the AIChatbot.csv file
# --------------------------------------------------------------------------------
file_path = "AIChatbot.csv" # Ensure this file is in the same directory as app.py
corpus_questions, corpus_answers = load_csv(file_path)
if not corpus_questions:
raise ValueError(f"Failed to load questions from {file_path}.")
# --------------------------------------------------------------------------------
# Embedding Model
# --------------------------------------------------------------------------------
embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
embedding_model = SentenceTransformer(embedding_model_name)
logger.info(f"Loaded sentence embedding model: {embedding_model_name}")
# Encode Questions (for retrieval)
question_embeddings = embedding_model.encode(corpus_questions, convert_to_tensor=True)
# --------------------------------------------------------------------------------
# Cross-Encoder for Re-Ranking
# --------------------------------------------------------------------------------
cross_encoder_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = CrossEncoder(cross_encoder_model_name)
logger.info(f"Loaded cross-encoder model: {cross_encoder_model_name}")
# --------------------------------------------------------------------------------
# Retrieval + Re-ranking Class
# --------------------------------------------------------------------------------
class EmbeddingRetriever:
def __init__(self, questions, answers, embeddings, model, cross_encoder):
self.questions = questions
self.answers = answers
self.embeddings = embeddings
self.model = model
self.cross_encoder = cross_encoder
def retrieve(self, query, top_k=3):
# Compute query embedding
query_embedding = self.model.encode(query, convert_to_tensor=True)
scores = util.pytorch_cos_sim(query_embedding, self.embeddings)[0].cpu().tolist()
# Combine data
scored_data = list(zip(self.questions, self.answers, scores))
# Sort by best scores
scored_data = sorted(scored_data, key=lambda x: x[2], reverse=True)
# Take top_k
top_candidates = scored_data[:top_k]
# Cross-encode re-rank
cross_inputs = [[query, candidate[0]] for candidate in top_candidates]
cross_scores = self.cross_encoder.predict(cross_inputs)
reranked = sorted(
zip(top_candidates, cross_scores),
key=lambda x: x[1],
reverse=True
)
# The best candidate
best_candidate = reranked[0][0] # (question, answer, score)
best_answer = best_candidate[1]
return best_answer
retriever = EmbeddingRetriever(
questions=corpus_questions,
answers=corpus_answers,
embeddings=question_embeddings,
model=embedding_model,
cross_encoder=cross_encoder
)
# --------------------------------------------------------------------------------
# Simple Answer Expander (Without custom sampling parameters)
# --------------------------------------------------------------------------------
class AnswerExpander:
def __init__(self, model: HfApiModel):
self.model = model
def expand(self, question: str, short_answer: str) -> str:
"""
Prompt the LLM to provide a more creative, brand-aligned answer.
"""
prompt = (
"You are Daily Wellness AI, a friendly and creative wellness expert. "
"The user has a question about well-being. Provide an encouraging, day-to-day "
"wellness perspective. Be gentle, uplifting, and brand-aligned.\n\n"
f"Question: {question}\n"
f"Current short answer: {short_answer}\n\n"
"Please rephrase and expand with more detail, wellness tips, daily-life "
"applications, and an optimistic tone. Keep it informal, friendly, and end "
"with a short inspirational note.\n"
)
try:
expanded_answer = self.model.run(prompt)
return expanded_answer.strip()
except Exception as e:
logger.error(f"Failed to expand answer: {e}")
return short_answer
# NOTE: We are using a basic HfApiModel here (no custom sampling).
expander_model = HfApiModel()
answer_expander = AnswerExpander(expander_model)
# --------------------------------------------------------------------------------
# Enhanced Retriever Tool
# --------------------------------------------------------------------------------
from smolagents import Tool
class RetrieverTool(Tool):
name = "retriever_tool"
description = "Uses semantic search + cross-encoder re-ranking to retrieve the best answer."
inputs = {
"query": {
"type": "string",
"description": "User query for retrieving relevant information.",
}
}
output_type = "string"
def __init__(self, retriever, expander):
super().__init__()
self.retriever = retriever
self.expander = expander
def forward(self, query):
best_answer = self.retriever.retrieve(query, top_k=3)
if best_answer:
# If short, expand it
if len(best_answer.strip()) < 80:
logger.info("Answer is short. Expanding with LLM.")
best_answer = self.expander.expand(query, best_answer)
return best_answer
return "No relevant information found."
retriever_tool = RetrieverTool(retriever, answer_expander)
# --------------------------------------------------------------------------------
# DuckDuckGo (Web) Fallback
# --------------------------------------------------------------------------------
search_tool = DuckDuckGoSearchTool()
# --------------------------------------------------------------------------------
# Managed Agents
# --------------------------------------------------------------------------------
from smolagents import ManagedAgent, CodeAgent, LiteLLMModel
retriever_agent = ManagedAgent(
agent=CodeAgent(tools=[retriever_tool], model=LiteLLMModel("groq/llama3-8b-8192")),
name="retriever_agent",
description="Retrieves answers from the local knowledge base (CSV file)."
)
web_agent = ManagedAgent(
agent=CodeAgent(tools=[search_tool], model=HfApiModel()),
name="web_search_agent",
description="Performs web searches if the local knowledge base doesn't have an answer."
)
# --------------------------------------------------------------------------------
# Manager Agent to Orchestrate
# --------------------------------------------------------------------------------
manager_agent = CodeAgent(
tools=[],
model=HfApiModel(),
managed_agents=[retriever_agent, web_agent],
verbose=True
)
# --------------------------------------------------------------------------------
# Gradio Interface
# --------------------------------------------------------------------------------
def gradio_interface(query):
try:
logger.info(f"User query: {query}")
# 1) Query local knowledge base
retriever_response = retriever_tool.forward(query)
if retriever_response != "No relevant information found.":
logger.info("Provided answer from local DB (possibly expanded).")
return (
f"Hello! This is **Daily Wellness AI**.\n\n"
f"{retriever_response}\n\n"
"Disclaimer: This is general wellness information, "
"not a substitute for professional medical advice.\n\n"
"Wishing you a calm and wonderful day!"
)
# 2) Fallback to Web if no relevant local info
logger.info("Falling back to web search.")
web_response = web_agent.run(query)
if web_response:
logger.info("Response retrieved from the web.")
return (
f"Hello! This is **Daily Wellness AI**.\n\n"
f"{web_response.strip()}\n\n"
"Disclaimer: This is general wellness information, "
"not a substitute for professional medical advice.\n\n"
"Wishing you a calm and wonderful day!"
)
# 3) Default fallback
logger.info("No response found from any source.")
return (
"Hello! This is **Daily Wellness AI**.\n\n"
"I'm sorry, I couldn't find an answer to your question. "
"Please try rephrasing or ask something else.\n\n"
"Take care, and have a wonderful day!"
)
except Exception as e:
logger.error(f"Error processing query: {e}")
return "**An error occurred while processing your request. Please try again later.**"
# --------------------------------------------------------------------------------
# Launch Gradio App
# --------------------------------------------------------------------------------
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(
label="Ask Daily Wellness AI",
placeholder="e.g., What is box breathing?"
),
outputs=gr.Markdown(label="Answer from Daily Wellness AI"),
title="Daily Wellness AI Guru Chatbot",
description=(
"Ask wellness-related questions to get detailed, creative answers from "
"our knowledge base—expanded by an LLM if needed—or from the web. "
"We aim to bring calm and positivity to your day."
),
theme="compact"
)
def main():
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
# If running in a local environment, we can also just call main()
if __name__ == "__main__":
main()