diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,4 +1,2785 @@
-import streamlit as st
-
-x = st.slider('Select a value')
-st.write(x, 'squared is', x * x)
+# app.py
+
+import streamlit as st
+
+# Set page config first, before any other st commands
+st.set_page_config(page_title="SNAP", layout="wide")
+
+# Add warning filters
+import warnings
+# More specific warning filters for torch.classes
+warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*')
+warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*')
+
+import pandas as pd
+import numpy as np
+import os
+import io
+import time
+from datetime import datetime
+import base64
+import re
+import pickle
+from typing import List, Dict, Any, Tuple
+import plotly.express as px
+import torch
+
+# For parallelism
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+
+# Import necessary libraries for embeddings, clustering, and summarization
+from sentence_transformers import SentenceTransformer
+from sklearn.metrics.pairwise import cosine_similarity
+from bertopic import BERTopic
+from hdbscan import HDBSCAN
+import nltk
+from nltk.corpus import stopwords
+from nltk.tokenize import word_tokenize
+
+# For summarization and chat
+from langchain.chains import LLMChain
+from langchain_community.chat_models import ChatOpenAI
+from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
+from openai import OpenAI
+from transformers import GPT2TokenizerFast
+
+# Initialize OpenAI client and tokenizer
+client = OpenAI()
+
+###############################################################################
+# Helper: Attempt to get this file's directory or fallback to current working dir
+###############################################################################
+def get_base_dir():
+ try:
+ base_dir = os.path.dirname(__file__)
+ if not base_dir:
+ return os.getcwd()
+ return base_dir
+ except NameError:
+ # In case __file__ is not defined (some environments)
+ return os.getcwd()
+
+BASE_DIR = get_base_dir()
+
+# Function to get or create model directory
+def get_model_dir():
+ base_dir = get_base_dir()
+ model_dir = os.path.join(base_dir, 'models')
+ os.makedirs(model_dir, exist_ok=True)
+ return model_dir
+
+# Function to load tokenizer from local storage or download
+def load_tokenizer():
+ model_dir = get_model_dir()
+ tokenizer_dir = os.path.join(model_dir, 'tokenizer')
+ os.makedirs(tokenizer_dir, exist_ok=True)
+
+ try:
+ # Try to load from local directory first
+ tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir)
+ #st.success("Loaded tokenizer from local storage")
+ except Exception as e:
+ #st.warning("Downloading tokenizer (one-time operation)...")
+ try:
+ # Download and save to local directory
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Use standard GPT2 tokenizer
+ tokenizer.save_pretrained(tokenizer_dir)
+ #st.success("Downloaded and saved tokenizer")
+ except Exception as download_e:
+ #st.error(f"Error downloading tokenizer: {str(download_e)}")
+ raise
+
+ return tokenizer
+
+# Load tokenizer
+try:
+ tokenizer = load_tokenizer()
+except Exception as e:
+ #st.error("Failed to load tokenizer. Some functionality may be limited.")
+ tokenizer = None
+
+MAX_CONTEXT_WINDOW = 128000 # GPT-4o context window size
+
+# Initialize chat history in session state if not exists
+if 'chat_history' not in st.session_state:
+ st.session_state.chat_history = []
+
+###############################################################################
+# Helper: Get chat response from OpenAI
+###############################################################################
+def get_chat_response(messages):
+ try:
+ response = client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=messages,
+ temperature=0,
+ )
+ return response.choices[0].message.content.strip()
+ except Exception as e:
+ st.error(f"Error querying OpenAI: {e}")
+ return None
+
+###############################################################################
+# Helper: Generate raw summary for a cluster (without references)
+###############################################################################
+def generate_raw_cluster_summary(
+ topic_val: int,
+ cluster_df: pd.DataFrame,
+ llm: Any,
+ chat_prompt: Any
+) -> Dict[str, Any]:
+ """Generate a summary for a single cluster without reference enhancement,
+ automatically trimming text if it exceeds a safe token limit."""
+ cluster_text = " ".join(cluster_df['text'].tolist())
+ if not cluster_text.strip():
+ return None
+
+ # Define a safe limit (95% of max context window to leave room for prompts)
+ safe_limit = int(MAX_CONTEXT_WINDOW * 0.95)
+
+ # Encode the text into tokens
+ encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False)
+
+ # If the text is too large, slice it
+ if len(encoded_text) > safe_limit:
+ #st.warning(f"Cluster {topic_val} text is too large ({len(encoded_text)} tokens). Trimming to {safe_limit} tokens.")
+ encoded_text = encoded_text[:safe_limit]
+ cluster_text = tokenizer.decode(encoded_text)
+
+ user_prompt_local = f"**Text to summarize**: {cluster_text}"
+ try:
+ local_chain = LLMChain(llm=llm, prompt=chat_prompt)
+ summary_local = local_chain.run(user_prompt=user_prompt_local).strip()
+ return {'Topic': topic_val, 'Summary': summary_local}
+ except Exception as e:
+ st.error(f"Error generating summary for cluster {topic_val}: {str(e)}")
+ return None
+
+###############################################################################
+# Helper: Enhance a summary with references
+###############################################################################
+def enhance_summary_with_references(
+ summary_dict: Dict[str, Any],
+ df_scope: pd.DataFrame,
+ reference_id_column: str,
+ url_column: str = None,
+ llm: Any = None
+) -> Dict[str, Any]:
+ """Add references to a summary."""
+ if not summary_dict or 'Summary' not in summary_dict:
+ return summary_dict
+
+ try:
+ cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']]
+ enhanced = add_references_to_summary(
+ summary_dict['Summary'],
+ cluster_df,
+ reference_id_column,
+ url_column,
+ llm
+ )
+ summary_dict['Enhanced_Summary'] = enhanced
+ return summary_dict
+ except Exception as e:
+ st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}")
+ return summary_dict
+
+###############################################################################
+# Helper: Process summaries in parallel
+###############################################################################
+def process_summaries_in_parallel(
+ df_scope: pd.DataFrame,
+ unique_selected_topics: List[int],
+ llm: Any,
+ chat_prompt: Any,
+ enable_references: bool = False,
+ reference_id_column: str = None,
+ url_column: str = None,
+ max_workers: int = 16
+) -> List[Dict[str, Any]]:
+ """Process multiple cluster summaries in parallel using ThreadPoolExecutor."""
+ summaries = []
+ total_topics = len(unique_selected_topics)
+
+ # Create progress placeholders
+ progress_text = st.empty()
+ progress_bar = st.progress(0)
+
+ try:
+ # Phase 1: Generate raw summaries in parallel
+ progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)")
+ completed_summaries = 0
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit summary generation tasks
+ future_to_topic = {
+ executor.submit(
+ generate_raw_cluster_summary,
+ topic_val,
+ df_scope[df_scope['Topic'] == topic_val],
+ llm,
+ chat_prompt
+ ): topic_val
+ for topic_val in unique_selected_topics
+ }
+
+ # Process completed summary tasks
+ for future in future_to_topic:
+ try:
+ result = future.result()
+ if result:
+ summaries.append(result)
+ completed_summaries += 1
+ # Update progress
+ progress = completed_summaries / total_topics
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_topic[future]
+ st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}")
+ completed_summaries += 1
+ continue
+
+ # Phase 2: Enhance summaries with references in parallel (if enabled)
+ if enable_references and reference_id_column and summaries:
+ total_to_enhance = len(summaries)
+ completed_enhancements = 0
+ progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)")
+ progress_bar.progress(0)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit reference enhancement tasks
+ future_to_summary = {
+ executor.submit(
+ enhance_summary_with_references,
+ summary_dict,
+ df_scope,
+ reference_id_column,
+ url_column,
+ llm
+ ): summary_dict.get('Topic')
+ for summary_dict in summaries
+ }
+
+ # Process completed enhancement tasks
+ enhanced_summaries = []
+ for future in future_to_summary:
+ try:
+ result = future.result()
+ if result:
+ enhanced_summaries.append(result)
+ completed_enhancements += 1
+ # Update progress
+ progress = completed_enhancements / total_to_enhance
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_summary[future]
+ st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}")
+ completed_enhancements += 1
+ continue
+
+ summaries = enhanced_summaries
+
+ # Phase 3: Generate cluster names in parallel
+ if summaries:
+ total_to_name = len(summaries)
+ completed_names = 0
+ progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)")
+ progress_bar.progress(0)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit cluster naming tasks
+ future_to_summary = {
+ executor.submit(
+ generate_cluster_name,
+ summary_dict.get('Enhanced_Summary', summary_dict['Summary']),
+ llm
+ ): summary_dict.get('Topic')
+ for summary_dict in summaries
+ }
+
+ # Process completed naming tasks
+ named_summaries = []
+ for future in future_to_summary:
+ try:
+ cluster_name = future.result()
+ topic_val = future_to_summary[future]
+ # Find the corresponding summary dict
+ summary_dict = next(s for s in summaries if s['Topic'] == topic_val)
+ summary_dict['Cluster_Name'] = cluster_name
+ named_summaries.append(summary_dict)
+ completed_names += 1
+ # Update progress
+ progress = completed_names / total_to_name
+ progress_bar.progress(progress)
+ progress_text.text(
+ f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)"
+ )
+ except Exception as e:
+ topic_val = future_to_summary[future]
+ st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}")
+ completed_names += 1
+ continue
+
+ summaries = named_summaries
+ finally:
+ # Clean up progress indicators
+ progress_text.empty()
+ progress_bar.empty()
+
+ return summaries
+
+###############################################################################
+# Helper: Generate cluster name
+###############################################################################
+def generate_cluster_name(summary_text: str, llm: Any) -> str:
+ """Generate a concise, descriptive name for a cluster based on its summary."""
+ system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster.
+
+Rules:
+1. Keep it between 3-6 words
+2. Be specific but concise
+3. Capture the main theme/focus
+4. Use title case
+4. Do not include words like "Cluster", "Topic", or "Theme"
+5. Focus on the content, not metadata
+
+Example good names:
+- Agricultural Water Management Innovation
+- Gender Equality in Farming
+- Climate-Smart Village Implementation
+- Sustainable Livestock Practices"""
+
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"}
+ ]
+
+ try:
+ response = get_chat_response(messages)
+ # Clean up response (remove quotes, newlines, etc.)
+ cluster_name = response.strip().strip('"').strip("'").strip()
+ return cluster_name
+ except Exception as e:
+ st.error(f"Error generating cluster name: {str(e)}")
+ return "Unnamed Cluster"
+
+###############################################################################
+# Helper: Attempt to get this file's directory or fallback to current working dir
+###############################################################################
+def get_base_dir():
+ try:
+ base_dir = os.path.dirname(__file__)
+ if not base_dir:
+ return os.getcwd()
+ return base_dir
+ except NameError:
+ # In case __file__ is not defined (some environments)
+ return os.getcwd()
+
+BASE_DIR = get_base_dir()
+
+###############################################################################
+# NLTK Resource Initialization
+###############################################################################
+def init_nltk_resources():
+ """Initialize NLTK resources with better error handling and less verbose output"""
+ nltk.data.path.append('/home/appuser/nltk_data') # Ensure consistent data path
+
+ resources = {
+ 'tokenizers/punkt': 'punkt_tab', # Updated to use punkt_tab
+ 'corpora/stopwords': 'stopwords'
+ }
+
+ for resource_path, resource_name in resources.items():
+ try:
+ nltk.data.find(resource_path)
+ except LookupError:
+ try:
+ nltk.download(resource_name, quiet=True)
+ except Exception as e:
+ st.warning(f"Error downloading NLTK resource {resource_name}: {e}")
+
+ # Test tokenizer silently
+ try:
+ from nltk.tokenize import PunktSentenceTokenizer
+ tokenizer = PunktSentenceTokenizer()
+ tokenizer.tokenize("Test sentence.")
+ except Exception as e:
+ st.error(f"Error initializing NLTK tokenizer: {e}")
+ try:
+ nltk.download('punkt_tab', quiet=True) # Updated to use punkt_tab
+ except Exception as e:
+ st.error(f"Failed to download punkt_tab tokenizer: {e}")
+
+# Initialize NLTK resources
+init_nltk_resources()
+
+###############################################################################
+# Function: add_references_to_summary
+###############################################################################
+def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None):
+ """
+ Add references to a summary by identifying which parts of the summary come
+ from which source documents. References will be appended as [ID],
+ optionally linked if a URL column is provided.
+
+ Args:
+ summary (str): The summary text to enhance with references.
+ source_df (DataFrame): DataFrame containing the source documents.
+ reference_column (str): Column name to use for reference IDs.
+ url_column (str, optional): Column name containing URLs for hyperlinks.
+ llm (LLM, optional): Language model for source attribution.
+ Returns:
+ str: Enhanced summary with references as HTML if possible.
+ """
+ if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns:
+ return summary
+
+ # If no LLM is provided, we can't do source attribution
+ if llm is None:
+ return summary
+
+ # Split the summary into paragraphs first
+ paragraphs = summary.split('\n\n')
+ enhanced_paragraphs = []
+
+ # Prepare source texts with their reference IDs
+ source_texts = []
+ reference_ids = []
+ urls = []
+ for _, row in source_df.iterrows():
+ if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]):
+ source_texts.append(str(row['text']))
+ reference_ids.append(str(row[reference_column]))
+ if url_column and url_column in row and pd.notna(row[url_column]):
+ urls.append(str(row[url_column]))
+ else:
+ urls.append(None)
+ if not source_texts:
+ return summary
+
+ # Create a mapping between URLs and reference IDs
+ url_map = {}
+ for ref_id, u in zip(reference_ids, urls):
+ if u:
+ url_map[ref_id] = u
+
+ # Define the system prompt for source attribution
+ system_prompt = """
+ You are an expert at identifying the source of information. You will be given:
+ 1. A sentence or bullet point from a summary
+ 2. A list of source texts with their IDs
+
+ Your task is to identify which source text(s) the text most likely came from.
+ Return ONLY the IDs of the source texts that contributed to the text, separated by commas.
+ If you cannot confidently attribute the text to any source, return "unknown".
+ """
+
+ for paragraph in paragraphs:
+ if not paragraph.strip():
+ enhanced_paragraphs.append('')
+ continue
+
+ # Check if it's a bullet point list
+ if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')):
+ # Handle bullet points
+ bullet_lines = paragraph.split('\n')
+ enhanced_bullets = []
+ for line in bullet_lines:
+ if not line.strip():
+ enhanced_bullets.append(line)
+ continue
+
+ if line.strip().startswith('- ') or line.strip().startswith('* '):
+ # Process each bullet point
+ user_prompt = f"""
+ Text: {line.strip()}
+
+ Source texts:
+ {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
+
+ Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown".
+ """
+
+ try:
+ system_message = SystemMessagePromptTemplate.from_template(system_prompt)
+ human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
+ chain = LLMChain(llm=llm, prompt=chat_prompt)
+ response = chain.run(user_prompt=user_prompt)
+ source_ids = response.strip()
+
+ if source_ids.lower() == "unknown":
+ enhanced_bullets.append(line)
+ else:
+ # Extract just the IDs
+ source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
+ source_ids = re.sub(r'\s+', '', source_ids)
+ ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
+
+ if ids:
+ ref_parts = []
+ for id_ in ids:
+ if id_ in url_map:
+ ref_parts.append(f'{id_}')
+ else:
+ ref_parts.append(id_)
+ ref_string = ", ".join(ref_parts)
+ enhanced_bullets.append(f"{line} [{ref_string}]")
+ else:
+ enhanced_bullets.append(line)
+ except Exception:
+ enhanced_bullets.append(line)
+ else:
+ enhanced_bullets.append(line)
+
+ enhanced_paragraphs.append('\n'.join(enhanced_bullets))
+ else:
+ # Handle regular paragraphs
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
+ enhanced_sentences = []
+
+ for sentence in sentences:
+ if not sentence.strip():
+ continue
+
+ user_prompt = f"""
+ Sentence: {sentence.strip()}
+
+ Source texts:
+ {'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
+
+ Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown".
+ """
+
+ try:
+ system_message = SystemMessagePromptTemplate.from_template(system_prompt)
+ human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
+ chain = LLMChain(llm=llm, prompt=chat_prompt)
+ response = chain.run(user_prompt=user_prompt)
+ source_ids = response.strip()
+
+ if source_ids.lower() == "unknown":
+ enhanced_sentences.append(sentence)
+ else:
+ # Extract just the IDs
+ source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
+ source_ids = re.sub(r'\s+', '', source_ids)
+ ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
+
+ if ids:
+ ref_parts = []
+ for id_ in ids:
+ if id_ in url_map:
+ ref_parts.append(f'{id_}')
+ else:
+ ref_parts.append(id_)
+ ref_string = ", ".join(ref_parts)
+ enhanced_sentences.append(f"{sentence} [{ref_string}]")
+ else:
+ enhanced_sentences.append(sentence)
+ except Exception:
+ enhanced_sentences.append(sentence)
+
+ enhanced_paragraphs.append(' '.join(enhanced_sentences))
+
+ # Join paragraphs back together with double newlines to preserve formatting
+ return '\n\n'.join(enhanced_paragraphs)
+
+
+st.sidebar.image("static/SNAP_logo.png", width=350)
+
+###############################################################################
+# Device / GPU Info
+###############################################################################
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+if device == 'cuda':
+ st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
+else:
+ st.sidebar.info("Using CPU")
+
+###############################################################################
+# Load or Compute Embeddings
+###############################################################################
+@st.cache_resource
+def get_embedding_model():
+ model_dir = get_model_dir()
+ st_model_dir = os.path.join(model_dir, 'sentence_transformer')
+ os.makedirs(st_model_dir, exist_ok=True)
+
+ model_name = 'all-MiniLM-L6-v2'
+ try:
+ # Try to load from local directory first
+ model = SentenceTransformer(st_model_dir)
+ #st.success("Loaded sentence transformer from local storage")
+ except Exception as e:
+ #st.warning("Downloading sentence transformer model (one-time operation)...")
+ try:
+ # Download and save to local directory
+ model = SentenceTransformer(model_name)
+ model.save(st_model_dir)
+ #st.success("Downloaded and saved sentence transformer model")
+ except Exception as download_e:
+ st.error(f"Error downloading sentence transformer model: {str(download_e)}")
+ raise
+
+ return model.to(device)
+
+def generate_embeddings(texts, model):
+ with st.spinner('Calculating embeddings...'):
+ embeddings = model.encode(texts, show_progress_bar=True, device=device)
+ return embeddings
+
+@st.cache_data
+def load_default_dataset(default_dataset_path):
+ if os.path.exists(default_dataset_path):
+ df_ = pd.read_excel(default_dataset_path)
+ return df_
+ else:
+ st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.")
+ return None
+
+@st.cache_data
+def load_uploaded_dataset(uploaded_file):
+ df_ = pd.read_excel(uploaded_file)
+ return df_
+
+def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None):
+ """
+ Loads pre-computed embeddings from a pickle file if they match current data,
+ otherwise computes and caches them.
+ """
+ if not text_columns:
+ return None, None
+
+ base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset"
+ if uploaded_file_name:
+ base_name = os.path.splitext(uploaded_file_name)[0]
+
+ cols_key = "_".join(sorted(text_columns))
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+ embeddings_dir = BASE_DIR
+ if using_default_dataset:
+ embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl')
+ else:
+ # For custom dataset, we still try to avoid regenerating each time
+ embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl")
+
+ df_fill = df.fillna("")
+ texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist()
+
+ # If already in session_state with matching columns and length, reuse
+ if ('embeddings' in st.session_state
+ and 'last_text_columns' in st.session_state
+ and st.session_state['last_text_columns'] == text_columns
+ and len(st.session_state['embeddings']) == len(texts)):
+ return st.session_state['embeddings'], st.session_state.get('embeddings_file', None)
+
+ # Try to load from disk
+ if os.path.exists(embeddings_file):
+ with open(embeddings_file, 'rb') as f:
+ embeddings = pickle.load(f)
+ if len(embeddings) == len(texts):
+ st.write("Loaded pre-calculated embeddings.")
+ st.session_state['embeddings'] = embeddings
+ st.session_state['embeddings_file'] = embeddings_file
+ st.session_state['last_text_columns'] = text_columns
+ return embeddings, embeddings_file
+
+ # Otherwise compute
+ st.write("Generating embeddings...")
+ model = get_embedding_model()
+ embeddings = generate_embeddings(texts, model)
+ with open(embeddings_file, 'wb') as f:
+ pickle.dump(embeddings, f)
+
+ st.session_state['embeddings'] = embeddings
+ st.session_state['embeddings_file'] = embeddings_file
+ st.session_state['last_text_columns'] = text_columns
+ return embeddings, embeddings_file
+
+
+###############################################################################
+# Reset Filter Function
+###############################################################################
+def reset_filters():
+ st.session_state['selected_additional_filters'] = {}
+
+# Selector de vista
+st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view")
+
+if st.session_state.view == "Power User Mode":
+ st.header("Power User Mode")
+ ###############################################################################
+ # Sidebar: Dataset Selection
+ ###############################################################################
+ st.sidebar.title("Data Selection")
+ dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset'))
+
+ if 'df' not in st.session_state:
+ st.session_state['df'] = pd.DataFrame()
+ if 'filtered_df' not in st.session_state:
+ st.session_state['filtered_df'] = pd.DataFrame()
+
+ if dataset_option == 'PRMS 2022+2023+2024 QAed':
+ default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
+ df = load_default_dataset(default_dataset_path)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = True
+
+ # Initialize filtered_df with full dataset by default
+ if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty:
+ st.session_state['filtered_df'] = df.copy()
+
+ # Initialize filter_state if not exists
+ if 'filter_state' not in st.session_state:
+ st.session_state['filter_state'] = {
+ 'applied': False,
+ 'filters': {}
+ }
+
+ # Set default text columns if not already set
+ if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
+ default_text_cols = []
+ if 'Title' in df.columns and 'Description' in df.columns:
+ default_text_cols = ['Title', 'Description']
+ st.session_state['text_columns'] = default_text_cols
+
+ #st.write("Using default dataset:")
+ #st.write("Data Preview:")
+ #st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
+ #st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+ df_cols = df.columns.tolist()
+
+ # Additional filter columns
+ st.subheader("Select Filters")
+ if 'additional_filters_selected' not in st.session_state:
+ st.session_state['additional_filters_selected'] = []
+ if 'filter_values' not in st.session_state:
+ st.session_state['filter_values'] = {}
+
+ with st.form("filter_selection_form"):
+ all_columns = df.columns.tolist()
+ selected_additional_cols = st.multiselect(
+ "Select columns from your dataset to use as filters:",
+ all_columns,
+ default=st.session_state['additional_filters_selected']
+ )
+ add_filters_submitted = st.form_submit_button("Add Additional Filters")
+
+ if add_filters_submitted:
+ if selected_additional_cols != st.session_state['additional_filters_selected']:
+ st.session_state['additional_filters_selected'] = selected_additional_cols
+ # Reset removed columns
+ st.session_state['filter_values'] = {
+ k: v for k, v in st.session_state['filter_values'].items()
+ if k in selected_additional_cols
+ }
+
+ # Show dynamic filters form if any selected columns
+ if st.session_state['additional_filters_selected']:
+ st.subheader("Apply Filters")
+
+ # Quick search section (outside form)
+ for col_name in st.session_state['additional_filters_selected']:
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+
+ # Add a search box for quick selection
+ search_key = f"search_{col_name}"
+ if search_key not in st.session_state:
+ st.session_state[search_key] = ""
+
+ col1, col2 = st.columns([3, 1])
+ with col1:
+ search_term = st.text_input(
+ f"Search in {col_name}",
+ key=search_key,
+ help="Enter text to find and select all matching values"
+ )
+ with col2:
+ if st.button(f"Select Matching", key=f"select_{col_name}"):
+ # Handle comma-separated values
+ if search_term:
+ matching_vals = [
+ val for val in unique_vals
+ if any(search_term.lower() in str(part).lower()
+ for part in (val.split(',') if isinstance(val, str) else [val]))
+ ]
+ # Update the multiselect default value
+ current_selected = st.session_state['filter_values'].get(col_name, [])
+ st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals))
+
+ # Show feedback about matches
+ if matching_vals:
+ st.success(f"Found and selected {len(matching_vals)} matching values")
+ else:
+ st.warning("No matching values found")
+
+ # Filter application form
+ with st.form("apply_filters_form"):
+ for col_name in st.session_state['additional_filters_selected']:
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+ selected_vals = st.multiselect(
+ f"Filter by {col_name}",
+ options=unique_vals,
+ default=st.session_state['filter_values'].get(col_name, [])
+ )
+ st.session_state['filter_values'][col_name] = selected_vals
+
+ # Add clear filters button and apply filters button
+ col1, col2 = st.columns([1, 4])
+ with col1:
+ clear_filters = st.form_submit_button("Clear All")
+ with col2:
+ apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset")
+
+ if clear_filters:
+ st.session_state['filter_values'] = {}
+ # Clear any existing summary data when filters are cleared
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+ st.rerun()
+
+ # Text columns selection moved to Advanced Settings
+ with st.expander("⚙️ Advanced Settings", expanded=False):
+ st.subheader("**Select Text Columns for Embedding**")
+ text_columns_selected = st.multiselect(
+ "Text Columns:",
+ df_cols,
+ default=st.session_state['text_columns'],
+ help="Choose columns containing text for semantic search and clustering. "
+ "If multiple are selected, their text will be concatenated."
+ )
+ st.session_state['text_columns'] = text_columns_selected
+
+ # Apply filters to the dataset
+ filtered_df = df.copy()
+ if 'apply_filters_submitted' in locals() and apply_filters_submitted:
+ # Clear any existing summary data when new filters are applied
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ for col_name in st.session_state['additional_filters_selected']:
+ selected_vals = st.session_state['filter_values'].get(col_name, [])
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+ st.success("Filters applied successfully!")
+ st.session_state['filtered_df'] = filtered_df.copy()
+ st.session_state['filter_state'] = {
+ 'applied': True,
+ 'filters': st.session_state['filter_values'].copy()
+ }
+ # Reset any existing clustering results
+ for k in ['clustered_data', 'topic_model', 'current_clustering_data',
+ 'current_clustering_option', 'hierarchy']:
+ if k in st.session_state:
+ del st.session_state[k]
+
+ elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']:
+ # Reapply stored filters
+ for col_name, selected_vals in st.session_state['filter_state']['filters'].items():
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+ st.session_state['filtered_df'] = filtered_df.copy()
+
+ # Show current data preview and download button
+ if st.session_state['filtered_df'] is not None:
+ if st.session_state['filter_state']['applied']:
+ st.write("Filtered Data Preview:")
+ else:
+ st.write("Current Data Preview:")
+ st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
+ st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ st.session_state['filtered_df'].to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+
+ st.download_button(
+ label="Download Current Data",
+ data=processed_data,
+ file_name='data.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ )
+ else:
+ st.warning("Please ensure the default dataset exists in the 'input' directory.")
+
+ else:
+ # Upload custom dataset
+ uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"])
+ if uploaded_file is not None:
+ df = load_uploaded_dataset(uploaded_file)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = False
+ st.session_state['uploaded_file_name'] = uploaded_file.name
+ st.write("Data preview:")
+ st.write(df.head())
+ df_cols = df.columns.tolist()
+
+ st.subheader("**Select Text Columns for Embedding**")
+ text_columns_selected = st.multiselect(
+ "Text Columns:",
+ df_cols,
+ default=df_cols[:1] if df_cols else []
+ )
+ st.session_state['text_columns'] = text_columns_selected
+
+ st.write("**Additional Filters**")
+ selected_additional_cols = st.multiselect(
+ "Select additional columns from your dataset to use as filters:",
+ df_cols,
+ default=[]
+ )
+ st.session_state['additional_filters_selected'] = selected_additional_cols
+
+ filtered_df = df.copy()
+ for col_name in selected_additional_cols:
+ if f'selected_filter_{col_name}' not in st.session_state:
+ st.session_state[f'selected_filter_{col_name}'] = []
+ unique_vals = sorted(df[col_name].dropna().unique().tolist())
+ selected_vals = st.multiselect(
+ f"Filter by {col_name}",
+ options=unique_vals,
+ default=st.session_state[f'selected_filter_{col_name}']
+ )
+ st.session_state[f'selected_filter_{col_name}'] = selected_vals
+ if selected_vals:
+ filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
+
+ st.session_state['filtered_df'] = filtered_df
+ st.write("Filtered Data Preview:")
+ st.dataframe(filtered_df.head(), hide_index=True)
+ st.write(f"Total number of results: {len(filtered_df)}")
+
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ filtered_df.to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+
+ st.download_button(
+ label="Download Filtered Data",
+ data=processed_data,
+ file_name='filtered_data.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
+ )
+ else:
+ st.warning("Failed to load the uploaded dataset.")
+ else:
+ st.warning("Please upload an Excel file to proceed.")
+
+ if 'filtered_df' in st.session_state:
+ st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
+
+
+ ###############################################################################
+ # Preserve active tab across reruns
+ ###############################################################################
+ if 'active_tab_index' not in st.session_state:
+ st.session_state.active_tab_index = 0
+
+ tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"]
+ tabs = st.tabs(tabs_titles)
+ # We just create these references so we can navigate more easily
+ tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs
+
+ ###############################################################################
+ # Tab: Help
+ ###############################################################################
+ with tab_help:
+ st.header("Help")
+ st.markdown("""
+ ### About SNAP
+
+ SNAP allows you to explore, filter, search, cluster, and summarize textual datasets.
+
+ **Workflow**:
+ 1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own.
+ 2. **Filtering**: Set additional filters for your dataset.
+ 3. **Select Text Columns**: Which columns to embed.
+ 4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents.
+ 5. **Clustering** (Tab): Group documents into topics.
+ 6. **Summarization** (Tab): Summarize the clustered documents (with optional references).
+
+ ### Troubleshooting
+ - If you see no results, try lowering the similarity threshold or removing negative/required keywords.
+ - Ensure you have at least one text column selected for embeddings.
+ """)
+
+ ###############################################################################
+ # Tab: Semantic Search
+ ###############################################################################
+ with tab_semantic:
+ st.header("Semantic Search")
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select at least one column for text embedding.")
+ else:
+ df_full = st.session_state['df']
+ # Load or compute embeddings if necessary
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if embeddings is not None:
+ with st.expander("ℹ️ How Semantic Search Works", expanded=False):
+ st.markdown("""
+ ### Understanding Semantic Search
+
+ Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works:
+
+ 1. **Query Processing**:
+ - Your search query is converted into a numerical representation (embedding) that captures its meaning
+ - Example: Searching for "Climate Smart Villages" will understand the concept, not just the words
+ - Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words
+
+ 2. **Similarity Matching**:
+ - Documents are ranked by how closely their meaning matches your query
+ - The similarity threshold controls how strict this matching is
+ - Higher threshold (e.g., 0.8) = more precise but fewer results
+ - Lower threshold (e.g., 0.3) = more results but might be less relevant
+
+ 3. **Advanced Features**:
+ - **Negative Keywords**: Use to explicitly exclude documents containing certain terms
+ - **Required Keywords**: Ensure specific terms appear in the results
+ - These work as traditional keyword filters after the semantic search
+
+ ### Search Tips
+
+ - **Phrase Queries**: Enter complete phrases for better context
+ - "Climate Smart Villages" (as one concept)
+ - Better than separate terms: "climate", "smart", "villages"
+
+ - **Descriptive Queries**: Add context for better results
+ - Instead of: "water"
+ - Better: "water management in agriculture"
+
+ - **Conceptual Queries**: Focus on concepts rather than specific terms
+ - Instead of: "increased yield"
+ - Better: "agricultural productivity improvements"
+
+ ### Example Searches
+
+ 1. **Query**: "Climate Smart Villages"
+ - Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development
+ - Even if they don't use these exact words
+
+ 2. **Query**: "Gender equality in agriculture"
+ - Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development
+ - Related concepts are captured semantically
+
+ 3. **Query**: "Sustainable water management"
+ + Required keyword: "irrigation"
+ - Combines semantic understanding of water sustainability with specific irrigation focus
+ """)
+
+ with st.form("search_parameters"):
+ query = st.text_input("Enter your search query:")
+ include_keywords = st.text_input("Include only documents containing these words (comma-separated):")
+ similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35)
+ submitted = st.form_submit_button("Search")
+
+ if submitted:
+ if query.strip():
+ with st.spinner("Performing Semantic Search..."):
+ # Clear any existing summary data when new search is run
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ model = get_embedding_model()
+ df_filtered = st.session_state['filtered_df'].fillna("")
+ search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
+
+ # Filter the embeddings to the same subset
+ subset_indices = df_filtered.index
+ subset_embeddings = embeddings[subset_indices]
+
+ query_embedding = model.encode([query], device=device)
+ similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
+
+ # Show distribution
+ fig = px.histogram(
+ x=similarities,
+ nbins=30,
+ labels={'x': 'Similarity Score', 'y': 'Number of Documents'},
+ title='Distribution of Similarity Scores'
+ )
+ fig.add_vline(
+ x=similarity_threshold,
+ line_dash="dash",
+ line_color="red",
+ annotation_text=f"Threshold: {similarity_threshold:.2f}",
+ annotation_position="top"
+ )
+ st.write("### Similarity Score Distribution")
+ st.plotly_chart(fig)
+
+ above_threshold_indices = np.where(similarities > similarity_threshold)[0]
+ if len(above_threshold_indices) == 0:
+ st.warning("No results found above the similarity threshold.")
+ if 'search_results' in st.session_state:
+ del st.session_state['search_results']
+ else:
+ selected_indices = subset_indices[above_threshold_indices]
+ results = df_filtered.loc[selected_indices].copy()
+ results['similarity_score'] = similarities[above_threshold_indices]
+ results.sort_values(by='similarity_score', ascending=False, inplace=True)
+
+ # Include keyword filtering
+ if include_keywords.strip():
+ inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()]
+ if inc_words:
+ results = results[
+ results.apply(
+ lambda row: all(
+ w in (' '.join(row.astype(str)).lower()) for w in inc_words
+ ),
+ axis=1
+ )
+ ]
+
+ if results.empty:
+ st.warning("No results found after applying keyword filters.")
+ if 'search_results' in st.session_state:
+ del st.session_state['search_results']
+ else:
+ st.session_state['search_results'] = results.copy()
+ output = io.BytesIO()
+ writer = pd.ExcelWriter(output, engine='openpyxl')
+ results.to_excel(writer, index=False)
+ writer.close()
+ processed_data = output.getvalue()
+ st.session_state['search_results_processed_data'] = processed_data
+ else:
+ st.warning("Please enter a query to search.")
+
+ # Display search results if available
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ st.write("## Search Results")
+ results = st.session_state['search_results']
+ cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score']
+ st.dataframe(results[cols_to_display], hide_index=True)
+ st.write(f"Total number of results: {len(results)}")
+
+ if 'search_results_processed_data' in st.session_state:
+ st.download_button(
+ label="Download Full Results",
+ data=st.session_state['search_results_processed_data'],
+ file_name='search_results.xlsx',
+ mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ key='download_search_results'
+ )
+ else:
+ st.info("No search results to display. Enter a query and click 'Search'.")
+ else:
+ st.warning("No embeddings available because no text columns were chosen.")
+ else:
+ st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.")
+
+
+ ###############################################################################
+ # Tab: Clustering
+ ###############################################################################
+ with tab_clustering:
+ st.header("Clustering")
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ # Add explanation about clustering
+ with st.expander("ℹ️ How Clustering Works", expanded=False):
+ st.markdown("""
+ ### Understanding Document Clustering
+
+ Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works:
+
+ 1. **Cluster Formation**:
+ - Documents are grouped based on their semantic similarity
+ - Each cluster represents a distinct theme or topic
+ - Documents that are too different from others may remain unclustered (labeled as -1)
+ - The "Min Cluster Size" parameter controls how clusters are formed
+
+ 2. **Interpreting Results**:
+ - Each cluster is assigned a number (e.g., 0, 1, 2...)
+ - Cluster -1 contains "outlier" documents that didn't fit well in other clusters
+ - The size of each cluster indicates how common that theme is
+ - Keywords for each cluster show the main topics/concepts
+
+ 3. **Visualizations**:
+ - **Intertopic Distance Map**: Shows how clusters relate to each other
+ - Closer clusters are more semantically similar
+ - Size of circles indicates number of documents
+ - Hover to see top terms for each cluster
+
+ - **Topic Document Visualization**: Shows individual documents
+ - Each point is a document
+ - Colors indicate cluster membership
+ - Distance between points shows similarity
+
+ - **Topic Hierarchy**: Shows how topics are related
+ - Tree structure shows topic relationships
+ - Parent topics contain broader themes
+ - Child topics show more specific sub-themes
+
+ ### How to Use Clusters
+
+ 1. **Exploration**:
+ - Use clusters to discover main themes in your data
+ - Look for unexpected groupings that might reveal insights
+ - Identify outliers that might need special attention
+
+ 2. **Analysis**:
+ - Compare cluster sizes to understand theme distribution
+ - Examine keywords to understand what defines each cluster
+ - Use hierarchy to see how themes are nested
+
+ 3. **Practical Applications**:
+ - Generate summaries for specific clusters
+ - Focus detailed analysis on clusters of interest
+ - Use clusters to organize and categorize documents
+ - Identify gaps or overlaps in your dataset
+
+ ### Tips for Better Results
+
+ - **Adjust Min Cluster Size**:
+ - Larger values (15-20): Fewer, broader clusters
+ - Smaller values (2-5): More specific, smaller clusters
+ - Balance between too many small clusters and too few large ones
+
+ - **Choose Data Wisely**:
+ - Cluster full dataset for overall themes
+ - Cluster search results for focused analysis
+ - More documents generally give better clusters
+
+ - **Interpret with Context**:
+ - Consider your domain knowledge
+ - Look for patterns across multiple visualizations
+ - Use cluster insights to guide further analysis
+ """)
+
+ df_to_cluster = None
+
+ # Create a single form for clustering settings
+ with st.form("clustering_form"):
+ st.subheader("Clustering Settings")
+
+ # Data source selection
+ clustering_option = st.radio(
+ "Select data for clustering:",
+ ('Full Dataset', 'Filtered Dataset', 'Semantic Search Results')
+ )
+
+ # Clustering parameters
+ min_cluster_size_val = st.slider(
+ "Min Cluster Size",
+ min_value=2,
+ max_value=50,
+ value=st.session_state.get('min_cluster_size', 5),
+ help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)"
+ )
+
+ run_clustering = st.form_submit_button("Run Clustering")
+
+ if run_clustering:
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.session_state['min_cluster_size'] = min_cluster_size_val
+
+ # Decide which DataFrame is used based on the selection
+ if clustering_option == 'Semantic Search Results':
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ df_to_cluster = st.session_state['search_results'].copy()
+ else:
+ st.warning("No semantic search results found. Please run a search first.")
+ elif clustering_option == 'Filtered Dataset':
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ df_to_cluster = st.session_state['filtered_df'].copy()
+ else:
+ st.warning("Filtered dataset is empty. Please check your filters.")
+ else: # Full Dataset
+ if 'df' in st.session_state and not st.session_state['df'].empty:
+ df_to_cluster = st.session_state['df'].copy()
+
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select text columns to embed before clustering.")
+ else:
+ # Ensure embeddings are available
+ df_full = st.session_state['df']
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering:
+ with st.spinner("Performing clustering..."):
+ # Clear any existing summary data when clustering is run
+ if 'summary_df' in st.session_state:
+ del st.session_state['summary_df']
+ if 'high_level_summary' in st.session_state:
+ del st.session_state['high_level_summary']
+ if 'enhanced_summary' in st.session_state:
+ del st.session_state['enhanced_summary']
+
+ dfc = df_to_cluster.copy().fillna("")
+ dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
+
+ # Filter embeddings to those rows
+ selected_indices = dfc.index
+ embeddings_clustering = embeddings[selected_indices]
+
+ # Basic cleaning
+ stop_words = set(stopwords.words('english'))
+ texts_cleaned = []
+ for text in dfc['text'].tolist():
+ try:
+ # First try with word_tokenize
+ try:
+ word_tokens = word_tokenize(text)
+ except LookupError:
+ # If punkt is missing, try downloading it again
+ nltk.download('punkt_tab', quiet=False)
+ word_tokens = word_tokenize(text)
+ except Exception as e:
+ # If word_tokenize fails, fall back to simple splitting
+ st.warning(f"Using fallback tokenization due to error: {e}")
+ word_tokens = text.split()
+
+ filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
+ texts_cleaned.append(filtered_text)
+ except Exception as e:
+ st.error(f"Error processing text: {e}")
+ # Add the original text if processing fails
+ texts_cleaned.append(text)
+
+ try:
+ # Validation checks before clustering
+ if len(texts_cleaned) < min_cluster_size_val:
+ st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.")
+ st.session_state['clustering_error'] = "Insufficient documents for clustering"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Convert embeddings to CPU numpy if needed
+ if torch.is_tensor(embeddings_clustering):
+ embeddings_for_clustering = embeddings_clustering.cpu().numpy()
+ else:
+ embeddings_for_clustering = embeddings_clustering
+
+ # Additional validation
+ if embeddings_for_clustering.shape[0] != len(texts_cleaned):
+ st.error("Mismatch between number of embeddings and texts.")
+ st.session_state['clustering_error'] = "Embedding and text count mismatch"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Build the HDBSCAN model with error handling
+ try:
+ hdbscan_model = HDBSCAN(
+ min_cluster_size=min_cluster_size_val,
+ metric='euclidean',
+ cluster_selection_method='eom'
+ )
+
+ # Build the BERTopic model
+ topic_model = BERTopic(
+ embedding_model=get_embedding_model(),
+ hdbscan_model=hdbscan_model
+ )
+
+ # Fit the model and get topics
+ topics, probs = topic_model.fit_transform(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+
+ # Validate clustering results
+ unique_topics = set(topics)
+ if len(unique_topics) < 2:
+ st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.")
+ if -1 in unique_topics:
+ non_noise_docs = sum(1 for t in topics if t != -1)
+ st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).")
+ if non_noise_docs < min_cluster_size_val:
+ st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.")
+ st.session_state['clustering_error'] = "Insufficient clustered documents"
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Store results if validation passes
+ dfc['Topic'] = topics
+ st.session_state['topic_model'] = topic_model
+ st.session_state['clustered_data'] = dfc.copy()
+ st.session_state['clustering_texts_cleaned'] = texts_cleaned
+ st.session_state['clustering_embeddings'] = embeddings_for_clustering
+ st.session_state['clustering_completed'] = True
+
+ # Try to generate visualizations with error handling
+ try:
+ st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
+ except Exception as viz_error:
+ st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.")
+ st.session_state['intertopic_distance_fig'] = None
+
+ try:
+ st.session_state['topic_document_fig'] = topic_model.visualize_documents(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+ except Exception as viz_error:
+ st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.")
+ st.session_state['topic_document_fig'] = None
+
+ try:
+ hierarchy = topic_model.hierarchical_topics(texts_cleaned)
+ st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
+ st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
+ except Exception as viz_error:
+ st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.")
+ st.session_state['hierarchy'] = pd.DataFrame()
+ st.session_state['hierarchy_fig'] = None
+
+ except ValueError as ve:
+ if "zero-size array to reduction operation maximum which has no identity" in str(ve):
+ st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.")
+ elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve):
+ st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.")
+ else:
+ st.error(f"Clustering error: {str(ve)}")
+ st.session_state['clustering_error'] = str(ve)
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ except Exception as e:
+ st.error(f"An error occurred during clustering: {str(e)}")
+ st.session_state['clustering_error'] = str(e)
+ st.session_state['clustering_completed'] = False
+ st.session_state.active_tab_index = tabs_titles.index("Clustering")
+ st.stop()
+
+ # Display clustering results if they exist
+ if st.session_state.get('clustering_completed', False):
+ st.subheader("Topic Overview")
+ dfc = st.session_state['clustered_data']
+ topic_model = st.session_state['topic_model']
+ topics = dfc['Topic'].tolist()
+
+ unique_topics = sorted(list(set(topics)))
+ cluster_info = []
+ for t in unique_topics:
+ cluster_docs = dfc[dfc['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ if top_words:
+ top_keywords = ", ".join([w[0] for w in top_words[:5]])
+ else:
+ top_keywords = "N/A"
+ cluster_info.append((t, count, top_keywords))
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+
+ st.write("### Topic Overview")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ st.subheader("Clustering Results")
+ columns_to_display = [c for c in dfc.columns if c != 'text']
+ st.dataframe(dfc[columns_to_display], hide_index=True)
+
+ # Display stored visualizations with error handling
+ st.write("### Intertopic Distance Map")
+ if st.session_state.get('intertopic_distance_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['intertopic_distance_fig'])
+ except Exception:
+ st.info("Topic visualization is not available for the current clustering results.")
+
+ st.write("### Topic Document Visualization")
+ if st.session_state.get('topic_document_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['topic_document_fig'])
+ except Exception:
+ st.info("Document visualization is not available for the current clustering results.")
+
+ st.write("### Topic Hierarchy")
+ if st.session_state.get('hierarchy_fig') is not None:
+ try:
+ st.plotly_chart(st.session_state['hierarchy_fig'])
+ except Exception:
+ st.info("Topic hierarchy visualization is not available for the current clustering results.")
+ if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering):
+ pass
+ else:
+ st.warning("Please select or upload a dataset and filter as needed.")
+
+
+ ###############################################################################
+ # Tab: Summarization
+ ###############################################################################
+ with tab_summarization:
+ st.header("Summarization")
+ # Add explanation about summarization
+ with st.expander("ℹ️ How Summarization Works", expanded=False):
+ st.markdown("""
+ ### Understanding Document Summarization
+
+ Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works:
+
+ 1. **Summary Generation**:
+ - Documents are processed using advanced language models
+ - Key themes and important points are identified
+ - Content is condensed while maintaining context
+ - Both high-level and cluster-specific summaries are available
+
+ 2. **Reference System**:
+ - Summaries can include references to source documents
+ - References are shown as [ID] or as clickable links
+ - Each statement can be traced back to its source
+ - Helps maintain accountability and verification
+
+ 3. **Types of Summaries**:
+ - **High-Level Summary**: Overview of all selected documents
+ - Captures main themes across the entire selection
+ - Ideal for quick understanding of large document sets
+ - Shows relationships between different topics
+
+ - **Cluster-Specific Summaries**: Focused on each cluster
+ - More detailed for specific themes
+ - Shows unique aspects of each cluster
+ - Helps understand sub-topics in depth
+
+ ### How to Use Summaries
+
+ 1. **Configuration**:
+ - Choose between all clusters or specific ones
+ - Set temperature for creativity vs. consistency
+ - Adjust max tokens for summary length
+ - Enable/disable reference system
+
+ 2. **Reference Options**:
+ - Select column for reference IDs
+ - Add hyperlinks to references
+ - Choose URL column for clickable links
+ - References help track information sources
+
+ 3. **Practical Applications**:
+ - Quick overview of large datasets
+ - Detailed analysis of specific themes
+ - Evidence-based reporting with references
+ - Compare different document groups
+
+ ### Tips for Better Results
+
+ - **Temperature Setting**:
+ - Higher (0.7-1.0): More creative, varied summaries
+ - Lower (0.1-0.3): More consistent, conservative summaries
+ - Balance based on your needs for creativity vs. consistency
+
+ - **Token Length**:
+ - Longer limits: More detailed summaries
+ - Shorter limits: More concise, focused summaries
+ - Adjust based on document complexity
+
+ - **Reference Usage**:
+ - Enable references for traceability
+ - Use hyperlinks for easy navigation
+ - Choose meaningful reference columns
+ - Helps validate summary accuracy
+
+ ### Best Practices
+
+ 1. **For General Overview**:
+ - Use high-level summary
+ - Keep temperature moderate (0.5-0.7)
+ - Enable references for verification
+ - Focus on broader themes
+
+ 2. **For Detailed Analysis**:
+ - Use cluster-specific summaries
+ - Adjust temperature based on need
+ - Include references with hyperlinks
+ - Look for patterns within clusters
+
+ 3. **For Reporting**:
+ - Combine both summary types
+ - Use references extensively
+ - Balance detail and brevity
+ - Ensure source traceability
+ """)
+
+ df_summ = None
+ # We'll try to summarize either the clustered data or just the filtered dataset
+ if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
+ df_summ = st.session_state['clustered_data']
+ elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ df_summ = st.session_state['filtered_df']
+ else:
+ st.warning("No data available for summarization. Please cluster first or have some filtered data.")
+
+ if df_summ is not None and not df_summ.empty:
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns:
+ st.warning("No text columns selected. Please select columns for text embedding first.")
+ else:
+ if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state:
+ st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.")
+ else:
+ topic_model = st.session_state['topic_model']
+ df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1)
+
+ # List of topics
+ topics = sorted(df_summ['Topic'].unique())
+ cluster_info = []
+ for t in topics:
+ cluster_docs = df_summ[df_summ['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ if top_words:
+ top_keywords = ", ".join([w[0] for w in top_words[:5]])
+ else:
+ top_keywords = "N/A"
+ cluster_info.append((t, count, top_keywords))
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+
+ # If we have cluster names from previous summarization, add them
+ if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
+ summary_df = st.session_state['summary_df']
+ # Create a mapping of topic to name for merging
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ # Add cluster names to cluster_df
+ cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
+ # Reorder columns to show name after topic
+ cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
+
+ st.write("### Available Clusters:")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Summarization settings
+ st.subheader("Summarization Settings")
+ # Summaries scope
+ summary_scope = st.radio(
+ "Generate summaries for:",
+ ["All clusters", "Specific clusters"]
+ )
+ if summary_scope == "Specific clusters":
+ # Format options to include cluster names if available
+ if 'Cluster_Name' in cluster_df.columns:
+ topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])]
+ topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])}
+ selected_topic_options = st.multiselect("Select clusters to summarize", topic_options)
+ selected_topics = [topic_to_id[opt] for opt in selected_topic_options]
+ else:
+ selected_topics = st.multiselect("Select clusters to summarize", topics)
+ else:
+ selected_topics = topics
+
+ # Add system prompt configuration
+ default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries.
+ You will be given text and an objective context. Please produce a clear, cohesive,
+ and thematically relevant summary.
+ Focus on key points, insights, or patterns that emerge from the text."""
+
+ if 'system_prompt' not in st.session_state:
+ st.session_state['system_prompt'] = default_system_prompt
+
+ with st.expander("🔧 Advanced Settings", expanded=False):
+ st.markdown("""
+ ### System Prompt Configuration
+
+ The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs:
+ - Be specific about the style and focus you want
+ - Add domain-specific context if needed
+ - Include any special formatting requirements
+ """)
+
+ system_prompt = st.text_area(
+ "Customize System Prompt",
+ value=st.session_state['system_prompt'],
+ height=150,
+ help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus."
+ )
+
+ if st.button("Reset to Default"):
+ system_prompt = default_system_prompt
+ st.session_state['system_prompt'] = default_system_prompt
+
+ st.markdown("### Generation Parameters")
+ temperature = st.slider(
+ "Temperature",
+ 0.0, 1.0, 0.7,
+ help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent."
+ )
+ max_tokens = st.slider(
+ "Max Tokens",
+ 100, 3000, 1000,
+ help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate."
+ )
+
+ st.session_state['system_prompt'] = system_prompt
+
+ st.write("### Enhanced Summary References")
+ st.write("Select columns for references (optional).")
+ all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']]
+
+ # By default, let's guess the first column as reference ID if available
+ if 'reference_id_column' not in st.session_state:
+ st.session_state.reference_id_column = all_cols[0] if all_cols else None
+ # If there's a column that looks like a URL, guess that
+ url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None)
+ if 'url_column' not in st.session_state:
+ st.session_state.url_column = url_guess
+
+ enable_references = st.checkbox(
+ "Enable references in summaries",
+ value=True, # default to True as requested
+ help="Add source references to the final summary text."
+ )
+ reference_id_column = st.selectbox(
+ "Select column to use as reference ID:",
+ all_cols,
+ index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0
+ )
+ add_hyperlinks = st.checkbox(
+ "Add hyperlinks to references",
+ value=True, # default to True
+ help="If the reference column has a matching URL, make it clickable."
+ )
+ url_column = None
+ if add_hyperlinks:
+ url_column = st.selectbox(
+ "Select column containing URLs:",
+ all_cols,
+ index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0
+ )
+
+ # Summarization button
+ if st.button("Generate Summaries"):
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
+ if not openai_api_key:
+ st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
+ else:
+ # Set flag to indicate summarization button was clicked
+ st.session_state['_summarization_button_clicked'] = True
+
+ llm = ChatOpenAI(
+ api_key=openai_api_key,
+ model_name='gpt-4o-mini', # or 'gpt-4o'
+ temperature=temperature,
+ max_tokens=max_tokens
+ )
+
+ # Filter to selected topics
+ if selected_topics:
+ df_scope = df_summ[df_summ['Topic'].isin(selected_topics)]
+ else:
+ st.warning("No topics selected for summarization.")
+ df_scope = pd.DataFrame()
+
+ if df_scope.empty:
+ st.warning("No documents match the selected topics for summarization.")
+ else:
+ all_texts = df_scope['text'].tolist()
+ combined_text = " ".join(all_texts)
+ if not combined_text.strip():
+ st.warning("No text data available for summarization.")
+ else:
+ # For cluster-specific summaries, use the customized prompt
+ local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
+ local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
+
+ # Summaries per cluster
+ # Only if multiple clusters are selected
+ unique_selected_topics = df_scope['Topic'].unique()
+ if len(unique_selected_topics) > 1:
+ st.write("### Summaries per Selected Cluster")
+
+ # Process summaries in parallel
+ with st.spinner("Generating cluster summaries in parallel..."):
+ summaries = process_summaries_in_parallel(
+ df_scope=df_scope,
+ unique_selected_topics=unique_selected_topics,
+ llm=llm,
+ chat_prompt=local_chat_prompt,
+ enable_references=enable_references,
+ reference_id_column=reference_id_column,
+ url_column=url_column if add_hyperlinks else None,
+ max_workers=min(16, len(unique_selected_topics)) # Limit workers based on clusters
+ )
+
+ if summaries:
+ summary_df = pd.DataFrame(summaries)
+ # Store the summaries DataFrame in session state
+ st.session_state['summary_df'] = summary_df
+ # Store additional summary info in session state
+ st.session_state['has_references'] = enable_references
+ st.session_state['reference_id_column'] = reference_id_column
+ st.session_state['url_column'] = url_column if add_hyperlinks else None
+
+ # Update cluster_df with new names
+ if 'Cluster_Name' in summary_df.columns:
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
+ cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
+
+ # Immediately display updated cluster overview
+ st.write("### Updated Topic Overview:")
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Now generate high-level summary from the cluster summaries
+ with st.spinner("Generating high-level summary from cluster summaries..."):
+ # Format cluster summaries with proper markdown and HTML
+ formatted_summaries = []
+ total_tokens = 0
+ MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75) # Leave room for system prompt and completion
+ summary_batches = []
+ current_batch = []
+ current_batch_tokens = 0
+
+ for _, row in summary_df.iterrows():
+ summary_text = row.get('Enhanced_Summary', row['Summary'])
+ formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
+ summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
+
+ # If adding this summary would exceed the safe token limit, start a new batch
+ if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
+ if current_batch: # Only append if we have summaries in the current batch
+ summary_batches.append(current_batch)
+ current_batch = []
+ current_batch_tokens = 0
+
+ current_batch.append(formatted_summary)
+ current_batch_tokens += summary_tokens
+
+ # Add the last batch if it has any summaries
+ if current_batch:
+ summary_batches.append(current_batch)
+
+ # Generate overview for each batch
+ batch_overviews = []
+ with st.spinner("Generating batch summaries..."):
+ for i, batch in enumerate(summary_batches, 1):
+ st.write(f"Processing batch {i} of {len(summary_batches)}...")
+
+ batch_text = "\n\n".join(batch)
+ batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
+
+Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
+1. Preserve all hyperlinked references exactly as they appear in the input summaries
+2. Maintain the HTML anchor tags () intact when using information from the summaries
+3. Keep the markdown formatting for better readability
+4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters
+
+Here are the cluster summaries to synthesize:
+
+{batch_text}"""
+
+ # Generate overview for this batch
+ high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
+ high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message])
+ high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
+ batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
+ batch_overviews.append(batch_overview)
+
+ # Now combine the batch overviews
+ with st.spinner("Generating final combined summary..."):
+ combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents.
+
+Please create a final comprehensive synthesis that:
+1. Integrates the key themes and findings from all parts
+2. Preserves all hyperlinked references exactly as they appear
+3. Maintains the HTML anchor tags () intact
+4. Keeps the markdown formatting for better readability
+5. Creates a coherent narrative across all parts
+6. Highlights any themes that span multiple parts
+
+Here are the overviews to synthesize:
+
+### Part 1:
+
+{combined_overviews}"""
+
+ # Verify the final prompt's token count
+ final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
+ if final_prompt_tokens > MAX_SAFE_TOKENS:
+ st.error(f"❌ Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.")
+ high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ else:
+ # Generate final synthesis
+ high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
+ high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
+
+ # Store both versions of the summary
+ st.session_state['high_level_summary'] = high_level_summary
+ st.session_state['enhanced_summary'] = high_level_summary
+
+ # Set flag to indicate summarization is complete
+ st.session_state['summarization_completed'] = True
+
+ # Update the display without rerunning
+ st.write("### High-Level Summary:")
+ st.markdown(high_level_summary, unsafe_allow_html=True)
+
+ # Display cluster summaries
+ st.write("### Cluster Summaries:")
+ if enable_references and 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.write("### Summaries per Cluster:")
+ if 'Cluster_Name' in summary_df.columns:
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Download
+ if 'Enhanced_Summary' in summary_df.columns:
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ else:
+ dl_df = summary_df
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+
+ # Display existing summaries if available and summarization was completed
+ if st.session_state.get('summarization_completed', False):
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ if 'high_level_summary' in st.session_state:
+ st.write("### High-Level Summary:")
+ st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True)
+
+ st.write("### Cluster Summaries:")
+ summary_df = st.session_state['summary_df']
+ if 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Add download button for existing summaries
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
+ if 'Cluster_Name' in dl_df.columns:
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+ else:
+ st.warning("No data available for summarization.")
+
+ # Display existing summaries if available (when returning to the tab)
+ if not st.session_state.get('_summarization_button_clicked', False): # Only show if not just generated
+ if 'high_level_summary' in st.session_state:
+ st.write("### Existing High-Level Summary:")
+ if st.session_state.get('enhanced_summary'):
+ st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True)
+ with st.expander("View original summary (without references)"):
+ st.write(st.session_state['high_level_summary'])
+ else:
+ st.write(st.session_state['high_level_summary'])
+
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ st.write("### Existing Cluster Summaries:")
+ summary_df = st.session_state['summary_df']
+ if 'Enhanced_Summary' in summary_df.columns:
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ st.write(f"**Topic {row['Topic']} - {cluster_name}**")
+ st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
+ st.write("---")
+ with st.expander("View original summaries in table format"):
+ display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
+ display_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ st.dataframe(display_df, hide_index=True)
+ else:
+ st.dataframe(summary_df, hide_index=True)
+
+ # Add download button for existing summaries
+ dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
+ if 'Cluster_Name' in dl_df.columns:
+ dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
+ csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
+ b64 = base64.b64encode(csv_bytes).decode()
+ href = f'Download Summaries CSV'
+ st.markdown(href, unsafe_allow_html=True)
+
+
+ ###############################################################################
+ # Tab: Chat
+ ###############################################################################
+ with tab_chat:
+ st.header("Chat with Your Data")
+
+ # Add explanation about chat functionality
+ with st.expander("ℹ️ How Chat Works", expanded=False):
+ st.markdown("""
+ ### Understanding Chat with Your Data
+
+ The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works:
+
+ 1. **Data Selection**:
+ - Choose which dataset to chat about (filtered, clustered, or search results)
+ - Optionally focus on specific clusters if clustering was performed
+ - System automatically includes relevant context from your selection
+
+ 2. **Context Window**:
+ - Shows how much of the GPT-4 context window is being used
+ - Helps you understand if you need to filter data further
+ - Displays token usage statistics
+
+ 3. **Chat Features**:
+ - Ask questions about your data
+ - Get insights and analysis
+ - Reference specific documents or clusters
+ - Download chat context for transparency
+
+ ### Best Practices
+
+ 1. **Data Selection**:
+ - Start with filtered or clustered data for more focused conversations
+ - Select specific clusters if you want to dive deep into a topic
+ - Consider the context window usage when selecting data
+
+ 2. **Asking Questions**:
+ - Be specific in your questions
+ - Ask about patterns, trends, or insights
+ - Reference clusters or documents by their IDs
+ - Build on previous questions for deeper analysis
+
+ 3. **Managing Context**:
+ - Monitor the context window usage
+ - Filter data further if context is too full
+ - Download chat context for documentation
+ - Clear chat history to start fresh
+
+ ### Tips for Better Results
+
+ - **Question Types**:
+ - "What are the main themes in cluster 3?"
+ - "Compare the findings between clusters 1 and 2"
+ - "Summarize the methodology used across these documents"
+ - "What are the common outcomes reported?"
+
+ - **Follow-up Questions**:
+ - Build on previous answers
+ - Ask for clarification
+ - Request specific examples
+ - Explore relationships between findings
+ """)
+
+ # Function to check data source availability
+ def get_available_data_sources():
+ sources = []
+ if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
+ sources.append("Filtered Dataset")
+ if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
+ sources.append("Clustered Data")
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ sources.append("Search Results")
+ if ('high_level_summary' in st.session_state or
+ ('summary_df' in st.session_state and not st.session_state['summary_df'].empty)):
+ sources.append("Summarized Data")
+ return sources
+
+ # Get available data sources
+ available_sources = get_available_data_sources()
+
+ if not available_sources:
+ st.warning("No data available for chat. Please filter, cluster, search, or summarize first.")
+ st.stop()
+
+ # Initialize or update data source in session state
+ if 'chat_data_source' not in st.session_state:
+ st.session_state.chat_data_source = available_sources[0]
+ elif st.session_state.chat_data_source not in available_sources:
+ st.session_state.chat_data_source = available_sources[0]
+
+ # Data source selection with automatic fallback
+ data_source = st.radio(
+ "Select data to chat about:",
+ available_sources,
+ index=available_sources.index(st.session_state.chat_data_source),
+ help="Choose which dataset you want to analyze in the chat."
+ )
+
+ # Update session state if data source changed
+ if data_source != st.session_state.chat_data_source:
+ st.session_state.chat_data_source = data_source
+ # Clear any cluster-specific selections if switching data sources
+ if 'chat_selected_cluster' in st.session_state:
+ del st.session_state.chat_selected_cluster
+
+ # Get the appropriate DataFrame based on selected source
+ df_chat = None
+ if data_source == "Filtered Dataset":
+ df_chat = st.session_state['filtered_df']
+ elif data_source == "Clustered Data":
+ df_chat = st.session_state['clustered_data']
+ elif data_source == "Search Results":
+ df_chat = st.session_state['search_results']
+ elif data_source == "Summarized Data":
+ # Create DataFrame with selected summaries
+ summary_rows = []
+
+ # Add high-level summary if available
+ if 'high_level_summary' in st.session_state:
+ summary_rows.append({
+ 'Summary_Type': 'High-Level Summary',
+ 'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary'])
+ })
+
+ # Add cluster summaries if available
+ if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
+ summary_df = st.session_state['summary_df']
+ for _, row in summary_df.iterrows():
+ summary_rows.append({
+ 'Summary_Type': f"Cluster {row['Topic']} Summary",
+ 'Content': row.get('Enhanced_Summary', row['Summary'])
+ })
+
+ if summary_rows:
+ df_chat = pd.DataFrame(summary_rows)
+
+ if df_chat is not None and not df_chat.empty:
+ # If we have clustered data, allow cluster selection
+ selected_cluster = None
+ if data_source != "Summarized Data" and 'Topic' in df_chat.columns:
+ cluster_option = st.radio(
+ "Choose cluster scope:",
+ ["All Clusters", "Specific Cluster"]
+ )
+ if cluster_option == "Specific Cluster":
+ unique_topics = sorted(df_chat['Topic'].unique())
+ # Check if we have cluster names
+ if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
+ summary_df = st.session_state['summary_df']
+ # Create a mapping of topic to name
+ topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
+ # Format the selectbox options
+ topic_options = [
+ (t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}")
+ for t in unique_topics
+ ]
+ selected_cluster = st.selectbox(
+ "Select cluster to focus on:",
+ [t[0] for t in topic_options],
+ format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x)
+ )
+ else:
+ selected_cluster = st.selectbox(
+ "Select cluster to focus on:",
+ unique_topics,
+ format_func=lambda x: f"Cluster {x}"
+ )
+ if selected_cluster is not None:
+ df_chat = df_chat[df_chat['Topic'] == selected_cluster]
+ st.session_state.chat_selected_cluster = selected_cluster
+ elif 'chat_selected_cluster' in st.session_state:
+ del st.session_state.chat_selected_cluster
+
+ # Prepare the data for chat context
+ text_columns = st.session_state.get('text_columns', [])
+ if not text_columns and data_source != "Summarized Data":
+ st.warning("No text columns selected. Please select text columns to enable chat functionality.")
+ st.stop()
+
+ # Instead of limiting to 210 documents, we'll limit by tokens
+ MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95) # 95% of context window
+
+ # Prepare system message first to account for its tokens
+ system_msg = {
+ "role": "system",
+ "content": """You are a specialized assistant analyzing data from a research database.
+ Your role is to:
+ 1. Provide clear, concise answers based on the data provided
+ 2. Highlight relevant information from specific results when answering
+ 3. When referencing specific results, use their row index or ID if available
+ 4. Clearly state if information is not available in the results
+ 5. Maintain a professional and analytical tone
+ 6. Format your responses using Markdown:
+ - Use **bold** for emphasis
+ - Use bullet points and numbered lists for structured information
+ - Create tables using Markdown syntax when presenting structured data
+ - Use backticks for code or technical terms
+ - Include hyperlinks when referencing external sources
+ - Use headings (###) to organize long responses
+
+ The data is provided in a structured format where:""" + ("""
+ - Each result contains multiple fields
+ - Text content is primarily in the following columns: """ + ", ".join(text_columns) + """
+ - Additional metadata and fields are available for reference
+ - If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """
+ - The data consists of AI-generated summaries of the documents
+ - Each summary may contain references to source documents in markdown format
+ - References are shown as [ID] or as clickable hyperlinks
+ - Summaries may be high-level (covering all documents) or cluster-specific""") + """
+ """
+ }
+
+ # Calculate system message tokens
+ system_tokens = len(tokenizer(system_msg["content"])["input_ids"])
+ remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens
+
+ # Prepare the data context with token limiting
+ data_text = "Available Data:\n"
+ included_rows = 0
+ total_rows = len(df_chat)
+
+ if data_source == "Summarized Data":
+ # For summarized data, process row by row
+ for idx, row in df_chat.iterrows():
+ row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n"
+ row_tokens = len(tokenizer(row_text)["input_ids"])
+
+ if remaining_tokens - row_tokens > 0:
+ data_text += row_text
+ remaining_tokens -= row_tokens
+ included_rows += 1
+ else:
+ break
+ else:
+ # For regular data, process row by row
+ for idx, row in df_chat.iterrows():
+ row_text = f"\nItem {idx}:\n"
+ for col in df_chat.columns:
+ if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score':
+ row_text += f"{col}: {row[col]}\n"
+
+ row_tokens = len(tokenizer(row_text)["input_ids"])
+ if remaining_tokens - row_tokens > 0:
+ data_text += row_text
+ remaining_tokens -= row_tokens
+ included_rows += 1
+ else:
+ break
+
+ # Calculate token usage
+ data_tokens = len(tokenizer(data_text)["input_ids"])
+ total_tokens = system_tokens + data_tokens
+ context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100
+
+ # Display token usage and data coverage
+ st.subheader("Context Window Usage")
+ st.write(f"System Message: {system_tokens:,} tokens")
+ st.write(f"Data Context: {data_tokens:,} tokens")
+ st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)")
+ st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)")
+
+ if context_usage_percent > 90:
+ st.warning("⚠️ High context usage! Consider reducing the number of results or filtering further.")
+ elif context_usage_percent > 75:
+ st.info("ℹ️ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.")
+
+ # Add download button for chat context
+ chat_context = f"""System Message:
+ {system_msg['content']}
+
+ {data_text}"""
+ st.download_button(
+ label="📥 Download Chat Context",
+ data=chat_context,
+ file_name="chat_context.txt",
+ mime="text/plain",
+ help="Download the exact context that the chatbot receives"
+ )
+
+ # Chat interface
+ col_chat1, col_chat2 = st.columns([3, 1])
+ with col_chat1:
+ user_input = st.text_area("Ask a question about your data:", key="chat_input")
+ with col_chat2:
+ if st.button("Clear Chat History"):
+ st.session_state.chat_history = []
+ st.rerun()
+
+ # Store current tab index before processing
+ current_tab = tabs_titles.index("Chat")
+
+ if st.button("Send", key="send_button"):
+ if user_input:
+ # Set the active tab index to stay on Chat
+ st.session_state.active_tab_index = current_tab
+
+ with st.spinner("Processing your question..."):
+ # Add user's question to chat history
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
+
+ # Prepare messages for API call
+ messages = [system_msg]
+ messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"})
+
+ # Get response from OpenAI
+ response = get_chat_response(messages)
+
+ if response:
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
+
+ # Display chat history
+ st.subheader("Chat History")
+ for message in st.session_state.chat_history:
+ if message["role"] == "user":
+ st.write("**You:**", message["content"])
+ else:
+ st.write("**Assistant:**")
+ st.markdown(message["content"], unsafe_allow_html=True)
+ st.write("---") # Add a separator between messages
+
+
+ ###############################################################################
+ # Tab: Internal Validation
+ ###############################################################################
+
+else: # Simple view
+ st.header("Automatic Mode")
+
+ # Initialize session state for automatic view
+ if 'df' not in st.session_state:
+ default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
+ df = load_default_dataset(default_dataset_path)
+ if df is not None:
+ st.session_state['df'] = df.copy()
+ st.session_state['using_default_dataset'] = True
+ st.session_state['filtered_df'] = df.copy()
+
+ # Set default text columns if not already set
+ if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
+ default_text_cols = []
+ if 'Title' in df.columns and 'Description' in df.columns:
+ default_text_cols = ['Title', 'Description']
+ st.session_state['text_columns'] = default_text_cols
+
+ # Single search bar for automatic processing
+ #st.write("Enter your query to automatically search, cluster, and summarize the results:")
+ query = st.text_input("Write your query here:")
+
+
+
+
+ if st.button("SNAP!"):
+ if query.strip():
+ # Step 1: Semantic Search
+ st.write("### Step 1: Semantic Search")
+ with st.spinner("Performing Semantic Search..."):
+ text_columns = st.session_state.get('text_columns', [])
+ if text_columns:
+ df_full = st.session_state['df']
+ embeddings, _ = load_or_compute_embeddings(
+ df_full,
+ st.session_state.get('using_default_dataset', False),
+ st.session_state.get('uploaded_file_name'),
+ text_columns
+ )
+
+ if embeddings is not None:
+ model = get_embedding_model()
+ df_filtered = st.session_state['filtered_df'].fillna("")
+ search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
+
+ subset_indices = df_filtered.index
+ subset_embeddings = embeddings[subset_indices]
+
+ query_embedding = model.encode([query], device=device)
+ similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
+
+ similarity_threshold = 0.35 # Default threshold
+ above_threshold_indices = np.where(similarities > similarity_threshold)[0]
+
+ if len(above_threshold_indices) > 0:
+ selected_indices = subset_indices[above_threshold_indices]
+ results = df_filtered.loc[selected_indices].copy()
+ results['similarity_score'] = similarities[above_threshold_indices]
+ results.sort_values(by='similarity_score', ascending=False, inplace=True)
+ st.session_state['search_results'] = results.copy()
+ st.write(f"Found {len(results)} relevant documents")
+ else:
+ st.warning("No results found above the similarity threshold.")
+ st.stop()
+
+ # Step 2: Clustering
+ if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
+ st.write("### Step 2: Clustering")
+ with st.spinner("Performing clustering..."):
+ df_to_cluster = st.session_state['search_results'].copy()
+ dfc = df_to_cluster.copy().fillna("")
+ dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
+
+ # Filter embeddings to those rows
+ selected_indices = dfc.index
+ embeddings_clustering = embeddings[selected_indices]
+
+ # Basic cleaning
+ stop_words = set(stopwords.words('english'))
+ texts_cleaned = []
+ for text in dfc['text'].tolist():
+ try:
+ word_tokens = word_tokenize(text)
+ filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
+ texts_cleaned.append(filtered_text)
+ except Exception as e:
+ texts_cleaned.append(text)
+
+ min_cluster_size = 5 # Default value
+
+ try:
+ # Convert embeddings to CPU numpy if needed
+ if torch.is_tensor(embeddings_clustering):
+ embeddings_for_clustering = embeddings_clustering.cpu().numpy()
+ else:
+ embeddings_for_clustering = embeddings_clustering
+
+ # Build the HDBSCAN model
+ hdbscan_model = HDBSCAN(
+ min_cluster_size=min_cluster_size,
+ metric='euclidean',
+ cluster_selection_method='eom'
+ )
+
+ # Build the BERTopic model
+ topic_model = BERTopic(
+ embedding_model=get_embedding_model(),
+ hdbscan_model=hdbscan_model
+ )
+
+ # Fit the model and get topics
+ topics, probs = topic_model.fit_transform(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+
+ # Store results
+ dfc['Topic'] = topics
+ st.session_state['topic_model'] = topic_model
+ st.session_state['clustered_data'] = dfc.copy()
+ st.session_state['clustering_completed'] = True
+
+ # Display clustering results summary
+ unique_topics = sorted(list(set(topics)))
+ num_clusters = len([t for t in unique_topics if t != -1]) # Exclude noise cluster (-1)
+ noise_docs = len([t for t in topics if t == -1])
+ clustered_docs = len(topics) - noise_docs
+
+ st.write(f"Found {num_clusters} distinct clusters")
+ #st.write(f"Documents successfully clustered: {clustered_docs}")
+ #if noise_docs > 0:
+ # st.write(f"Documents not fitting in any cluster: {noise_docs}")
+
+ # Show quick cluster overview
+ cluster_info = []
+ for t in unique_topics:
+ if t != -1: # Skip noise cluster in the overview
+ cluster_docs = dfc[dfc['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
+ cluster_info.append((t, count, top_keywords))
+
+ if cluster_info:
+ #st.write("### Quick Cluster Overview:")
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
+ # st.dataframe(
+ # cluster_df,
+ # column_config={
+ # "Topic": st.column_config.NumberColumn("Topic", help="Topic ID"),
+ # "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ # "Top Keywords": st.column_config.TextColumn(
+ # "Top Keywords",
+ # help="Top 5 keywords that characterize this topic"
+ # )
+ # },
+ # hide_index=True
+ # )
+
+ # Generate visualizations
+ try:
+ st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
+ except Exception:
+ st.session_state['intertopic_distance_fig'] = None
+
+ try:
+ st.session_state['topic_document_fig'] = topic_model.visualize_documents(
+ texts_cleaned,
+ embeddings=embeddings_for_clustering
+ )
+ except Exception:
+ st.session_state['topic_document_fig'] = None
+
+ try:
+ hierarchy = topic_model.hierarchical_topics(texts_cleaned)
+ st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
+ st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
+ except Exception:
+ st.session_state['hierarchy'] = pd.DataFrame()
+ st.session_state['hierarchy_fig'] = None
+
+ except Exception as e:
+ st.error(f"An error occurred during clustering: {str(e)}")
+ st.stop()
+
+ # Step 3: Summarization
+ if st.session_state.get('clustering_completed', False):
+ st.write("### Step 3: Summarization")
+
+ # Initialize OpenAI client
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
+ if not openai_api_key:
+ st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
+ st.stop()
+
+ llm = ChatOpenAI(
+ api_key=openai_api_key,
+ model_name='gpt-4o-mini',
+ temperature=0.7,
+ max_tokens=1000
+ )
+
+ df_scope = st.session_state['clustered_data']
+ unique_selected_topics = df_scope['Topic'].unique()
+
+ # Process summaries in parallel
+ with st.spinner("Generating summaries..."):
+ local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries.
+You will be given text and an objective context. Please produce a clear, cohesive,
+and thematically relevant summary.
+Focus on key points, insights, or patterns that emerge from the text.""")
+ local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
+ local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
+
+ # Find URL column if it exists
+ url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None)
+
+ summaries = process_summaries_in_parallel(
+ df_scope=df_scope,
+ unique_selected_topics=unique_selected_topics,
+ llm=llm,
+ chat_prompt=local_chat_prompt,
+ enable_references=True,
+ reference_id_column=df_scope.columns[0],
+ url_column=url_column, # Add URL column for clickable links
+ max_workers=min(16, len(unique_selected_topics))
+ )
+
+ if summaries:
+ summary_df = pd.DataFrame(summaries)
+ st.session_state['summary_df'] = summary_df
+
+ # Display updated cluster overview
+ if 'Cluster_Name' in summary_df.columns:
+ st.write("### Updated Topic Overview:")
+ cluster_info = []
+ for t in unique_selected_topics:
+ cluster_docs = df_scope[df_scope['Topic'] == t]
+ count = len(cluster_docs)
+ top_words = topic_model.get_topic(t)
+ top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
+ cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0]
+ cluster_info.append((t, cluster_name, count, top_keywords))
+
+ cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"])
+ st.dataframe(
+ cluster_df,
+ column_config={
+ "Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
+ "Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
+ "Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
+ "Top Keywords": st.column_config.TextColumn(
+ "Top Keywords",
+ help="Top 5 keywords that characterize this topic"
+ )
+ },
+ hide_index=True
+ )
+
+ # Generate and display high-level summary
+ with st.spinner("Generating high-level summary..."):
+ formatted_summaries = []
+ summary_batches = []
+ current_batch = []
+ current_batch_tokens = 0
+ MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75)
+
+ for _, row in summary_df.iterrows():
+ summary_text = row.get('Enhanced_Summary', row['Summary'])
+ formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
+ summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
+
+ if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
+ if current_batch:
+ summary_batches.append(current_batch)
+ current_batch = []
+ current_batch_tokens = 0
+
+ current_batch.append(formatted_summary)
+ current_batch_tokens += summary_tokens
+
+ if current_batch:
+ summary_batches.append(current_batch)
+
+ # Process each batch separately first
+ batch_overviews = []
+ for i, batch in enumerate(summary_batches, 1):
+ st.write(f"Processing summary batch {i} of {len(summary_batches)}...")
+ batch_text = "\n\n".join(batch)
+ batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or ID.
+
+Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
+1. Preserve all hyperlinked references exactly as they appear in the input summaries
+2. Maintain the HTML anchor tags () intact when using information from the summaries
+3. Keep the markdown formatting for better readability
+4. Create clear sections with headings for different themes
+5. Use bullet points or numbered lists where appropriate
+6. Focus on synthesizing the main themes and findings
+
+Here are the cluster summaries to synthesize:
+
+{batch_text}"""
+
+ high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
+ batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
+ batch_overviews.append(batch_overview)
+
+ # Now create the final synthesis
+ if len(batch_overviews) > 1:
+ st.write("Generating final synthesis...")
+ combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents.
+
+Please create a final comprehensive synthesis that:
+1. Integrates the key themes and findings from all parts into a cohesive narrative
+2. Preserves all hyperlinked references exactly as they appear
+3. Maintains the HTML anchor tags () intact
+4. Uses clear section headings and structured formatting
+5. Highlights cross-cutting themes and relationships between different aspects
+6. Provides a clear introduction and conclusion
+
+Here are the overviews to synthesize:
+
+# Part 1
+
+{combined_overviews}"""
+
+ final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
+ if final_prompt_tokens > MAX_SAFE_TOKENS:
+ # If too long, just combine with headers
+ high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
+ else:
+ high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
+ high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
+ else:
+ # If only one batch, use its overview directly
+ high_level_summary = batch_overviews[0]
+
+ st.session_state['high_level_summary'] = high_level_summary
+ st.session_state['enhanced_summary'] = high_level_summary
+
+ # Display summaries
+ st.write("### High-Level Summary:")
+ with st.expander("High-Level Summary", expanded=True):
+ st.markdown(high_level_summary, unsafe_allow_html=True)
+
+ st.write("### Cluster Summaries:")
+ for idx, row in summary_df.iterrows():
+ cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
+ with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False):
+ st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True)
+ st.markdown("##### About this tool")
+ with st.expander("Click to expand/collapse", expanded=True):
+ st.markdown("""
+ This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on.
+
+ **Tips:**
+ - **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`).
+ - Avoid writing full questions — **this is not a chatbot**.
+ - Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`).
+ - Focus on **concepts or themes** — not single words like `"climate"` or `"yield"` alone.
+ - Example good queries:
+ - `"climate adaptation smallholder farming"`
+ - `"digital agriculture innovations"`
+ - `"nutrition-sensitive value chains"`
+
+ **Example use case**:
+ You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**.
+ A good search phrase would be:
+ 👉 `"poverty reduction maize Africa"`
+ This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*.
+ """)
\ No newline at end of file