Spaces:
Sleeping
Sleeping
# app.py | |
import json | |
from typing import List, Tuple | |
import os | |
import logging | |
import gradio as gr | |
from dotenv import load_dotenv | |
from slugify import slugify | |
from rag.rag_pipeline import RAGPipeline | |
from utils.helpers import ( | |
generate_follow_up_questions, | |
append_to_study_files, | |
add_study_files_to_chromadb, | |
chromadb_client, | |
) | |
from utils.prompts import ( | |
highlight_prompt, | |
evidence_based_prompt, | |
sample_questions, | |
) | |
import openai | |
from config import STUDY_FILES, OPENAI_API_KEY | |
from utils.zotero_manager import ZoteroManager | |
import csv | |
import io | |
import datetime | |
load_dotenv() | |
logging.basicConfig(level=logging.INFO) | |
openai.api_key = OPENAI_API_KEY | |
# After loop, add all collected data to ChromaDB | |
add_study_files_to_chromadb("study_files.json", "study_files_collection") | |
# Cache for RAG pipelines | |
rag_cache = {} | |
def process_zotero_library_items( | |
zotero_library_id: str, zotero_api_access_key: str | |
) -> str: | |
if not zotero_library_id or not zotero_api_access_key: | |
return "Please enter your zotero library Id and API Access Key" | |
zotero_library_id = zotero_library_id | |
zotero_library_type = "user" # or "group" | |
zotero_api_access_key = zotero_api_access_key | |
message = "" | |
try: | |
zotero_manager = ZoteroManager( | |
zotero_library_id, zotero_library_type, zotero_api_access_key | |
) | |
zotero_collections = zotero_manager.get_collections() | |
zotero_collection_lists = zotero_manager.list_zotero_collections( | |
zotero_collections | |
) | |
filtered_zotero_collection_lists = ( | |
zotero_manager.filter_and_return_collections_with_items( | |
zotero_collection_lists | |
) | |
) | |
study_files_data = {} # Dictionary to collect items for ChromaDB | |
for collection in filtered_zotero_collection_lists: | |
collection_name = collection.get("name") | |
if collection_name not in STUDY_FILES: | |
collection_key = collection.get("key") | |
collection_items = zotero_manager.get_collection_items(collection_key) | |
zotero_collection_items = ( | |
zotero_manager.get_collection_zotero_items_by_key(collection_key) | |
) | |
#### Export zotero collection items to json #### | |
zotero_items_json = zotero_manager.zotero_items_to_json( | |
zotero_collection_items | |
) | |
export_file = f"{slugify(collection_name)}_zotero_items.json" | |
zotero_manager.write_zotero_items_to_json_file( | |
zotero_items_json, f"data/{export_file}" | |
) | |
append_to_study_files( | |
"study_files.json", collection_name, f"data/{export_file}" | |
) | |
# Collect for ChromaDB | |
study_files_data[collection_name] = f"data/{export_file}" | |
# Update in-memory STUDY_FILES for reference in current session | |
STUDY_FILES.update({collection_name: f"data/{export_file}"}) | |
logging.info(f"STUDY_FILES: {STUDY_FILES}") | |
# After loop, add all collected data to ChromaDB | |
add_study_files_to_chromadb("study_files.json", "study_files_collection") | |
message = "Successfully processed items in your zotero library" | |
except Exception as e: | |
message = f"Error process your zotero library: {str(e)}" | |
return message | |
def get_rag_pipeline(study_name: str) -> RAGPipeline: | |
"""Get or create a RAGPipeline instance for the given study by querying ChromaDB.""" | |
if study_name not in rag_cache: | |
# Query ChromaDB for the study file path by ID | |
collection = chromadb_client.get_or_create_collection("study_files_collection") | |
result = collection.get(ids=[study_name]) # Retrieve document by ID | |
# Check if the result contains the requested document | |
if not result or len(result["metadatas"]) == 0: | |
raise ValueError(f"Invalid study name: {study_name}") | |
# Extract the file path from the document metadata | |
study_file = result["metadatas"][0].get("file_path") | |
if not study_file: | |
raise ValueError(f"File path not found for study name: {study_name}") | |
# Create and cache the RAGPipeline instance | |
rag_cache[study_name] = RAGPipeline(study_file) | |
return rag_cache[study_name] | |
def chat_function(message: str, study_name: str, prompt_type: str) -> str: | |
"""Process a chat message and generate a response using the RAG pipeline.""" | |
if not message.strip(): | |
return "Please enter a valid query." | |
rag = get_rag_pipeline(study_name) | |
logging.info(f"rag: ==> {rag}") | |
prompt = { | |
"Highlight": highlight_prompt, | |
"Evidence-based": evidence_based_prompt, | |
}.get(prompt_type) | |
response = rag.query(message, prompt_template=prompt) | |
return response.response | |
def get_study_info(study_name: str) -> str: | |
"""Retrieve information about the specified study.""" | |
collection = chromadb_client.get_or_create_collection("study_files_collection") | |
result = collection.get(ids=[study_name]) # Query by study name (as a list) | |
logging.info(f"Result: ======> {result}") | |
# Check if the document exists in the result | |
if not result or len(result["metadatas"]) == 0: | |
raise ValueError(f"Invalid study name: {study_name}") | |
# Extract the file path from the document metadata | |
study_file = result["metadatas"][0].get("file_path") | |
logging.info(f"study_file: =======> {study_file}") | |
if not study_file: | |
raise ValueError(f"File path not found for study name: {study_name}") | |
with open(study_file, "r") as f: | |
data = json.load(f) | |
return f"### Number of documents: {len(data)}" | |
def markdown_table_to_csv(markdown_text: str) -> str: | |
"""Convert a markdown table to CSV format.""" | |
# Split the text into lines and remove empty lines | |
lines = [line.strip() for line in markdown_text.split("\n") if line.strip()] | |
# Find the table content (lines starting with |) | |
table_lines = [line for line in lines if line.startswith("|")] | |
if not table_lines: | |
return "" | |
# Process each line to extract cell values | |
csv_data = [] | |
for line in table_lines: | |
# Skip separator lines (containing only dashes) | |
if "---" in line: | |
continue | |
# Split by |, remove empty strings, and strip whitespace | |
cells = [cell.strip() for cell in line.split("|") if cell.strip()] | |
csv_data.append(cells) | |
# Create CSV string | |
output = io.StringIO() | |
writer = csv.writer(output) | |
writer.writerows(csv_data) | |
return output.getvalue() | |
def update_interface(study_name: str) -> Tuple[str, gr.update, gr.update, gr.update]: | |
"""Update the interface based on the selected study.""" | |
study_info = get_study_info(study_name) | |
questions = sample_questions.get(study_name, [])[:3] | |
if not questions: | |
questions = sample_questions.get("General", [])[:3] | |
visible_questions = [gr.update(visible=True, value=q) for q in questions] | |
hidden_questions = [gr.update(visible=False) for _ in range(3 - len(questions))] | |
return (study_info, *visible_questions, *hidden_questions) | |
def set_question(question: str) -> str: | |
return question.lstrip("✨ ") | |
def process_multi_input(text, study_name, prompt_type): | |
# Split input based on commas and strip any extra spaces | |
variable_list = [word.strip().upper() for word in text.split(",")] | |
user_message = f"Extract and present in a tabular format the following variables for each {study_name} study: {', '.join(variable_list)}" | |
logging.info(f"User message: ==> {user_message}") | |
response = chat_function(user_message, study_name, prompt_type) | |
return [response, gr.update(visible=True)] | |
def create_gr_interface() -> gr.Blocks: | |
""" | |
Create and configure the Gradio interface for the RAG platform. | |
This function sets up the entire user interface, including: | |
- Chat interface with message input and display | |
- Study selection dropdown | |
- Sample and follow-up question buttons | |
- Prompt type selection | |
- Event handlers for user interactions | |
Returns: | |
gr.Blocks: The configured Gradio interface ready for launching. | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown("# ACRES RAG Platform") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Zotero Credentials") | |
zotero_library_id = gr.Textbox( | |
label="Zotero Library ID", | |
type="password", | |
placeholder="Enter Your Zotero Library ID here...", | |
) | |
zotero_api_access_key = gr.Textbox( | |
label="Zotero API Access Key", | |
type="password", | |
placeholder="Enter Your Zotero API Access Key...", | |
) | |
process_zotero_btn = gr.Button("Process your Zotero Library") | |
zotero_output = gr.Markdown(label="Zotero") | |
gr.Markdown("### Study Information") | |
# Query ChromaDB for all document IDs in the "study_files_collection" collection | |
collection = chromadb_client.get_or_create_collection( | |
"study_files_collection" | |
) | |
# Retrieve all documents by querying with an empty string and specifying a high n_results | |
all_documents = collection.query(query_texts=[""], n_results=1000) | |
logging.info(f"all_documents: =========> {all_documents}") | |
# Extract document IDs as study names | |
document_ids = all_documents.get("ids") | |
study_choices = [ | |
doc_id for doc_id in document_ids[0] if document_ids | |
] # Get list of document IDs | |
logging.info(f"study_choices: ======> {study_choices}") | |
# Update the Dropdown with choices from ChromaDB | |
study_dropdown = gr.Dropdown( | |
choices=study_choices, | |
label="Select Study", | |
value=( | |
study_choices[0] if study_choices else None | |
), # Set first choice as default, if available | |
) | |
study_info = gr.Markdown(label="Study Details") | |
gr.Markdown("### Settings") | |
prompt_type = gr.Radio( | |
["Default", "Highlight", "Evidence-based"], | |
label="Prompt Type", | |
value="Default", | |
) | |
# clear = gr.Button("Clear Chat") | |
with gr.Column(scale=3): | |
gr.Markdown("### Study Variables") | |
with gr.Row(): | |
study_variables = gr.Textbox( | |
show_label=False, | |
placeholder="Type your variables separated by commas e.g (Study ID, Study Title, Authors etc)", | |
scale=4, | |
lines=1, | |
autofocus=True, | |
) | |
submit_btn = gr.Button("Submit", scale=1) | |
answer_output = gr.Markdown(label="Answer") | |
# button to download_csv | |
download_btn = gr.DownloadButton( | |
"Download as CSV", | |
variant="primary", | |
size="sm", | |
scale=1, | |
visible=False, | |
) | |
def download_as_csv(markdown_content): | |
"""Convert markdown table to CSV and provide for download.""" | |
if not markdown_content: | |
return None | |
csv_content = markdown_table_to_csv(markdown_content) | |
if not csv_content: | |
return None | |
# Create temporary file with actual content | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
temp_path = f"study_export_{timestamp}.csv" | |
with open(temp_path, "w", newline="", encoding="utf-8") as f: | |
f.write(csv_content) | |
return temp_path | |
def cleanup_temp_files(): | |
"""Clean up old temporary files.""" | |
try: | |
# Delete files older than 5 minutes | |
current_time = datetime.datetime.now() | |
for file in os.listdir(): | |
if file.startswith("study_export_") and file.endswith(".csv"): | |
file_time = datetime.datetime.fromtimestamp( | |
os.path.getmtime(file) | |
) | |
if (current_time - file_time).seconds > 30: # 5 minutes | |
try: | |
os.remove(file) | |
except Exception as e: | |
logging.warning( | |
f"Failed to remove temp file {file}: {e}" | |
) | |
except Exception as e: | |
logging.warning(f"Error during cleanup: {e}") | |
study_dropdown.change( | |
fn=get_study_info, | |
inputs=study_dropdown, | |
outputs=[study_info], | |
) | |
process_zotero_btn.click( | |
process_zotero_library_items, | |
inputs=[zotero_library_id, zotero_api_access_key], | |
outputs=[zotero_output], | |
queue=False, | |
) | |
submit_btn.click( | |
process_multi_input, | |
inputs=[study_variables, study_dropdown, prompt_type], | |
outputs=[answer_output, download_btn], | |
queue=False, | |
) | |
download_btn.click( | |
fn=download_as_csv, | |
inputs=[answer_output], | |
outputs=[download_btn], | |
).then( | |
fn=cleanup_temp_files, inputs=None, outputs=None # Clean up after download | |
) | |
return demo | |
demo = create_gr_interface() | |
if __name__ == "__main__": | |
# demo = create_gr_interface() | |
demo.launch(share=True, debug=True) | |