acres / app.py
ak3ra's picture
add csv export
7ddc93d
raw
history blame
14.2 kB
# app.py
import json
from typing import List, Tuple
import os
import logging
import gradio as gr
from dotenv import load_dotenv
from slugify import slugify
from rag.rag_pipeline import RAGPipeline
from utils.helpers import (
generate_follow_up_questions,
append_to_study_files,
add_study_files_to_chromadb,
chromadb_client,
)
from utils.prompts import (
highlight_prompt,
evidence_based_prompt,
sample_questions,
)
import openai
from config import STUDY_FILES, OPENAI_API_KEY
from utils.zotero_manager import ZoteroManager
import csv
import io
import datetime
load_dotenv()
logging.basicConfig(level=logging.INFO)
openai.api_key = OPENAI_API_KEY
# After loop, add all collected data to ChromaDB
add_study_files_to_chromadb("study_files.json", "study_files_collection")
# Cache for RAG pipelines
rag_cache = {}
def process_zotero_library_items(
zotero_library_id: str, zotero_api_access_key: str
) -> str:
if not zotero_library_id or not zotero_api_access_key:
return "Please enter your zotero library Id and API Access Key"
zotero_library_id = zotero_library_id
zotero_library_type = "user" # or "group"
zotero_api_access_key = zotero_api_access_key
message = ""
try:
zotero_manager = ZoteroManager(
zotero_library_id, zotero_library_type, zotero_api_access_key
)
zotero_collections = zotero_manager.get_collections()
zotero_collection_lists = zotero_manager.list_zotero_collections(
zotero_collections
)
filtered_zotero_collection_lists = (
zotero_manager.filter_and_return_collections_with_items(
zotero_collection_lists
)
)
study_files_data = {} # Dictionary to collect items for ChromaDB
for collection in filtered_zotero_collection_lists:
collection_name = collection.get("name")
if collection_name not in STUDY_FILES:
collection_key = collection.get("key")
collection_items = zotero_manager.get_collection_items(collection_key)
zotero_collection_items = (
zotero_manager.get_collection_zotero_items_by_key(collection_key)
)
#### Export zotero collection items to json ####
zotero_items_json = zotero_manager.zotero_items_to_json(
zotero_collection_items
)
export_file = f"{slugify(collection_name)}_zotero_items.json"
zotero_manager.write_zotero_items_to_json_file(
zotero_items_json, f"data/{export_file}"
)
append_to_study_files(
"study_files.json", collection_name, f"data/{export_file}"
)
# Collect for ChromaDB
study_files_data[collection_name] = f"data/{export_file}"
# Update in-memory STUDY_FILES for reference in current session
STUDY_FILES.update({collection_name: f"data/{export_file}"})
logging.info(f"STUDY_FILES: {STUDY_FILES}")
# After loop, add all collected data to ChromaDB
add_study_files_to_chromadb("study_files.json", "study_files_collection")
message = "Successfully processed items in your zotero library"
except Exception as e:
message = f"Error process your zotero library: {str(e)}"
return message
def get_rag_pipeline(study_name: str) -> RAGPipeline:
"""Get or create a RAGPipeline instance for the given study by querying ChromaDB."""
if study_name not in rag_cache:
# Query ChromaDB for the study file path by ID
collection = chromadb_client.get_or_create_collection("study_files_collection")
result = collection.get(ids=[study_name]) # Retrieve document by ID
# Check if the result contains the requested document
if not result or len(result["metadatas"]) == 0:
raise ValueError(f"Invalid study name: {study_name}")
# Extract the file path from the document metadata
study_file = result["metadatas"][0].get("file_path")
if not study_file:
raise ValueError(f"File path not found for study name: {study_name}")
# Create and cache the RAGPipeline instance
rag_cache[study_name] = RAGPipeline(study_file)
return rag_cache[study_name]
def chat_function(message: str, study_name: str, prompt_type: str) -> str:
"""Process a chat message and generate a response using the RAG pipeline."""
if not message.strip():
return "Please enter a valid query."
rag = get_rag_pipeline(study_name)
logging.info(f"rag: ==> {rag}")
prompt = {
"Highlight": highlight_prompt,
"Evidence-based": evidence_based_prompt,
}.get(prompt_type)
response = rag.query(message, prompt_template=prompt)
return response.response
def get_study_info(study_name: str) -> str:
"""Retrieve information about the specified study."""
collection = chromadb_client.get_or_create_collection("study_files_collection")
result = collection.get(ids=[study_name]) # Query by study name (as a list)
logging.info(f"Result: ======> {result}")
# Check if the document exists in the result
if not result or len(result["metadatas"]) == 0:
raise ValueError(f"Invalid study name: {study_name}")
# Extract the file path from the document metadata
study_file = result["metadatas"][0].get("file_path")
logging.info(f"study_file: =======> {study_file}")
if not study_file:
raise ValueError(f"File path not found for study name: {study_name}")
with open(study_file, "r") as f:
data = json.load(f)
return f"### Number of documents: {len(data)}"
def markdown_table_to_csv(markdown_text: str) -> str:
"""Convert a markdown table to CSV format."""
# Split the text into lines and remove empty lines
lines = [line.strip() for line in markdown_text.split("\n") if line.strip()]
# Find the table content (lines starting with |)
table_lines = [line for line in lines if line.startswith("|")]
if not table_lines:
return ""
# Process each line to extract cell values
csv_data = []
for line in table_lines:
# Skip separator lines (containing only dashes)
if "---" in line:
continue
# Split by |, remove empty strings, and strip whitespace
cells = [cell.strip() for cell in line.split("|") if cell.strip()]
csv_data.append(cells)
# Create CSV string
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(csv_data)
return output.getvalue()
def update_interface(study_name: str) -> Tuple[str, gr.update, gr.update, gr.update]:
"""Update the interface based on the selected study."""
study_info = get_study_info(study_name)
questions = sample_questions.get(study_name, [])[:3]
if not questions:
questions = sample_questions.get("General", [])[:3]
visible_questions = [gr.update(visible=True, value=q) for q in questions]
hidden_questions = [gr.update(visible=False) for _ in range(3 - len(questions))]
return (study_info, *visible_questions, *hidden_questions)
def set_question(question: str) -> str:
return question.lstrip("✨ ")
def process_multi_input(text, study_name, prompt_type):
# Split input based on commas and strip any extra spaces
variable_list = [word.strip().upper() for word in text.split(",")]
user_message = f"Extract and present in a tabular format the following variables for each {study_name} study: {', '.join(variable_list)}"
logging.info(f"User message: ==> {user_message}")
response = chat_function(user_message, study_name, prompt_type)
return [response, gr.update(visible=True)]
def create_gr_interface() -> gr.Blocks:
"""
Create and configure the Gradio interface for the RAG platform.
This function sets up the entire user interface, including:
- Chat interface with message input and display
- Study selection dropdown
- Sample and follow-up question buttons
- Prompt type selection
- Event handlers for user interactions
Returns:
gr.Blocks: The configured Gradio interface ready for launching.
"""
with gr.Blocks() as demo:
gr.Markdown("# ACRES RAG Platform")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Zotero Credentials")
zotero_library_id = gr.Textbox(
label="Zotero Library ID",
type="password",
placeholder="Enter Your Zotero Library ID here...",
)
zotero_api_access_key = gr.Textbox(
label="Zotero API Access Key",
type="password",
placeholder="Enter Your Zotero API Access Key...",
)
process_zotero_btn = gr.Button("Process your Zotero Library")
zotero_output = gr.Markdown(label="Zotero")
gr.Markdown("### Study Information")
# Query ChromaDB for all document IDs in the "study_files_collection" collection
collection = chromadb_client.get_or_create_collection(
"study_files_collection"
)
# Retrieve all documents by querying with an empty string and specifying a high n_results
all_documents = collection.query(query_texts=[""], n_results=1000)
logging.info(f"all_documents: =========> {all_documents}")
# Extract document IDs as study names
document_ids = all_documents.get("ids")
study_choices = [
doc_id for doc_id in document_ids[0] if document_ids
] # Get list of document IDs
logging.info(f"study_choices: ======> {study_choices}")
# Update the Dropdown with choices from ChromaDB
study_dropdown = gr.Dropdown(
choices=study_choices,
label="Select Study",
value=(
study_choices[0] if study_choices else None
), # Set first choice as default, if available
)
study_info = gr.Markdown(label="Study Details")
gr.Markdown("### Settings")
prompt_type = gr.Radio(
["Default", "Highlight", "Evidence-based"],
label="Prompt Type",
value="Default",
)
# clear = gr.Button("Clear Chat")
with gr.Column(scale=3):
gr.Markdown("### Study Variables")
with gr.Row():
study_variables = gr.Textbox(
show_label=False,
placeholder="Type your variables separated by commas e.g (Study ID, Study Title, Authors etc)",
scale=4,
lines=1,
autofocus=True,
)
submit_btn = gr.Button("Submit", scale=1)
answer_output = gr.Markdown(label="Answer")
# button to download_csv
download_btn = gr.DownloadButton(
"Download as CSV",
variant="primary",
size="sm",
scale=1,
visible=False,
)
def download_as_csv(markdown_content):
"""Convert markdown table to CSV and provide for download."""
if not markdown_content:
return None
csv_content = markdown_table_to_csv(markdown_content)
if not csv_content:
return None
# Create temporary file with actual content
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
temp_path = f"study_export_{timestamp}.csv"
with open(temp_path, "w", newline="", encoding="utf-8") as f:
f.write(csv_content)
return temp_path
def cleanup_temp_files():
"""Clean up old temporary files."""
try:
# Delete files older than 5 minutes
current_time = datetime.datetime.now()
for file in os.listdir():
if file.startswith("study_export_") and file.endswith(".csv"):
file_time = datetime.datetime.fromtimestamp(
os.path.getmtime(file)
)
if (current_time - file_time).seconds > 30: # 5 minutes
try:
os.remove(file)
except Exception as e:
logging.warning(
f"Failed to remove temp file {file}: {e}"
)
except Exception as e:
logging.warning(f"Error during cleanup: {e}")
study_dropdown.change(
fn=get_study_info,
inputs=study_dropdown,
outputs=[study_info],
)
process_zotero_btn.click(
process_zotero_library_items,
inputs=[zotero_library_id, zotero_api_access_key],
outputs=[zotero_output],
queue=False,
)
submit_btn.click(
process_multi_input,
inputs=[study_variables, study_dropdown, prompt_type],
outputs=[answer_output, download_btn],
queue=False,
)
download_btn.click(
fn=download_as_csv,
inputs=[answer_output],
outputs=[download_btn],
).then(
fn=cleanup_temp_files, inputs=None, outputs=None # Clean up after download
)
return demo
demo = create_gr_interface()
if __name__ == "__main__":
# demo = create_gr_interface()
demo.launch(share=True, debug=True)