Spaces:
Running
Running
import sys | |
import streamlit as st | |
from pinecone import Pinecone | |
from langchain_openai import ChatOpenAI | |
from langchain_ollama import ChatOllama | |
from langgraph.pregel import GraphRecursionError | |
import tempfile | |
import os | |
import time | |
from pathlib import Path | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from src.vectorstore.pinecone_db import ingest_data, get_retriever, load_documents, process_chunks, save_to_parquet | |
from src.agents.workflow import run_adaptive_rag | |
# Page config | |
st.set_page_config(page_title="RAG Chat Assistant", layout="wide") | |
# Initialize session states | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "documents_processed" not in st.session_state: | |
st.session_state.documents_processed = False | |
if "retriever" not in st.session_state: | |
st.session_state.retriever = None | |
if "pinecone_client" not in st.session_state: | |
st.session_state.pinecone_client = None | |
def initialize_pinecone(api_key): | |
"""Initialize Pinecone client with API key.""" | |
try: | |
return Pinecone(api_key=api_key) | |
except Exception as e: | |
st.error(f"Error initializing Pinecone: {str(e)}") | |
return None | |
def initialize_llm(llm_option, openai_api_key=None): | |
"""Initialize LLM based on user selection.""" | |
if llm_option == "OpenAI": | |
if not openai_api_key: | |
st.sidebar.warning("Please enter OpenAI API key.") | |
return None | |
return ChatOpenAI(api_key=openai_api_key, model="gpt-3.5-turbo") | |
else: | |
return ChatOllama(model="llama3.2", temperature=0.3, num_predict=512, top_p=0.6) | |
def clear_pinecone_index(pc, index_name="vector-index"): | |
"""Clear the Pinecone index.""" | |
try: | |
pc.delete_index(index_name) | |
st.session_state.documents_processed = False | |
st.session_state.retriever = None | |
st.success("Database cleared successfully!") | |
except Exception as e: | |
st.error(f"Error clearing database: {str(e)}") | |
def process_documents(uploaded_files, pc): | |
"""Process uploaded documents and store in Pinecone.""" | |
if not uploaded_files: | |
st.warning("Please upload at least one document.") | |
return False | |
with st.spinner("Processing documents..."): | |
temp_dir = tempfile.mkdtemp() | |
file_paths = [] | |
markdown_path = Path(temp_dir) / "combined.md" | |
parquet_path = Path(temp_dir) / "documents.parquet" | |
for uploaded_file in uploaded_files: | |
file_path = Path(temp_dir) / uploaded_file.name | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getvalue()) | |
file_paths.append(str(file_path)) | |
try: | |
markdown_path = load_documents(file_paths, output_path=markdown_path) | |
chunks = process_chunks(markdown_path, chunk_size=256, threshold=0.6) | |
print(f"Processed chunks: {chunks}") | |
parquet_path = save_to_parquet(chunks, parquet_path) | |
ingest_data( | |
pc=pc, | |
parquet_path=parquet_path, | |
text_column="text", | |
pinecone_client=pc | |
) | |
st.session_state.retriever = get_retriever(pc) | |
st.session_state.documents_processed = True | |
return True | |
except Exception as e: | |
st.error(f"Error processing documents: {str(e)}") | |
return False | |
finally: | |
for file_path in file_paths: | |
try: | |
os.remove(file_path) | |
except: | |
pass | |
try: | |
os.rmdir(temp_dir) | |
except: | |
pass | |
def run_rag_with_streaming(retriever, question, llm, enable_web_search=False): | |
"""Run RAG workflow and yield streaming results.""" | |
try: | |
response = run_adaptive_rag( | |
retriever=retriever, | |
question=question, | |
llm=llm, | |
top_k=5, | |
enable_websearch=enable_web_search | |
) | |
for word in response.split(): | |
yield word + " " | |
time.sleep(0.03) | |
except GraphRecursionError: | |
response = "I apologize, but I cannot find a sufficient answer to your question in the provided documents. Please try rephrasing your question or ask something else about the content of the documents." | |
for word in response.split(): | |
yield word + " " | |
time.sleep(0.03) | |
except Exception as e: | |
yield f"I encountered an error while processing your question: {str(e)}" | |
def main(): | |
st.title("π€ RAG Chat Assistant") | |
# Sidebar configuration | |
st.sidebar.title("Configuration") | |
# API Keys in sidebar | |
pinecone_api_key = st.sidebar.text_input("Enter Pinecone API Key:", type="password") | |
# LLM Selection | |
llm_option = st.sidebar.selectbox("Select Language Model:", ["OpenAI"]) | |
openai_api_key = None | |
if llm_option == "OpenAI": | |
openai_api_key = st.sidebar.text_input("Enter OpenAI API Key:", type="password") | |
# Web search tool in sidebar | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### Tools") | |
use_web_search = st.sidebar.checkbox("Web search") | |
# Initialize Pinecone | |
if pinecone_api_key: | |
if st.session_state.pinecone_client is None: | |
st.session_state.pinecone_client = initialize_pinecone(pinecone_api_key) | |
else: | |
st.sidebar.warning("Please enter Pinecone API key to continue.") | |
st.stop() | |
# Initialize LLM | |
llm = initialize_llm(llm_option, openai_api_key) | |
if llm is None: | |
st.stop() | |
# Clear DB Button | |
st.sidebar.markdown("---") | |
if st.sidebar.button("Clear Database"): | |
if st.session_state.pinecone_client: | |
clear_pinecone_index(st.session_state.pinecone_client) | |
st.session_state.messages = [] # Clear chat history | |
# Document upload section | |
if not st.session_state.documents_processed: | |
st.header("π Document Upload") | |
uploaded_files = st.file_uploader( | |
"Upload your documents", | |
accept_multiple_files=True, | |
type=["pdf", "docx", "txt", "pptx", "md"] | |
) | |
if st.button("Process Documents"): | |
if process_documents(uploaded_files, st.session_state.pinecone_client): | |
st.success("Documents processed successfully!") | |
# Chat interface | |
if st.session_state.documents_processed: | |
st.header("π¬ Chat") | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input | |
if prompt := st.chat_input("Ask a question about your documents..."): | |
# Display user message | |
with st.chat_message("user"): | |
if use_web_search: | |
st.markdown(prompt.strip() + ''' :red-background[Web Search]''') | |
else: | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Generate and stream response | |
with st.chat_message("assistant"): | |
response_container = st.empty() | |
full_response = "" | |
# Show spinner while processing | |
with st.spinner("Thinking..."): | |
# Stream the response | |
for chunk in run_rag_with_streaming( | |
retriever=st.session_state.retriever, | |
question=prompt, | |
llm=llm, | |
enable_web_search=use_web_search | |
): | |
full_response += chunk | |
response_container.markdown(full_response + "β") | |
# Final update without cursor | |
response_container.markdown(full_response) | |
# Save to chat history | |
st.session_state.messages.append( | |
{"role": "assistant", "content": full_response} | |
) | |
if __name__ == "__main__": | |
main() |