Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import re | |
import nltk | |
import spacy | |
import traceback | |
from nltk.tokenize import sent_tokenize | |
from langchain.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_groq import ChatGroq | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
import chardet | |
import gradio as gr | |
import pandas as pd | |
import json | |
# Enable logging for debugging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Set NLTK data path to the local 'nltk_data' directory | |
nltk.data.path.append(os.path.join(os.path.dirname(__file__), 'nltk_data')) | |
logger.debug("Configured NLTK data path to local 'nltk_data' directory.") | |
# Function to clean the API key | |
def clean_api_key(key): | |
return ''.join(c for c in key if ord(c) < 128) | |
# Load the GROQ API key from environment variables (set as a secret in the Space) | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
api_key = clean_api_key(api_key).strip() # Clean and strip whitespace | |
# Function to clean text by removing non-ASCII characters | |
def clean_text(text): | |
return text.encode("ascii", errors="ignore").decode() | |
# Function to load and clean documents from multiple file formats | |
def load_documents(file_paths): | |
docs = [] | |
for file_path in file_paths: | |
ext = os.path.splitext(file_path)[-1].lower() | |
try: | |
if ext == ".csv": | |
# Handle CSV files | |
with open(file_path, 'rb') as f: | |
result = chardet.detect(f.read()) | |
encoding = result['encoding'] | |
data = pd.read_csv(file_path, encoding=encoding) | |
for index, row in data.iterrows(): | |
content = clean_text(row.to_string()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".json": | |
# Handle JSON files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
for entry in data: | |
content = clean_text(json.dumps(entry)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif isinstance(data, dict): | |
content = clean_text(json.dumps(data)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".txt": | |
# Handle TXT files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = clean_text(f.read()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
else: | |
logger.warning(f"Unsupported file format: {file_path}") | |
except Exception as e: | |
logger.error(f"Error processing file {file_path}: {e}") | |
logger.error(traceback.format_exc()) | |
return docs | |
# Function to ensure the response ends with complete sentences using NLTK | |
def ensure_complete_sentences(text): | |
logger.debug("Ensuring complete sentences for the given text.") | |
try: | |
sentences = sent_tokenize(text) | |
if sentences: | |
return ' '.join(sentences).strip() | |
return text # Return as is if no complete sentence is found | |
except LookupError as e: | |
logger.error("NLTK resource 'punkt' not found. Attempting to download again.") | |
try: | |
nltk.download('punkt', download_dir=os.path.join(os.path.dirname(__file__), 'nltk_data')) | |
nltk.data.path.append(os.path.join(os.path.dirname(__file__), 'nltk_data')) | |
sentences = sent_tokenize(text) | |
return ' '.join(sentences).strip() | |
except Exception as e_inner: | |
logger.error("Failed to download 'punkt' resource.") | |
logger.error(traceback.format_exc()) | |
raise e_inner | |
except Exception as e: | |
logger.error("Unexpected error during sentence tokenization.") | |
logger.error(traceback.format_exc()) | |
raise e | |
# Advanced input validation using spaCy (Section 8a) | |
def is_valid_input_nlp(text, threshold=0.5): | |
""" | |
Validates input text using spaCy's NLP capabilities. | |
Parameters: | |
- text (str): The input text to validate. | |
- threshold (float): The minimum ratio of meaningful tokens required. | |
Returns: | |
- bool: True if the input is valid, False otherwise. | |
""" | |
if not text or text.strip() == "": | |
logger.debug("Input text is empty or contains only whitespace.") | |
return False | |
doc = nlp(text) | |
meaningful_tokens = [token for token in doc if token.is_alpha] | |
if not meaningful_tokens: | |
logger.debug("No meaningful (alphabetic) tokens found in input.") | |
return False | |
ratio = len(meaningful_tokens) / len(doc) | |
logger.debug(f"Meaningful tokens ratio: {ratio}") | |
return ratio >= threshold | |
# Function to estimate prompt tokens (simple word count approximation) | |
def estimate_prompt_tokens(prompt): | |
""" | |
Estimates the number of tokens in the prompt. | |
This is a placeholder function. Replace it with actual token estimation logic. | |
Parameters: | |
- prompt (str): The prompt text. | |
Returns: | |
- int: Estimated number of tokens. | |
""" | |
return len(prompt.split()) | |
# Initialize the LLM using ChatGroq with GROQ's API | |
def initialize_llm(model, temperature, max_tokens, prompt_template): | |
try: | |
# Estimate prompt tokens | |
estimated_prompt_tokens = estimate_prompt_tokens(prompt_template) | |
logger.debug(f"Estimated prompt tokens: {estimated_prompt_tokens}") | |
# Allocate remaining tokens to response | |
response_max_tokens = max_tokens - estimated_prompt_tokens | |
logger.debug(f"Response max tokens: {response_max_tokens}") | |
if response_max_tokens <= 100: | |
raise ValueError("max_tokens is too small to allocate for the response.") | |
llm = ChatGroq( | |
model=model, | |
temperature=temperature, | |
max_tokens=response_max_tokens, # Adjusted max_tokens | |
api_key=api_key # Ensure the API key is passed correctly | |
) | |
logger.debug("LLM initialized successfully.") | |
return llm | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {e}") | |
logger.error(traceback.format_exc()) | |
raise e | |
# Create the RAG pipeline | |
def create_rag_pipeline(file_paths, model, temperature, max_tokens): | |
try: | |
# Define the prompt template first to estimate tokens | |
custom_prompt_template = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity. | |
Context: | |
{context} | |
Question: | |
{question} | |
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly. | |
""" | |
) | |
# Estimate prompt tokens | |
estimated_prompt_tokens = estimate_prompt_tokens(custom_prompt_template.template) | |
logger.debug(f"Estimated prompt tokens from template: {estimated_prompt_tokens}") | |
# Initialize the LLM with token allocation | |
llm = initialize_llm(model, temperature, max_tokens, custom_prompt_template.template) | |
# Load and process documents | |
docs = load_documents(file_paths) | |
if not docs: | |
logger.warning("No documents were loaded. Please check your file paths and formats.") | |
return None, "No documents were loaded. Please check your file paths and formats." | |
# Split documents into chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
logger.debug(f"Documents split into {len(splits)} chunks.") | |
# Initialize the embedding model | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
logger.debug("Embedding model initialized successfully.") | |
# Use a temporary directory for Chroma vectorstore to prevent caching issues on Hugging Face Spaces | |
vectorstore = Chroma.from_documents( | |
documents=splits, | |
embedding=embedding_model, | |
persist_directory="/tmp/chroma_db" # Temporary storage directory | |
) | |
vectorstore.persist() # Save the database to disk | |
logger.debug("Vectorstore initialized and persisted successfully.") | |
retriever = vectorstore.as_retriever() | |
# Create the RetrievalQA chain | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": custom_prompt_template} | |
) | |
logger.debug("RAG pipeline created successfully.") | |
return rag_chain, "Pipeline created successfully." | |
except Exception as e: | |
logger.error(f"Error creating RAG pipeline: {e}") | |
logger.error(traceback.format_exc()) | |
return None, f"Error creating RAG pipeline: {e}" | |
# Function to handle feedback (Section 8d) | |
def handle_feedback(feedback_text): | |
""" | |
Handles user feedback by logging it. | |
In a production environment, consider storing feedback in a database or external service. | |
Parameters: | |
- feedback_text (str): The feedback provided by the user. | |
Returns: | |
- str: Acknowledgment message. | |
""" | |
if feedback_text and feedback_text.strip() != "": | |
# For demonstration, we'll log the feedback. Replace this with database storage if needed. | |
logger.info(f"User Feedback: {feedback_text}") | |
return "Thank you for your feedback!" | |
else: | |
return "No feedback provided." | |
# Function to answer questions with input validation and post-processing | |
def answer_question(file_paths, model, temperature, max_tokens, question, feedback): | |
try: | |
# Validate input using spaCy-based validation | |
if not is_valid_input_nlp(question): | |
logger.debug("Invalid input detected.") | |
return "Please provide a valid question or input containing meaningful text.", "" | |
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) | |
if rag_chain is None: | |
logger.debug("RAG pipeline creation failed.") | |
return message, "" | |
try: | |
answer = rag_chain.run(question) | |
logger.debug("Question answered successfully.") | |
# Post-process to ensure the answer ends with complete sentences | |
complete_answer = ensure_complete_sentences(answer) | |
# Handle feedback | |
feedback_response = handle_feedback(feedback) | |
return complete_answer, feedback_response | |
except Exception as e_inner: | |
logger.error(f"Error during RAG pipeline execution: {e_inner}") | |
logger.error(traceback.format_exc()) | |
return f"Error during RAG pipeline execution: {e_inner}", "" | |
except Exception as e_outer: | |
logger.error(f"Unexpected error in answer_question: {e_outer}") | |
logger.error(traceback.format_exc()) | |
return f"Unexpected error: {e_outer}", "" | |
# Gradio Interface with Feedback Mechanism (Section 8d) | |
def gradio_interface(model, temperature, max_tokens, question, feedback): | |
file_paths = ['AIChatbot.csv'] # Ensure this file is present in your Space root directory | |
return answer_question(file_paths, model, temperature, max_tokens, question, feedback) | |
# Define Gradio UI | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox( | |
label="Model Name", | |
value="llama3-8b-8192", | |
placeholder="e.g., llama3-8b-8192" | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=0.7, | |
info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." | |
), | |
gr.Slider( | |
label="Max Tokens", | |
minimum=200, | |
maximum=2048, | |
step=1, | |
value=500, | |
info="Determines the maximum number of tokens in the response. Higher values allow for longer answers." | |
), | |
gr.Textbox( | |
label="Question", | |
placeholder="e.g., What is box breathing and how does it help reduce anxiety?" | |
), | |
gr.Textbox( | |
label="Feedback", | |
placeholder="Provide your feedback here...", | |
lines=2 | |
) | |
], | |
outputs=[ | |
"text", | |
"text" | |
], | |
title="Daily Wellness AI", | |
description="Ask questions about daily wellness and get detailed solutions.", | |
examples=[ | |
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?", "Great explanation!"], | |
["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques.", "Very helpful, thank you!"] | |
], | |
allow_flagging="never" # Disable default flagging; using custom feedback | |
) | |
# Launch Gradio app without share=True (not supported on Hugging Face Spaces) | |
if __name__ == "__main__": | |
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |