Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import re | |
from langchain.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_groq import ChatGroq | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
import chardet | |
import gradio as gr | |
import pandas as pd | |
import json | |
# Enable logging for debugging | |
logging.basicConfig(level=logging.INFO) # Changed to INFO to reduce verbosity | |
logger = logging.getLogger(__name__) | |
# Function to clean the API key | |
def clean_api_key(key): | |
return ''.join(c for c in key if ord(c) < 128) | |
# Load the GROQ API key from environment variables (set as a secret in the Space) | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
api_key = clean_api_key(api_key).strip() # Clean and strip whitespace | |
# Function to clean text by removing non-ASCII characters | |
def clean_text(text): | |
return text.encode("ascii", errors="ignore").decode() | |
# Function to load and clean documents from multiple file formats | |
def load_documents(file_paths): | |
docs = [] | |
for file_path in file_paths: | |
ext = os.path.splitext(file_path)[-1].lower() | |
try: | |
if ext == ".csv": | |
# Handle CSV files | |
with open(file_path, 'rb') as f: | |
result = chardet.detect(f.read()) | |
encoding = result['encoding'] | |
data = pd.read_csv(file_path, encoding=encoding) | |
for index, row in data.iterrows(): | |
content = clean_text(row.to_string()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".json": | |
# Handle JSON files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
for entry in data: | |
content = clean_text(json.dumps(entry)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif isinstance(data, dict): | |
content = clean_text(json.dumps(data)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".txt": | |
# Handle TXT files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = clean_text(f.read()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
else: | |
logger.warning(f"Unsupported file format: {file_path}") | |
except Exception as e: | |
logger.error(f"Error processing file {file_path}: {e}") | |
logger.debug("Exception details:", exc_info=True) | |
return docs | |
# Function to ensure the response ends with complete sentences | |
def ensure_complete_sentences(text): | |
# Use regex to find all complete sentences | |
sentences = re.findall(r'[^.!?]*[.!?]', text) | |
if sentences: | |
# Join all complete sentences to form the complete answer | |
return ' '.join(sentences).strip() | |
return text # Return as is if no complete sentence is found | |
# Function to check if input is valid | |
def is_valid_input(text): | |
""" | |
Checks if the input text is meaningful. | |
Returns True if the text contains alphabetic characters and is of sufficient length. | |
""" | |
if not text or text.strip() == "": | |
return False | |
# Regex to check for at least one alphabetic character | |
if not re.search('[A-Za-z]', text): | |
return False | |
# Additional check: minimum length | |
if len(text.strip()) < 5: | |
return False | |
return True | |
# Initialize the LLM using ChatGroq with GROQ's API | |
def initialize_llm(model, temperature, max_tokens): | |
try: | |
# Allocate a portion of tokens for the prompt, e.g., 20% | |
prompt_allocation = int(max_tokens * 0.2) | |
response_max_tokens = max_tokens - prompt_allocation | |
if response_max_tokens <= 50: | |
raise ValueError("max_tokens is too small to allocate for the response.") | |
llm = ChatGroq( | |
model=model, | |
temperature=temperature, | |
max_tokens=response_max_tokens, # Adjusted max_tokens | |
api_key=api_key # Ensure the API key is passed correctly | |
) | |
logger.info("LLM initialized successfully.") | |
return llm | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {e}") | |
raise | |
# Create the RAG pipeline | |
def create_rag_pipeline(file_paths, model, temperature, max_tokens): | |
try: | |
llm = initialize_llm(model, temperature, max_tokens) | |
docs = load_documents(file_paths) | |
if not docs: | |
logger.warning("No documents were loaded. Please check your file paths and formats.") | |
return None, "No documents were loaded. Please check your file paths and formats." | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
# Initialize the embedding model | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Use a temporary directory for Chroma vectorstore to prevent caching issues on Hugging Face Spaces | |
vectorstore = Chroma.from_documents( | |
documents=splits, | |
embedding=embedding_model, | |
persist_directory="/tmp/chroma_db" # Temporary storage directory | |
) | |
vectorstore.persist() # Save the database to disk | |
logger.info("Vectorstore initialized and persisted successfully.") | |
retriever = vectorstore.as_retriever() | |
custom_prompt_template = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity. | |
Context: | |
{context} | |
Question: | |
{question} | |
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly. | |
""" | |
) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": custom_prompt_template} | |
) | |
logger.info("RAG pipeline created successfully.") | |
return rag_chain, "Pipeline created successfully." | |
except Exception as e: | |
logger.error(f"Error creating RAG pipeline: {e}") | |
logger.debug("Exception details:", exc_info=True) | |
return None, f"Error creating RAG pipeline: {e}" | |
# Define positive and negative words for rule-based sentiment analysis | |
POSITIVE_WORDS = { | |
"good", "great", "excellent", "amazing", "wonderful", "fantastic", "positive", | |
"helpful", "satisfied", "happy", "love", "liked", "enjoyed", "beneficial", | |
"superb", "awesome", "nice", "brilliant", "favorable", "pleased" | |
} | |
NEGATIVE_WORDS = { | |
"bad", "terrible", "awful", "poor", "disappointed", "unsatisfied", "hate", | |
"hated", "dislike", "dislikes", "worst", "negative", "not helpful", "frustrated", | |
"unhappy", "dissatisfied", "unfortunate", "horrible", "annoyed", "problem", "issues" | |
} | |
# Function to handle feedback with rule-based sentiment analysis | |
def handle_feedback(feedback_text): | |
""" | |
Handles user feedback by analyzing its sentiment and providing a dynamic response. | |
Stores the feedback in a temporary file for persistence during the session. | |
Parameters: | |
- feedback_text (str): The feedback provided by the user. | |
Returns: | |
- str: Acknowledgment message based on feedback sentiment. | |
""" | |
if feedback_text and feedback_text.strip() != "": | |
# Normalize feedback text to lowercase for comparison | |
feedback_lower = feedback_text.lower() | |
# Count positive and negative words | |
positive_count = sum(word in feedback_lower for word in POSITIVE_WORDS) | |
negative_count = sum(word in feedback_lower for word in NEGATIVE_WORDS) | |
# Determine sentiment based on counts | |
if positive_count > negative_count: | |
sentiment = "positive" | |
acknowledgment = "Thank you for your positive feedback! We're glad to hear that you found our service helpful." | |
elif negative_count > positive_count: | |
sentiment = "negative" | |
acknowledgment = "We're sorry to hear that you're not satisfied. Your feedback is valuable to us, and we'll strive to improve." | |
else: | |
sentiment = "neutral" | |
acknowledgment = "Thank you for your feedback. We appreciate your input." | |
# Log the feedback with sentiment | |
logger.info(f"User Feedback: {feedback_text} | Sentiment: {sentiment}") | |
# Optionally, store feedback in a temporary file | |
try: | |
with open("/tmp/user_feedback.txt", "a") as f: | |
f.write(f"{feedback_text} | Sentiment: {sentiment}\n") | |
logger.debug("Feedback stored successfully in /tmp/user_feedback.txt.") | |
except Exception as e: | |
logger.error(f"Error storing feedback: {e}") | |
return acknowledgment | |
else: | |
return "No feedback provided." | |
# Initialize the RAG pipeline once at startup | |
# Define the file paths (ensure 'AIChatbot.csv' is in the root directory of your Space) | |
file_paths = ['AIChatbot.csv'] | |
model = "llama3-8b-8192" # Default model name | |
temperature = 0.7 # Default temperature | |
max_tokens = 500 # Default max tokens | |
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) | |
if rag_chain is None: | |
logger.error("Failed to initialize RAG pipeline at startup.") | |
# Depending on your preference, you might want to exit or continue. Here, we'll continue. | |
# Function to answer questions with input validation and post-processing | |
def answer_question(model, temperature, max_tokens, question, feedback): | |
# Validate input | |
if not is_valid_input(question): | |
logger.info("Received invalid input from user.") | |
return "Please provide a valid question or input containing meaningful text.", "" | |
# Check if the RAG pipeline is initialized | |
if rag_chain is None: | |
logger.error("RAG pipeline is not initialized.") | |
return "The system is currently unavailable. Please try again later.", "" | |
try: | |
answer = rag_chain.run(question) | |
logger.info("Question answered successfully.") | |
# Post-process to ensure the answer ends with complete sentences | |
complete_answer = ensure_complete_sentences(answer) | |
# Handle feedback | |
feedback_response = handle_feedback(feedback) | |
return complete_answer, feedback_response | |
except Exception as e_inner: | |
logger.error(f"Error during RAG pipeline execution: {e_inner}") | |
logger.debug("Exception details:", exc_info=True) | |
return f"Error during RAG pipeline execution: {e_inner}", "" | |
# Gradio Interface with Feedback Mechanism | |
def gradio_interface(model, temperature, max_tokens, question, feedback): | |
# Optionally, you can add functionality to update the RAG pipeline if model or parameters change | |
# For now, we'll ignore changes to model parameters after initialization | |
return answer_question(model, temperature, max_tokens, question, feedback) | |
# Define Gradio UI | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox( | |
label="Model Name", | |
value=model, | |
placeholder="e.g., llama3-8b-8192" | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=temperature, | |
info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." | |
), | |
gr.Slider( | |
label="Max Tokens", | |
minimum=200, | |
maximum=2048, | |
step=1, | |
value=max_tokens, | |
info="Determines the maximum number of tokens in the response. Higher values allow for longer answers." | |
), | |
gr.Textbox( | |
label="Question", | |
placeholder="e.g., What is box breathing and how does it help reduce anxiety?" | |
), | |
gr.Textbox( | |
label="Feedback", | |
placeholder="Provide your feedback here...", | |
lines=2 | |
) | |
], | |
outputs=[ | |
"text", | |
"text" | |
], | |
title="Daily Wellness AI", | |
description="Ask questions about daily wellness and get detailed solutions.", | |
examples=[ | |
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?", "Great explanation!"], | |
["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques.", "Very helpful, thank you!"] | |
], | |
allow_flagging="never" # Disable default flagging; using custom feedback | |
) | |
# Launch Gradio app without share=True (not supported on Hugging Face Spaces) | |
if __name__ == "__main__": | |
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |