|
|
|
|
|
import os |
|
import getpass |
|
import pandas as pd |
|
import chardet |
|
import logging |
|
import gradio as gr |
|
|
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder |
|
from langchain_community.retrievers import BM25Retriever |
|
from smolagents import ( |
|
CodeAgent, |
|
HfApiModel, |
|
DuckDuckGoSearchTool, |
|
Tool, |
|
ManagedAgent, |
|
LiteLLMModel |
|
) |
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger("Daily Wellness AI Guru") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'HF_API_TOKEN' not in os.environ or not os.environ['HF_API_TOKEN']: |
|
os.environ['HF_API_TOKEN'] = getpass.getpass('Enter your Hugging Face API Token: ') |
|
else: |
|
print("HF_API_TOKEN is already set.") |
|
|
|
|
|
|
|
|
|
def load_csv(file_path): |
|
""" |
|
Load and process a CSV file into two lists: questions and answers. |
|
""" |
|
try: |
|
|
|
with open(file_path, 'rb') as f: |
|
result = chardet.detect(f.read()) |
|
encoding = result['encoding'] |
|
|
|
|
|
data = pd.read_csv(file_path, encoding=encoding) |
|
|
|
|
|
if 'Question' not in data.columns or 'Answers' not in data.columns: |
|
raise ValueError("The CSV file must contain 'Question' and 'Answers' columns.") |
|
|
|
|
|
data = data.dropna(subset=['Question', 'Answers']) |
|
|
|
|
|
questions = data['Question'].tolist() |
|
answers = data['Answers'].tolist() |
|
|
|
logger.info(f"Loaded {len(questions)} questions and {len(answers)} answers from {file_path}") |
|
return questions, answers |
|
except Exception as e: |
|
logger.error(f"Error loading CSV file: {e}") |
|
return [], [] |
|
|
|
|
|
|
|
|
|
file_path = "AIChatbot.csv" |
|
corpus_questions, corpus_answers = load_csv(file_path) |
|
|
|
if not corpus_questions: |
|
raise ValueError(f"Failed to load questions from {file_path}.") |
|
|
|
|
|
|
|
|
|
embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1" |
|
embedding_model = SentenceTransformer(embedding_model_name) |
|
logger.info(f"Loaded sentence embedding model: {embedding_model_name}") |
|
|
|
|
|
question_embeddings = embedding_model.encode(corpus_questions, convert_to_tensor=True) |
|
|
|
|
|
|
|
|
|
cross_encoder_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" |
|
cross_encoder = CrossEncoder(cross_encoder_model_name) |
|
logger.info(f"Loaded cross-encoder model: {cross_encoder_model_name}") |
|
|
|
|
|
|
|
|
|
class EmbeddingRetriever: |
|
def __init__(self, questions, answers, embeddings, model, cross_encoder): |
|
self.questions = questions |
|
self.answers = answers |
|
self.embeddings = embeddings |
|
self.model = model |
|
self.cross_encoder = cross_encoder |
|
|
|
def retrieve(self, query, top_k=3): |
|
|
|
query_embedding = self.model.encode(query, convert_to_tensor=True) |
|
scores = util.pytorch_cos_sim(query_embedding, self.embeddings)[0].cpu().tolist() |
|
|
|
|
|
scored_data = list(zip(self.questions, self.answers, scores)) |
|
|
|
scored_data = sorted(scored_data, key=lambda x: x[2], reverse=True) |
|
|
|
top_candidates = scored_data[:top_k] |
|
|
|
|
|
cross_inputs = [[query, candidate[0]] for candidate in top_candidates] |
|
cross_scores = self.cross_encoder.predict(cross_inputs) |
|
|
|
reranked = sorted( |
|
zip(top_candidates, cross_scores), |
|
key=lambda x: x[1], |
|
reverse=True |
|
) |
|
|
|
|
|
best_candidate = reranked[0][0] |
|
best_answer = best_candidate[1] |
|
return best_answer |
|
|
|
retriever = EmbeddingRetriever( |
|
questions=corpus_questions, |
|
answers=corpus_answers, |
|
embeddings=question_embeddings, |
|
model=embedding_model, |
|
cross_encoder=cross_encoder |
|
) |
|
|
|
|
|
|
|
|
|
class AnswerExpander: |
|
def __init__(self, model: HfApiModel): |
|
self.model = model |
|
|
|
def expand(self, question: str, short_answer: str) -> str: |
|
""" |
|
Prompt the LLM to provide a more creative, brand-aligned answer. |
|
""" |
|
prompt = ( |
|
"You are Daily Wellness AI, a friendly and creative wellness expert. " |
|
"The user has a question about well-being. Provide an encouraging, day-to-day " |
|
"wellness perspective. Be gentle, uplifting, and brand-aligned.\n\n" |
|
f"Question: {question}\n" |
|
f"Current short answer: {short_answer}\n\n" |
|
"Please rephrase and expand with more detail, wellness tips, daily-life " |
|
"applications, and an optimistic tone. Keep it informal, friendly, and end " |
|
"with a short inspirational note.\n" |
|
) |
|
try: |
|
expanded_answer = self.model.run(prompt) |
|
return expanded_answer.strip() |
|
except Exception as e: |
|
logger.error(f"Failed to expand answer: {e}") |
|
return short_answer |
|
|
|
|
|
expander_model = HfApiModel() |
|
answer_expander = AnswerExpander(expander_model) |
|
|
|
|
|
|
|
|
|
from smolagents import Tool |
|
|
|
class RetrieverTool(Tool): |
|
name = "retriever_tool" |
|
description = "Uses semantic search + cross-encoder re-ranking to retrieve the best answer." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "User query for retrieving relevant information.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, retriever, expander): |
|
super().__init__() |
|
self.retriever = retriever |
|
self.expander = expander |
|
|
|
def forward(self, query): |
|
best_answer = self.retriever.retrieve(query, top_k=3) |
|
if best_answer: |
|
|
|
if len(best_answer.strip()) < 80: |
|
logger.info("Answer is short. Expanding with LLM.") |
|
best_answer = self.expander.expand(query, best_answer) |
|
return best_answer |
|
return "No relevant information found." |
|
|
|
retriever_tool = RetrieverTool(retriever, answer_expander) |
|
|
|
|
|
|
|
|
|
search_tool = DuckDuckGoSearchTool() |
|
|
|
|
|
|
|
|
|
from smolagents import ManagedAgent, CodeAgent, LiteLLMModel |
|
|
|
retriever_agent = ManagedAgent( |
|
agent=CodeAgent(tools=[retriever_tool], model=LiteLLMModel("groq/llama3-8b-8192")), |
|
name="retriever_agent", |
|
description="Retrieves answers from the local knowledge base (CSV file)." |
|
) |
|
|
|
web_agent = ManagedAgent( |
|
agent=CodeAgent(tools=[search_tool], model=HfApiModel()), |
|
name="web_search_agent", |
|
description="Performs web searches if the local knowledge base doesn't have an answer." |
|
) |
|
|
|
|
|
|
|
|
|
manager_agent = CodeAgent( |
|
tools=[], |
|
model=HfApiModel(), |
|
managed_agents=[retriever_agent, web_agent], |
|
verbose=True |
|
) |
|
|
|
|
|
|
|
|
|
def gradio_interface(query): |
|
try: |
|
logger.info(f"User query: {query}") |
|
|
|
|
|
retriever_response = retriever_tool.forward(query) |
|
if retriever_response != "No relevant information found.": |
|
logger.info("Provided answer from local DB (possibly expanded).") |
|
return ( |
|
f"Hello! This is **Daily Wellness AI**.\n\n" |
|
f"{retriever_response}\n\n" |
|
"Disclaimer: This is general wellness information, " |
|
"not a substitute for professional medical advice.\n\n" |
|
"Wishing you a calm and wonderful day!" |
|
) |
|
|
|
|
|
logger.info("Falling back to web search.") |
|
web_response = web_agent.run(query) |
|
if web_response: |
|
logger.info("Response retrieved from the web.") |
|
return ( |
|
f"Hello! This is **Daily Wellness AI**.\n\n" |
|
f"{web_response.strip()}\n\n" |
|
"Disclaimer: This is general wellness information, " |
|
"not a substitute for professional medical advice.\n\n" |
|
"Wishing you a calm and wonderful day!" |
|
) |
|
|
|
|
|
logger.info("No response found from any source.") |
|
return ( |
|
"Hello! This is **Daily Wellness AI**.\n\n" |
|
"I'm sorry, I couldn't find an answer to your question. " |
|
"Please try rephrasing or ask something else.\n\n" |
|
"Take care, and have a wonderful day!" |
|
) |
|
except Exception as e: |
|
logger.error(f"Error processing query: {e}") |
|
return "**An error occurred while processing your request. Please try again later.**" |
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Textbox( |
|
label="Ask Daily Wellness AI", |
|
placeholder="e.g., What is box breathing?" |
|
), |
|
outputs=gr.Markdown(label="Answer from Daily Wellness AI"), |
|
title="Daily Wellness AI Guru Chatbot", |
|
description=( |
|
"Ask wellness-related questions to get detailed, creative answers from " |
|
"our knowledge base—expanded by an LLM if needed—or from the web. " |
|
"We aim to bring calm and positivity to your day." |
|
), |
|
theme="compact" |
|
) |
|
|
|
def main(): |
|
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|