|
|
|
|
|
import streamlit as st
|
|
|
|
|
|
st.set_page_config(page_title="SNAP", layout="wide")
|
|
|
|
|
|
import warnings
|
|
|
|
warnings.filterwarnings('ignore', message='.*torch.classes.*__path__._path.*')
|
|
warnings.filterwarnings('ignore', message='.*torch.classes.*registered via torch::class_.*')
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import os
|
|
import io
|
|
import time
|
|
from datetime import datetime
|
|
import base64
|
|
import re
|
|
import pickle
|
|
from typing import List, Dict, Any, Tuple
|
|
import plotly.express as px
|
|
import torch
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from functools import partial
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from bertopic import BERTopic
|
|
from hdbscan import HDBSCAN
|
|
import nltk
|
|
from nltk.corpus import stopwords
|
|
from nltk.tokenize import word_tokenize
|
|
|
|
|
|
from langchain.chains import LLMChain
|
|
from langchain_community.chat_models import ChatOpenAI
|
|
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
|
from openai import OpenAI
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
|
|
client = OpenAI()
|
|
|
|
|
|
|
|
|
|
def get_base_dir():
|
|
try:
|
|
base_dir = os.path.dirname(__file__)
|
|
if not base_dir:
|
|
return os.getcwd()
|
|
return base_dir
|
|
except NameError:
|
|
|
|
return os.getcwd()
|
|
|
|
BASE_DIR = get_base_dir()
|
|
|
|
|
|
def get_model_dir():
|
|
base_dir = get_base_dir()
|
|
model_dir = os.path.join(base_dir, 'models')
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
return model_dir
|
|
|
|
|
|
def load_tokenizer():
|
|
model_dir = get_model_dir()
|
|
tokenizer_dir = os.path.join(model_dir, 'tokenizer')
|
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
|
|
try:
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir)
|
|
|
|
except Exception as e:
|
|
|
|
try:
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
tokenizer.save_pretrained(tokenizer_dir)
|
|
|
|
except Exception as download_e:
|
|
|
|
raise
|
|
|
|
return tokenizer
|
|
|
|
|
|
try:
|
|
tokenizer = load_tokenizer()
|
|
except Exception as e:
|
|
|
|
tokenizer = None
|
|
|
|
MAX_CONTEXT_WINDOW = 128000
|
|
|
|
|
|
if 'chat_history' not in st.session_state:
|
|
st.session_state.chat_history = []
|
|
|
|
|
|
|
|
|
|
def get_chat_response(messages):
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model="gpt-4o-mini",
|
|
messages=messages,
|
|
temperature=0,
|
|
)
|
|
return response.choices[0].message.content.strip()
|
|
except Exception as e:
|
|
st.error(f"Error querying OpenAI: {e}")
|
|
return None
|
|
|
|
|
|
|
|
|
|
def generate_raw_cluster_summary(
|
|
topic_val: int,
|
|
cluster_df: pd.DataFrame,
|
|
llm: Any,
|
|
chat_prompt: Any
|
|
) -> Dict[str, Any]:
|
|
"""Generate a summary for a single cluster without reference enhancement,
|
|
automatically trimming text if it exceeds a safe token limit."""
|
|
cluster_text = " ".join(cluster_df['text'].tolist())
|
|
if not cluster_text.strip():
|
|
return None
|
|
|
|
|
|
safe_limit = int(MAX_CONTEXT_WINDOW * 0.95)
|
|
|
|
|
|
encoded_text = tokenizer.encode(cluster_text, add_special_tokens=False)
|
|
|
|
|
|
if len(encoded_text) > safe_limit:
|
|
|
|
encoded_text = encoded_text[:safe_limit]
|
|
cluster_text = tokenizer.decode(encoded_text)
|
|
|
|
user_prompt_local = f"**Text to summarize**: {cluster_text}"
|
|
try:
|
|
local_chain = LLMChain(llm=llm, prompt=chat_prompt)
|
|
summary_local = local_chain.run(user_prompt=user_prompt_local).strip()
|
|
return {'Topic': topic_val, 'Summary': summary_local}
|
|
except Exception as e:
|
|
st.error(f"Error generating summary for cluster {topic_val}: {str(e)}")
|
|
return None
|
|
|
|
|
|
|
|
|
|
def enhance_summary_with_references(
|
|
summary_dict: Dict[str, Any],
|
|
df_scope: pd.DataFrame,
|
|
reference_id_column: str,
|
|
url_column: str = None,
|
|
llm: Any = None
|
|
) -> Dict[str, Any]:
|
|
"""Add references to a summary."""
|
|
if not summary_dict or 'Summary' not in summary_dict:
|
|
return summary_dict
|
|
|
|
try:
|
|
cluster_df = df_scope[df_scope['Topic'] == summary_dict['Topic']]
|
|
enhanced = add_references_to_summary(
|
|
summary_dict['Summary'],
|
|
cluster_df,
|
|
reference_id_column,
|
|
url_column,
|
|
llm
|
|
)
|
|
summary_dict['Enhanced_Summary'] = enhanced
|
|
return summary_dict
|
|
except Exception as e:
|
|
st.error(f"Error enhancing summary for cluster {summary_dict.get('Topic')}: {str(e)}")
|
|
return summary_dict
|
|
|
|
|
|
|
|
|
|
def process_summaries_in_parallel(
|
|
df_scope: pd.DataFrame,
|
|
unique_selected_topics: List[int],
|
|
llm: Any,
|
|
chat_prompt: Any,
|
|
enable_references: bool = False,
|
|
reference_id_column: str = None,
|
|
url_column: str = None,
|
|
max_workers: int = 16
|
|
) -> List[Dict[str, Any]]:
|
|
"""Process multiple cluster summaries in parallel using ThreadPoolExecutor."""
|
|
summaries = []
|
|
total_topics = len(unique_selected_topics)
|
|
|
|
|
|
progress_text = st.empty()
|
|
progress_bar = st.progress(0)
|
|
|
|
try:
|
|
|
|
progress_text.text(f"Phase 1/3: Generating cluster summaries in parallel (0/{total_topics} completed)")
|
|
completed_summaries = 0
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
future_to_topic = {
|
|
executor.submit(
|
|
generate_raw_cluster_summary,
|
|
topic_val,
|
|
df_scope[df_scope['Topic'] == topic_val],
|
|
llm,
|
|
chat_prompt
|
|
): topic_val
|
|
for topic_val in unique_selected_topics
|
|
}
|
|
|
|
|
|
for future in future_to_topic:
|
|
try:
|
|
result = future.result()
|
|
if result:
|
|
summaries.append(result)
|
|
completed_summaries += 1
|
|
|
|
progress = completed_summaries / total_topics
|
|
progress_bar.progress(progress)
|
|
progress_text.text(
|
|
f"Phase 1/3: Generating cluster summaries in parallel ({completed_summaries}/{total_topics} completed)"
|
|
)
|
|
except Exception as e:
|
|
topic_val = future_to_topic[future]
|
|
st.error(f"Error in summary generation for cluster {topic_val}: {str(e)}")
|
|
completed_summaries += 1
|
|
continue
|
|
|
|
|
|
if enable_references and reference_id_column and summaries:
|
|
total_to_enhance = len(summaries)
|
|
completed_enhancements = 0
|
|
progress_text.text(f"Phase 2/3: Adding references to summaries (0/{total_to_enhance} completed)")
|
|
progress_bar.progress(0)
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
future_to_summary = {
|
|
executor.submit(
|
|
enhance_summary_with_references,
|
|
summary_dict,
|
|
df_scope,
|
|
reference_id_column,
|
|
url_column,
|
|
llm
|
|
): summary_dict.get('Topic')
|
|
for summary_dict in summaries
|
|
}
|
|
|
|
|
|
enhanced_summaries = []
|
|
for future in future_to_summary:
|
|
try:
|
|
result = future.result()
|
|
if result:
|
|
enhanced_summaries.append(result)
|
|
completed_enhancements += 1
|
|
|
|
progress = completed_enhancements / total_to_enhance
|
|
progress_bar.progress(progress)
|
|
progress_text.text(
|
|
f"Phase 2/3: Adding references to summaries ({completed_enhancements}/{total_to_enhance} completed)"
|
|
)
|
|
except Exception as e:
|
|
topic_val = future_to_summary[future]
|
|
st.error(f"Error in reference enhancement for cluster {topic_val}: {str(e)}")
|
|
completed_enhancements += 1
|
|
continue
|
|
|
|
summaries = enhanced_summaries
|
|
|
|
|
|
if summaries:
|
|
total_to_name = len(summaries)
|
|
completed_names = 0
|
|
progress_text.text(f"Phase 3/3: Generating cluster names (0/{total_to_name} completed)")
|
|
progress_bar.progress(0)
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
future_to_summary = {
|
|
executor.submit(
|
|
generate_cluster_name,
|
|
summary_dict.get('Enhanced_Summary', summary_dict['Summary']),
|
|
llm
|
|
): summary_dict.get('Topic')
|
|
for summary_dict in summaries
|
|
}
|
|
|
|
|
|
named_summaries = []
|
|
for future in future_to_summary:
|
|
try:
|
|
cluster_name = future.result()
|
|
topic_val = future_to_summary[future]
|
|
|
|
summary_dict = next(s for s in summaries if s['Topic'] == topic_val)
|
|
summary_dict['Cluster_Name'] = cluster_name
|
|
named_summaries.append(summary_dict)
|
|
completed_names += 1
|
|
|
|
progress = completed_names / total_to_name
|
|
progress_bar.progress(progress)
|
|
progress_text.text(
|
|
f"Phase 3/3: Generating cluster names ({completed_names}/{total_to_name} completed)"
|
|
)
|
|
except Exception as e:
|
|
topic_val = future_to_summary[future]
|
|
st.error(f"Error in cluster naming for cluster {topic_val}: {str(e)}")
|
|
completed_names += 1
|
|
continue
|
|
|
|
summaries = named_summaries
|
|
finally:
|
|
|
|
progress_text.empty()
|
|
progress_bar.empty()
|
|
|
|
return summaries
|
|
|
|
|
|
|
|
|
|
def generate_cluster_name(summary_text: str, llm: Any) -> str:
|
|
"""Generate a concise, descriptive name for a cluster based on its summary."""
|
|
system_prompt = """You are a cluster naming expert. Your task is to generate a very concise (3-6 words) but descriptive name for a cluster based on its summary. The name should capture the main theme or focus of the cluster.
|
|
|
|
Rules:
|
|
1. Keep it between 3-6 words
|
|
2. Be specific but concise
|
|
3. Capture the main theme/focus
|
|
4. Use title case
|
|
4. Do not include words like "Cluster", "Topic", or "Theme"
|
|
5. Focus on the content, not metadata
|
|
|
|
Example good names:
|
|
- Agricultural Water Management Innovation
|
|
- Gender Equality in Farming
|
|
- Climate-Smart Village Implementation
|
|
- Sustainable Livestock Practices"""
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": f"Generate a concise cluster name based on this summary:\n\n{summary_text}"}
|
|
]
|
|
|
|
try:
|
|
response = get_chat_response(messages)
|
|
|
|
cluster_name = response.strip().strip('"').strip("'").strip()
|
|
return cluster_name
|
|
except Exception as e:
|
|
st.error(f"Error generating cluster name: {str(e)}")
|
|
return "Unnamed Cluster"
|
|
|
|
|
|
|
|
|
|
def get_base_dir():
|
|
try:
|
|
base_dir = os.path.dirname(__file__)
|
|
if not base_dir:
|
|
return os.getcwd()
|
|
return base_dir
|
|
except NameError:
|
|
|
|
return os.getcwd()
|
|
|
|
BASE_DIR = get_base_dir()
|
|
|
|
|
|
|
|
|
|
def init_nltk_resources():
|
|
"""Initialize NLTK resources with better error handling and less verbose output"""
|
|
nltk.data.path.append('/home/appuser/nltk_data')
|
|
|
|
resources = {
|
|
'tokenizers/punkt': 'punkt_tab',
|
|
'corpora/stopwords': 'stopwords'
|
|
}
|
|
|
|
for resource_path, resource_name in resources.items():
|
|
try:
|
|
nltk.data.find(resource_path)
|
|
except LookupError:
|
|
try:
|
|
nltk.download(resource_name, quiet=True)
|
|
except Exception as e:
|
|
st.warning(f"Error downloading NLTK resource {resource_name}: {e}")
|
|
|
|
|
|
try:
|
|
from nltk.tokenize import PunktSentenceTokenizer
|
|
tokenizer = PunktSentenceTokenizer()
|
|
tokenizer.tokenize("Test sentence.")
|
|
except Exception as e:
|
|
st.error(f"Error initializing NLTK tokenizer: {e}")
|
|
try:
|
|
nltk.download('punkt_tab', quiet=True)
|
|
except Exception as e:
|
|
st.error(f"Failed to download punkt_tab tokenizer: {e}")
|
|
|
|
|
|
init_nltk_resources()
|
|
|
|
|
|
|
|
|
|
def add_references_to_summary(summary, source_df, reference_column, url_column=None, llm=None):
|
|
"""
|
|
Add references to a summary by identifying which parts of the summary come
|
|
from which source documents. References will be appended as [ID],
|
|
optionally linked if a URL column is provided.
|
|
|
|
Args:
|
|
summary (str): The summary text to enhance with references.
|
|
source_df (DataFrame): DataFrame containing the source documents.
|
|
reference_column (str): Column name to use for reference IDs.
|
|
url_column (str, optional): Column name containing URLs for hyperlinks.
|
|
llm (LLM, optional): Language model for source attribution.
|
|
Returns:
|
|
str: Enhanced summary with references as HTML if possible.
|
|
"""
|
|
if summary.strip() == "" or source_df.empty or reference_column not in source_df.columns:
|
|
return summary
|
|
|
|
|
|
if llm is None:
|
|
return summary
|
|
|
|
|
|
paragraphs = summary.split('\n\n')
|
|
enhanced_paragraphs = []
|
|
|
|
|
|
source_texts = []
|
|
reference_ids = []
|
|
urls = []
|
|
for _, row in source_df.iterrows():
|
|
if 'text' in row and pd.notna(row['text']) and pd.notna(row[reference_column]):
|
|
source_texts.append(str(row['text']))
|
|
reference_ids.append(str(row[reference_column]))
|
|
if url_column and url_column in row and pd.notna(row[url_column]):
|
|
urls.append(str(row[url_column]))
|
|
else:
|
|
urls.append(None)
|
|
if not source_texts:
|
|
return summary
|
|
|
|
|
|
url_map = {}
|
|
for ref_id, u in zip(reference_ids, urls):
|
|
if u:
|
|
url_map[ref_id] = u
|
|
|
|
|
|
system_prompt = """
|
|
You are an expert at identifying the source of information. You will be given:
|
|
1. A sentence or bullet point from a summary
|
|
2. A list of source texts with their IDs
|
|
|
|
Your task is to identify which source text(s) the text most likely came from.
|
|
Return ONLY the IDs of the source texts that contributed to the text, separated by commas.
|
|
If you cannot confidently attribute the text to any source, return "unknown".
|
|
"""
|
|
|
|
for paragraph in paragraphs:
|
|
if not paragraph.strip():
|
|
enhanced_paragraphs.append('')
|
|
continue
|
|
|
|
|
|
if any(line.strip().startswith('- ') or line.strip().startswith('* ') for line in paragraph.split('\n')):
|
|
|
|
bullet_lines = paragraph.split('\n')
|
|
enhanced_bullets = []
|
|
for line in bullet_lines:
|
|
if not line.strip():
|
|
enhanced_bullets.append(line)
|
|
continue
|
|
|
|
if line.strip().startswith('- ') or line.strip().startswith('* '):
|
|
|
|
user_prompt = f"""
|
|
Text: {line.strip()}
|
|
|
|
Source texts:
|
|
{'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
|
|
|
|
Which source ID(s) did this text most likely come from? Return only the ID(s) separated by commas, or "unknown".
|
|
"""
|
|
|
|
try:
|
|
system_message = SystemMessagePromptTemplate.from_template(system_prompt)
|
|
human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
|
|
chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
|
|
chain = LLMChain(llm=llm, prompt=chat_prompt)
|
|
response = chain.run(user_prompt=user_prompt)
|
|
source_ids = response.strip()
|
|
|
|
if source_ids.lower() == "unknown":
|
|
enhanced_bullets.append(line)
|
|
else:
|
|
|
|
source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
|
|
source_ids = re.sub(r'\s+', '', source_ids)
|
|
ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
|
|
|
|
if ids:
|
|
ref_parts = []
|
|
for id_ in ids:
|
|
if id_ in url_map:
|
|
ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>')
|
|
else:
|
|
ref_parts.append(id_)
|
|
ref_string = ", ".join(ref_parts)
|
|
enhanced_bullets.append(f"{line} [{ref_string}]")
|
|
else:
|
|
enhanced_bullets.append(line)
|
|
except Exception:
|
|
enhanced_bullets.append(line)
|
|
else:
|
|
enhanced_bullets.append(line)
|
|
|
|
enhanced_paragraphs.append('\n'.join(enhanced_bullets))
|
|
else:
|
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
|
enhanced_sentences = []
|
|
|
|
for sentence in sentences:
|
|
if not sentence.strip():
|
|
continue
|
|
|
|
user_prompt = f"""
|
|
Sentence: {sentence.strip()}
|
|
|
|
Source texts:
|
|
{'\n'.join([f"ID: {ref_id}, Text: {text[:500]}..." for ref_id, text in zip(reference_ids, source_texts)])}
|
|
|
|
Which source ID(s) did this sentence most likely come from? Return only the ID(s) separated by commas, or "unknown".
|
|
"""
|
|
|
|
try:
|
|
system_message = SystemMessagePromptTemplate.from_template(system_prompt)
|
|
human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
|
|
chat_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
|
|
chain = LLMChain(llm=llm, prompt=chat_prompt)
|
|
response = chain.run(user_prompt=user_prompt)
|
|
source_ids = response.strip()
|
|
|
|
if source_ids.lower() == "unknown":
|
|
enhanced_sentences.append(sentence)
|
|
else:
|
|
|
|
source_ids = re.sub(r'[^0-9,\s]', '', source_ids)
|
|
source_ids = re.sub(r'\s+', '', source_ids)
|
|
ids = [id_.strip() for id_ in source_ids.split(',') if id_.strip()]
|
|
|
|
if ids:
|
|
ref_parts = []
|
|
for id_ in ids:
|
|
if id_ in url_map:
|
|
ref_parts.append(f'<a href="{url_map[id_]}" target="_blank">{id_}</a>')
|
|
else:
|
|
ref_parts.append(id_)
|
|
ref_string = ", ".join(ref_parts)
|
|
enhanced_sentences.append(f"{sentence} [{ref_string}]")
|
|
else:
|
|
enhanced_sentences.append(sentence)
|
|
except Exception:
|
|
enhanced_sentences.append(sentence)
|
|
|
|
enhanced_paragraphs.append(' '.join(enhanced_sentences))
|
|
|
|
|
|
return '\n\n'.join(enhanced_paragraphs)
|
|
|
|
|
|
st.sidebar.image("static/SNAP_logo.png", width=350)
|
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
if device == 'cuda':
|
|
st.sidebar.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
else:
|
|
st.sidebar.info("Using CPU")
|
|
|
|
|
|
|
|
|
|
@st.cache_resource
|
|
def get_embedding_model():
|
|
model_dir = get_model_dir()
|
|
st_model_dir = os.path.join(model_dir, 'sentence_transformer')
|
|
os.makedirs(st_model_dir, exist_ok=True)
|
|
|
|
model_name = 'all-MiniLM-L6-v2'
|
|
try:
|
|
|
|
model = SentenceTransformer(st_model_dir)
|
|
|
|
except Exception as e:
|
|
|
|
try:
|
|
|
|
model = SentenceTransformer(model_name)
|
|
model.save(st_model_dir)
|
|
|
|
except Exception as download_e:
|
|
st.error(f"Error downloading sentence transformer model: {str(download_e)}")
|
|
raise
|
|
|
|
return model.to(device)
|
|
|
|
def generate_embeddings(texts, model):
|
|
with st.spinner('Calculating embeddings...'):
|
|
embeddings = model.encode(texts, show_progress_bar=True, device=device)
|
|
return embeddings
|
|
|
|
@st.cache_data
|
|
def load_default_dataset(default_dataset_path):
|
|
if os.path.exists(default_dataset_path):
|
|
df_ = pd.read_excel(default_dataset_path)
|
|
return df_
|
|
else:
|
|
st.error("Default dataset not found. Please ensure the file exists in the 'input' directory.")
|
|
return None
|
|
|
|
@st.cache_data
|
|
def load_uploaded_dataset(uploaded_file):
|
|
df_ = pd.read_excel(uploaded_file)
|
|
return df_
|
|
|
|
def load_or_compute_embeddings(df, using_default_dataset, uploaded_file_name=None, text_columns=None):
|
|
"""
|
|
Loads pre-computed embeddings from a pickle file if they match current data,
|
|
otherwise computes and caches them.
|
|
"""
|
|
if not text_columns:
|
|
return None, None
|
|
|
|
base_name = "PRMS_2022_2023_2024_QAed" if using_default_dataset else "custom_dataset"
|
|
if uploaded_file_name:
|
|
base_name = os.path.splitext(uploaded_file_name)[0]
|
|
|
|
cols_key = "_".join(sorted(text_columns))
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
embeddings_dir = BASE_DIR
|
|
if using_default_dataset:
|
|
embeddings_file = os.path.join(embeddings_dir, f'{base_name}_{cols_key}.pkl')
|
|
else:
|
|
|
|
embeddings_file = os.path.join(embeddings_dir, f"{base_name}_{cols_key}.pkl")
|
|
|
|
df_fill = df.fillna("")
|
|
texts = df_fill[text_columns].astype(str).agg(' '.join, axis=1).tolist()
|
|
|
|
|
|
if ('embeddings' in st.session_state
|
|
and 'last_text_columns' in st.session_state
|
|
and st.session_state['last_text_columns'] == text_columns
|
|
and len(st.session_state['embeddings']) == len(texts)):
|
|
return st.session_state['embeddings'], st.session_state.get('embeddings_file', None)
|
|
|
|
|
|
if os.path.exists(embeddings_file):
|
|
with open(embeddings_file, 'rb') as f:
|
|
embeddings = pickle.load(f)
|
|
if len(embeddings) == len(texts):
|
|
st.write("Loaded pre-calculated embeddings.")
|
|
st.session_state['embeddings'] = embeddings
|
|
st.session_state['embeddings_file'] = embeddings_file
|
|
st.session_state['last_text_columns'] = text_columns
|
|
return embeddings, embeddings_file
|
|
|
|
|
|
st.write("Generating embeddings...")
|
|
model = get_embedding_model()
|
|
embeddings = generate_embeddings(texts, model)
|
|
with open(embeddings_file, 'wb') as f:
|
|
pickle.dump(embeddings, f)
|
|
|
|
st.session_state['embeddings'] = embeddings
|
|
st.session_state['embeddings_file'] = embeddings_file
|
|
st.session_state['last_text_columns'] = text_columns
|
|
return embeddings, embeddings_file
|
|
|
|
|
|
|
|
|
|
|
|
def reset_filters():
|
|
st.session_state['selected_additional_filters'] = {}
|
|
|
|
|
|
st.sidebar.radio("Select view", ["Automatic Mode", "Power User Mode"], key="view")
|
|
|
|
if st.session_state.view == "Power User Mode":
|
|
st.header("Power User Mode")
|
|
|
|
|
|
|
|
st.sidebar.title("Data Selection")
|
|
dataset_option = st.sidebar.selectbox('Select Dataset', ('PRMS 2022+2023+2024 QAed', 'Upload my dataset'))
|
|
|
|
if 'df' not in st.session_state:
|
|
st.session_state['df'] = pd.DataFrame()
|
|
if 'filtered_df' not in st.session_state:
|
|
st.session_state['filtered_df'] = pd.DataFrame()
|
|
|
|
if dataset_option == 'PRMS 2022+2023+2024 QAed':
|
|
default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
|
|
df = load_default_dataset(default_dataset_path)
|
|
if df is not None:
|
|
st.session_state['df'] = df.copy()
|
|
st.session_state['using_default_dataset'] = True
|
|
|
|
|
|
if 'filtered_df' not in st.session_state or st.session_state['filtered_df'].empty:
|
|
st.session_state['filtered_df'] = df.copy()
|
|
|
|
|
|
if 'filter_state' not in st.session_state:
|
|
st.session_state['filter_state'] = {
|
|
'applied': False,
|
|
'filters': {}
|
|
}
|
|
|
|
|
|
if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
|
|
default_text_cols = []
|
|
if 'Title' in df.columns and 'Description' in df.columns:
|
|
default_text_cols = ['Title', 'Description']
|
|
st.session_state['text_columns'] = default_text_cols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df_cols = df.columns.tolist()
|
|
|
|
|
|
st.subheader("Select Filters")
|
|
if 'additional_filters_selected' not in st.session_state:
|
|
st.session_state['additional_filters_selected'] = []
|
|
if 'filter_values' not in st.session_state:
|
|
st.session_state['filter_values'] = {}
|
|
|
|
with st.form("filter_selection_form"):
|
|
all_columns = df.columns.tolist()
|
|
selected_additional_cols = st.multiselect(
|
|
"Select columns from your dataset to use as filters:",
|
|
all_columns,
|
|
default=st.session_state['additional_filters_selected']
|
|
)
|
|
add_filters_submitted = st.form_submit_button("Add Additional Filters")
|
|
|
|
if add_filters_submitted:
|
|
if selected_additional_cols != st.session_state['additional_filters_selected']:
|
|
st.session_state['additional_filters_selected'] = selected_additional_cols
|
|
|
|
st.session_state['filter_values'] = {
|
|
k: v for k, v in st.session_state['filter_values'].items()
|
|
if k in selected_additional_cols
|
|
}
|
|
|
|
|
|
if st.session_state['additional_filters_selected']:
|
|
st.subheader("Apply Filters")
|
|
|
|
|
|
for col_name in st.session_state['additional_filters_selected']:
|
|
unique_vals = sorted(df[col_name].dropna().unique().tolist())
|
|
|
|
|
|
search_key = f"search_{col_name}"
|
|
if search_key not in st.session_state:
|
|
st.session_state[search_key] = ""
|
|
|
|
col1, col2 = st.columns([3, 1])
|
|
with col1:
|
|
search_term = st.text_input(
|
|
f"Search in {col_name}",
|
|
key=search_key,
|
|
help="Enter text to find and select all matching values"
|
|
)
|
|
with col2:
|
|
if st.button(f"Select Matching", key=f"select_{col_name}"):
|
|
|
|
if search_term:
|
|
matching_vals = [
|
|
val for val in unique_vals
|
|
if any(search_term.lower() in str(part).lower()
|
|
for part in (val.split(',') if isinstance(val, str) else [val]))
|
|
]
|
|
|
|
current_selected = st.session_state['filter_values'].get(col_name, [])
|
|
st.session_state['filter_values'][col_name] = list(set(current_selected + matching_vals))
|
|
|
|
|
|
if matching_vals:
|
|
st.success(f"Found and selected {len(matching_vals)} matching values")
|
|
else:
|
|
st.warning("No matching values found")
|
|
|
|
|
|
with st.form("apply_filters_form"):
|
|
for col_name in st.session_state['additional_filters_selected']:
|
|
unique_vals = sorted(df[col_name].dropna().unique().tolist())
|
|
selected_vals = st.multiselect(
|
|
f"Filter by {col_name}",
|
|
options=unique_vals,
|
|
default=st.session_state['filter_values'].get(col_name, [])
|
|
)
|
|
st.session_state['filter_values'][col_name] = selected_vals
|
|
|
|
|
|
col1, col2 = st.columns([1, 4])
|
|
with col1:
|
|
clear_filters = st.form_submit_button("Clear All")
|
|
with col2:
|
|
apply_filters_submitted = st.form_submit_button("Apply Filters to Dataset")
|
|
|
|
if clear_filters:
|
|
st.session_state['filter_values'] = {}
|
|
|
|
if 'summary_df' in st.session_state:
|
|
del st.session_state['summary_df']
|
|
if 'high_level_summary' in st.session_state:
|
|
del st.session_state['high_level_summary']
|
|
if 'enhanced_summary' in st.session_state:
|
|
del st.session_state['enhanced_summary']
|
|
st.rerun()
|
|
|
|
|
|
with st.expander("βοΈ Advanced Settings", expanded=False):
|
|
st.subheader("**Select Text Columns for Embedding**")
|
|
text_columns_selected = st.multiselect(
|
|
"Text Columns:",
|
|
df_cols,
|
|
default=st.session_state['text_columns'],
|
|
help="Choose columns containing text for semantic search and clustering. "
|
|
"If multiple are selected, their text will be concatenated."
|
|
)
|
|
st.session_state['text_columns'] = text_columns_selected
|
|
|
|
|
|
filtered_df = df.copy()
|
|
if 'apply_filters_submitted' in locals() and apply_filters_submitted:
|
|
|
|
if 'summary_df' in st.session_state:
|
|
del st.session_state['summary_df']
|
|
if 'high_level_summary' in st.session_state:
|
|
del st.session_state['high_level_summary']
|
|
if 'enhanced_summary' in st.session_state:
|
|
del st.session_state['enhanced_summary']
|
|
|
|
for col_name in st.session_state['additional_filters_selected']:
|
|
selected_vals = st.session_state['filter_values'].get(col_name, [])
|
|
if selected_vals:
|
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
|
|
st.success("Filters applied successfully!")
|
|
st.session_state['filtered_df'] = filtered_df.copy()
|
|
st.session_state['filter_state'] = {
|
|
'applied': True,
|
|
'filters': st.session_state['filter_values'].copy()
|
|
}
|
|
|
|
for k in ['clustered_data', 'topic_model', 'current_clustering_data',
|
|
'current_clustering_option', 'hierarchy']:
|
|
if k in st.session_state:
|
|
del st.session_state[k]
|
|
|
|
elif 'filter_state' in st.session_state and st.session_state['filter_state']['applied']:
|
|
|
|
for col_name, selected_vals in st.session_state['filter_state']['filters'].items():
|
|
if selected_vals:
|
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
|
|
st.session_state['filtered_df'] = filtered_df.copy()
|
|
|
|
|
|
if st.session_state['filtered_df'] is not None:
|
|
if st.session_state['filter_state']['applied']:
|
|
st.write("Filtered Data Preview:")
|
|
else:
|
|
st.write("Current Data Preview:")
|
|
st.dataframe(st.session_state['filtered_df'].head(), hide_index=True)
|
|
st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
|
|
|
|
output = io.BytesIO()
|
|
writer = pd.ExcelWriter(output, engine='openpyxl')
|
|
st.session_state['filtered_df'].to_excel(writer, index=False)
|
|
writer.close()
|
|
processed_data = output.getvalue()
|
|
|
|
st.download_button(
|
|
label="Download Current Data",
|
|
data=processed_data,
|
|
file_name='data.xlsx',
|
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
)
|
|
else:
|
|
st.warning("Please ensure the default dataset exists in the 'input' directory.")
|
|
|
|
else:
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload your Excel file", type=["xlsx"])
|
|
if uploaded_file is not None:
|
|
df = load_uploaded_dataset(uploaded_file)
|
|
if df is not None:
|
|
st.session_state['df'] = df.copy()
|
|
st.session_state['using_default_dataset'] = False
|
|
st.session_state['uploaded_file_name'] = uploaded_file.name
|
|
st.write("Data preview:")
|
|
st.write(df.head())
|
|
df_cols = df.columns.tolist()
|
|
|
|
st.subheader("**Select Text Columns for Embedding**")
|
|
text_columns_selected = st.multiselect(
|
|
"Text Columns:",
|
|
df_cols,
|
|
default=df_cols[:1] if df_cols else []
|
|
)
|
|
st.session_state['text_columns'] = text_columns_selected
|
|
|
|
st.write("**Additional Filters**")
|
|
selected_additional_cols = st.multiselect(
|
|
"Select additional columns from your dataset to use as filters:",
|
|
df_cols,
|
|
default=[]
|
|
)
|
|
st.session_state['additional_filters_selected'] = selected_additional_cols
|
|
|
|
filtered_df = df.copy()
|
|
for col_name in selected_additional_cols:
|
|
if f'selected_filter_{col_name}' not in st.session_state:
|
|
st.session_state[f'selected_filter_{col_name}'] = []
|
|
unique_vals = sorted(df[col_name].dropna().unique().tolist())
|
|
selected_vals = st.multiselect(
|
|
f"Filter by {col_name}",
|
|
options=unique_vals,
|
|
default=st.session_state[f'selected_filter_{col_name}']
|
|
)
|
|
st.session_state[f'selected_filter_{col_name}'] = selected_vals
|
|
if selected_vals:
|
|
filtered_df = filtered_df[filtered_df[col_name].isin(selected_vals)]
|
|
|
|
st.session_state['filtered_df'] = filtered_df
|
|
st.write("Filtered Data Preview:")
|
|
st.dataframe(filtered_df.head(), hide_index=True)
|
|
st.write(f"Total number of results: {len(filtered_df)}")
|
|
|
|
output = io.BytesIO()
|
|
writer = pd.ExcelWriter(output, engine='openpyxl')
|
|
filtered_df.to_excel(writer, index=False)
|
|
writer.close()
|
|
processed_data = output.getvalue()
|
|
|
|
st.download_button(
|
|
label="Download Filtered Data",
|
|
data=processed_data,
|
|
file_name='filtered_data.xlsx',
|
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
)
|
|
else:
|
|
st.warning("Failed to load the uploaded dataset.")
|
|
else:
|
|
st.warning("Please upload an Excel file to proceed.")
|
|
|
|
if 'filtered_df' in st.session_state:
|
|
st.write(f"Total number of results: {len(st.session_state['filtered_df'])}")
|
|
|
|
|
|
|
|
|
|
|
|
if 'active_tab_index' not in st.session_state:
|
|
st.session_state.active_tab_index = 0
|
|
|
|
tabs_titles = ["Semantic Search", "Clustering", "Summarization", "Chat", "Help"]
|
|
tabs = st.tabs(tabs_titles)
|
|
|
|
tab_semantic, tab_clustering, tab_summarization, tab_chat, tab_help = tabs
|
|
|
|
|
|
|
|
|
|
with tab_help:
|
|
st.header("Help")
|
|
st.markdown("""
|
|
### About SNAP
|
|
|
|
SNAP allows you to explore, filter, search, cluster, and summarize textual datasets.
|
|
|
|
**Workflow**:
|
|
1. **Data Selection (Sidebar)**: Choose the default dataset or upload your own.
|
|
2. **Filtering**: Set additional filters for your dataset.
|
|
3. **Select Text Columns**: Which columns to embed.
|
|
4. **Semantic Search** (Tab): Provide a query and threshold to find relevant documents.
|
|
5. **Clustering** (Tab): Group documents into topics.
|
|
6. **Summarization** (Tab): Summarize the clustered documents (with optional references).
|
|
|
|
### Troubleshooting
|
|
- If you see no results, try lowering the similarity threshold or removing negative/required keywords.
|
|
- Ensure you have at least one text column selected for embeddings.
|
|
""")
|
|
|
|
|
|
|
|
|
|
with tab_semantic:
|
|
st.header("Semantic Search")
|
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
|
|
text_columns = st.session_state.get('text_columns', [])
|
|
if not text_columns:
|
|
st.warning("No text columns selected. Please select at least one column for text embedding.")
|
|
else:
|
|
df_full = st.session_state['df']
|
|
|
|
embeddings, _ = load_or_compute_embeddings(
|
|
df_full,
|
|
st.session_state.get('using_default_dataset', False),
|
|
st.session_state.get('uploaded_file_name'),
|
|
text_columns
|
|
)
|
|
|
|
if embeddings is not None:
|
|
with st.expander("βΉοΈ How Semantic Search Works", expanded=False):
|
|
st.markdown("""
|
|
### Understanding Semantic Search
|
|
|
|
Unlike traditional keyword search that looks for exact matches, semantic search understands the meaning and context of your query. Here's how it works:
|
|
|
|
1. **Query Processing**:
|
|
- Your search query is converted into a numerical representation (embedding) that captures its meaning
|
|
- Example: Searching for "Climate Smart Villages" will understand the concept, not just the words
|
|
- Related terms like "sustainable communities", "resilient farming", or "agricultural adaptation" might be found even if they don't contain the exact words
|
|
|
|
2. **Similarity Matching**:
|
|
- Documents are ranked by how closely their meaning matches your query
|
|
- The similarity threshold controls how strict this matching is
|
|
- Higher threshold (e.g., 0.8) = more precise but fewer results
|
|
- Lower threshold (e.g., 0.3) = more results but might be less relevant
|
|
|
|
3. **Advanced Features**:
|
|
- **Negative Keywords**: Use to explicitly exclude documents containing certain terms
|
|
- **Required Keywords**: Ensure specific terms appear in the results
|
|
- These work as traditional keyword filters after the semantic search
|
|
|
|
### Search Tips
|
|
|
|
- **Phrase Queries**: Enter complete phrases for better context
|
|
- "Climate Smart Villages" (as one concept)
|
|
- Better than separate terms: "climate", "smart", "villages"
|
|
|
|
- **Descriptive Queries**: Add context for better results
|
|
- Instead of: "water"
|
|
- Better: "water management in agriculture"
|
|
|
|
- **Conceptual Queries**: Focus on concepts rather than specific terms
|
|
- Instead of: "increased yield"
|
|
- Better: "agricultural productivity improvements"
|
|
|
|
### Example Searches
|
|
|
|
1. **Query**: "Climate Smart Villages"
|
|
- Will find: Documents about climate-resilient communities, adaptive farming practices, sustainable village development
|
|
- Even if they don't use these exact words
|
|
|
|
2. **Query**: "Gender equality in agriculture"
|
|
- Will find: Women's empowerment in farming, female farmer initiatives, gender-inclusive rural development
|
|
- Related concepts are captured semantically
|
|
|
|
3. **Query**: "Sustainable water management"
|
|
+ Required keyword: "irrigation"
|
|
- Combines semantic understanding of water sustainability with specific irrigation focus
|
|
""")
|
|
|
|
with st.form("search_parameters"):
|
|
query = st.text_input("Enter your search query:")
|
|
include_keywords = st.text_input("Include only documents containing these words (comma-separated):")
|
|
similarity_threshold = st.slider("Similarity threshold", 0.0, 1.0, 0.35)
|
|
submitted = st.form_submit_button("Search")
|
|
|
|
if submitted:
|
|
if query.strip():
|
|
with st.spinner("Performing Semantic Search..."):
|
|
|
|
if 'summary_df' in st.session_state:
|
|
del st.session_state['summary_df']
|
|
if 'high_level_summary' in st.session_state:
|
|
del st.session_state['high_level_summary']
|
|
if 'enhanced_summary' in st.session_state:
|
|
del st.session_state['enhanced_summary']
|
|
|
|
model = get_embedding_model()
|
|
df_filtered = st.session_state['filtered_df'].fillna("")
|
|
search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
|
|
|
|
|
|
subset_indices = df_filtered.index
|
|
subset_embeddings = embeddings[subset_indices]
|
|
|
|
query_embedding = model.encode([query], device=device)
|
|
similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
|
|
|
|
|
|
fig = px.histogram(
|
|
x=similarities,
|
|
nbins=30,
|
|
labels={'x': 'Similarity Score', 'y': 'Number of Documents'},
|
|
title='Distribution of Similarity Scores'
|
|
)
|
|
fig.add_vline(
|
|
x=similarity_threshold,
|
|
line_dash="dash",
|
|
line_color="red",
|
|
annotation_text=f"Threshold: {similarity_threshold:.2f}",
|
|
annotation_position="top"
|
|
)
|
|
st.write("### Similarity Score Distribution")
|
|
st.plotly_chart(fig)
|
|
|
|
above_threshold_indices = np.where(similarities > similarity_threshold)[0]
|
|
if len(above_threshold_indices) == 0:
|
|
st.warning("No results found above the similarity threshold.")
|
|
if 'search_results' in st.session_state:
|
|
del st.session_state['search_results']
|
|
else:
|
|
selected_indices = subset_indices[above_threshold_indices]
|
|
results = df_filtered.loc[selected_indices].copy()
|
|
results['similarity_score'] = similarities[above_threshold_indices]
|
|
results.sort_values(by='similarity_score', ascending=False, inplace=True)
|
|
|
|
|
|
if include_keywords.strip():
|
|
inc_words = [w.strip().lower() for w in include_keywords.split(',') if w.strip()]
|
|
if inc_words:
|
|
results = results[
|
|
results.apply(
|
|
lambda row: all(
|
|
w in (' '.join(row.astype(str)).lower()) for w in inc_words
|
|
),
|
|
axis=1
|
|
)
|
|
]
|
|
|
|
if results.empty:
|
|
st.warning("No results found after applying keyword filters.")
|
|
if 'search_results' in st.session_state:
|
|
del st.session_state['search_results']
|
|
else:
|
|
st.session_state['search_results'] = results.copy()
|
|
output = io.BytesIO()
|
|
writer = pd.ExcelWriter(output, engine='openpyxl')
|
|
results.to_excel(writer, index=False)
|
|
writer.close()
|
|
processed_data = output.getvalue()
|
|
st.session_state['search_results_processed_data'] = processed_data
|
|
else:
|
|
st.warning("Please enter a query to search.")
|
|
|
|
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
|
|
st.write("## Search Results")
|
|
results = st.session_state['search_results']
|
|
cols_to_display = [c for c in results.columns if c != 'similarity_score'] + ['similarity_score']
|
|
st.dataframe(results[cols_to_display], hide_index=True)
|
|
st.write(f"Total number of results: {len(results)}")
|
|
|
|
if 'search_results_processed_data' in st.session_state:
|
|
st.download_button(
|
|
label="Download Full Results",
|
|
data=st.session_state['search_results_processed_data'],
|
|
file_name='search_results.xlsx',
|
|
mime='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
|
key='download_search_results'
|
|
)
|
|
else:
|
|
st.info("No search results to display. Enter a query and click 'Search'.")
|
|
else:
|
|
st.warning("No embeddings available because no text columns were chosen.")
|
|
else:
|
|
st.warning("Filtered dataset is empty or not loaded. Please adjust your filters or upload data.")
|
|
|
|
|
|
|
|
|
|
|
|
with tab_clustering:
|
|
st.header("Clustering")
|
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
|
|
|
|
with st.expander("βΉοΈ How Clustering Works", expanded=False):
|
|
st.markdown("""
|
|
### Understanding Document Clustering
|
|
|
|
Clustering automatically groups similar documents together, helping you discover patterns and themes in your data. Here's how it works:
|
|
|
|
1. **Cluster Formation**:
|
|
- Documents are grouped based on their semantic similarity
|
|
- Each cluster represents a distinct theme or topic
|
|
- Documents that are too different from others may remain unclustered (labeled as -1)
|
|
- The "Min Cluster Size" parameter controls how clusters are formed
|
|
|
|
2. **Interpreting Results**:
|
|
- Each cluster is assigned a number (e.g., 0, 1, 2...)
|
|
- Cluster -1 contains "outlier" documents that didn't fit well in other clusters
|
|
- The size of each cluster indicates how common that theme is
|
|
- Keywords for each cluster show the main topics/concepts
|
|
|
|
3. **Visualizations**:
|
|
- **Intertopic Distance Map**: Shows how clusters relate to each other
|
|
- Closer clusters are more semantically similar
|
|
- Size of circles indicates number of documents
|
|
- Hover to see top terms for each cluster
|
|
|
|
- **Topic Document Visualization**: Shows individual documents
|
|
- Each point is a document
|
|
- Colors indicate cluster membership
|
|
- Distance between points shows similarity
|
|
|
|
- **Topic Hierarchy**: Shows how topics are related
|
|
- Tree structure shows topic relationships
|
|
- Parent topics contain broader themes
|
|
- Child topics show more specific sub-themes
|
|
|
|
### How to Use Clusters
|
|
|
|
1. **Exploration**:
|
|
- Use clusters to discover main themes in your data
|
|
- Look for unexpected groupings that might reveal insights
|
|
- Identify outliers that might need special attention
|
|
|
|
2. **Analysis**:
|
|
- Compare cluster sizes to understand theme distribution
|
|
- Examine keywords to understand what defines each cluster
|
|
- Use hierarchy to see how themes are nested
|
|
|
|
3. **Practical Applications**:
|
|
- Generate summaries for specific clusters
|
|
- Focus detailed analysis on clusters of interest
|
|
- Use clusters to organize and categorize documents
|
|
- Identify gaps or overlaps in your dataset
|
|
|
|
### Tips for Better Results
|
|
|
|
- **Adjust Min Cluster Size**:
|
|
- Larger values (15-20): Fewer, broader clusters
|
|
- Smaller values (2-5): More specific, smaller clusters
|
|
- Balance between too many small clusters and too few large ones
|
|
|
|
- **Choose Data Wisely**:
|
|
- Cluster full dataset for overall themes
|
|
- Cluster search results for focused analysis
|
|
- More documents generally give better clusters
|
|
|
|
- **Interpret with Context**:
|
|
- Consider your domain knowledge
|
|
- Look for patterns across multiple visualizations
|
|
- Use cluster insights to guide further analysis
|
|
""")
|
|
|
|
df_to_cluster = None
|
|
|
|
|
|
with st.form("clustering_form"):
|
|
st.subheader("Clustering Settings")
|
|
|
|
|
|
clustering_option = st.radio(
|
|
"Select data for clustering:",
|
|
('Full Dataset', 'Filtered Dataset', 'Semantic Search Results')
|
|
)
|
|
|
|
|
|
min_cluster_size_val = st.slider(
|
|
"Min Cluster Size",
|
|
min_value=2,
|
|
max_value=50,
|
|
value=st.session_state.get('min_cluster_size', 5),
|
|
help="Minimum size of each cluster in HDBSCAN; In other words, it's the minimum number of documents/texts that must be grouped together to form a valid cluster.\n\n- A larger value (e.g., 20) will result in fewer, larger clusters\n- A smaller value (e.g., 2-5) will allow for more clusters, including smaller ones\n- Documents that don't fit into any cluster meeting this minimum size requirement are labeled as noise (typically assigned to cluster -1)"
|
|
)
|
|
|
|
run_clustering = st.form_submit_button("Run Clustering")
|
|
|
|
if run_clustering:
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.session_state['min_cluster_size'] = min_cluster_size_val
|
|
|
|
|
|
if clustering_option == 'Semantic Search Results':
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
|
|
df_to_cluster = st.session_state['search_results'].copy()
|
|
else:
|
|
st.warning("No semantic search results found. Please run a search first.")
|
|
elif clustering_option == 'Filtered Dataset':
|
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
|
|
df_to_cluster = st.session_state['filtered_df'].copy()
|
|
else:
|
|
st.warning("Filtered dataset is empty. Please check your filters.")
|
|
else:
|
|
if 'df' in st.session_state and not st.session_state['df'].empty:
|
|
df_to_cluster = st.session_state['df'].copy()
|
|
|
|
text_columns = st.session_state.get('text_columns', [])
|
|
if not text_columns:
|
|
st.warning("No text columns selected. Please select text columns to embed before clustering.")
|
|
else:
|
|
|
|
df_full = st.session_state['df']
|
|
embeddings, _ = load_or_compute_embeddings(
|
|
df_full,
|
|
st.session_state.get('using_default_dataset', False),
|
|
st.session_state.get('uploaded_file_name'),
|
|
text_columns
|
|
)
|
|
|
|
if df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering:
|
|
with st.spinner("Performing clustering..."):
|
|
|
|
if 'summary_df' in st.session_state:
|
|
del st.session_state['summary_df']
|
|
if 'high_level_summary' in st.session_state:
|
|
del st.session_state['high_level_summary']
|
|
if 'enhanced_summary' in st.session_state:
|
|
del st.session_state['enhanced_summary']
|
|
|
|
dfc = df_to_cluster.copy().fillna("")
|
|
dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
|
|
|
|
|
|
selected_indices = dfc.index
|
|
embeddings_clustering = embeddings[selected_indices]
|
|
|
|
|
|
stop_words = set(stopwords.words('english'))
|
|
texts_cleaned = []
|
|
for text in dfc['text'].tolist():
|
|
try:
|
|
|
|
try:
|
|
word_tokens = word_tokenize(text)
|
|
except LookupError:
|
|
|
|
nltk.download('punkt_tab', quiet=False)
|
|
word_tokens = word_tokenize(text)
|
|
except Exception as e:
|
|
|
|
st.warning(f"Using fallback tokenization due to error: {e}")
|
|
word_tokens = text.split()
|
|
|
|
filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
|
|
texts_cleaned.append(filtered_text)
|
|
except Exception as e:
|
|
st.error(f"Error processing text: {e}")
|
|
|
|
texts_cleaned.append(text)
|
|
|
|
try:
|
|
|
|
if len(texts_cleaned) < min_cluster_size_val:
|
|
st.error(f"Not enough documents to form clusters. You have {len(texts_cleaned)} documents but minimum cluster size is set to {min_cluster_size_val}.")
|
|
st.session_state['clustering_error'] = "Insufficient documents for clustering"
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.stop()
|
|
|
|
|
|
if torch.is_tensor(embeddings_clustering):
|
|
embeddings_for_clustering = embeddings_clustering.cpu().numpy()
|
|
else:
|
|
embeddings_for_clustering = embeddings_clustering
|
|
|
|
|
|
if embeddings_for_clustering.shape[0] != len(texts_cleaned):
|
|
st.error("Mismatch between number of embeddings and texts.")
|
|
st.session_state['clustering_error'] = "Embedding and text count mismatch"
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.stop()
|
|
|
|
|
|
try:
|
|
hdbscan_model = HDBSCAN(
|
|
min_cluster_size=min_cluster_size_val,
|
|
metric='euclidean',
|
|
cluster_selection_method='eom'
|
|
)
|
|
|
|
|
|
topic_model = BERTopic(
|
|
embedding_model=get_embedding_model(),
|
|
hdbscan_model=hdbscan_model
|
|
)
|
|
|
|
|
|
topics, probs = topic_model.fit_transform(
|
|
texts_cleaned,
|
|
embeddings=embeddings_for_clustering
|
|
)
|
|
|
|
|
|
unique_topics = set(topics)
|
|
if len(unique_topics) < 2:
|
|
st.warning("Clustering resulted in too few clusters. Retry or try reducing the minimum cluster size.")
|
|
if -1 in unique_topics:
|
|
non_noise_docs = sum(1 for t in topics if t != -1)
|
|
st.info(f"Only {non_noise_docs} documents were assigned to clusters. The rest were marked as noise (-1).")
|
|
if non_noise_docs < min_cluster_size_val:
|
|
st.error("Not enough documents were successfully clustered. Try reducing the minimum cluster size.")
|
|
st.session_state['clustering_error'] = "Insufficient clustered documents"
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.stop()
|
|
|
|
|
|
dfc['Topic'] = topics
|
|
st.session_state['topic_model'] = topic_model
|
|
st.session_state['clustered_data'] = dfc.copy()
|
|
st.session_state['clustering_texts_cleaned'] = texts_cleaned
|
|
st.session_state['clustering_embeddings'] = embeddings_for_clustering
|
|
st.session_state['clustering_completed'] = True
|
|
|
|
|
|
try:
|
|
st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
|
|
except Exception as viz_error:
|
|
st.warning("Could not generate topic visualization. This usually happens when there are too few total clusters. Try adjusting the minimum cluster size or adding more documents.")
|
|
st.session_state['intertopic_distance_fig'] = None
|
|
|
|
try:
|
|
st.session_state['topic_document_fig'] = topic_model.visualize_documents(
|
|
texts_cleaned,
|
|
embeddings=embeddings_for_clustering
|
|
)
|
|
except Exception as viz_error:
|
|
st.warning("Could not generate document visualization. This might happen when the clustering results are not optimal. Try adjusting the clustering parameters.")
|
|
st.session_state['topic_document_fig'] = None
|
|
|
|
try:
|
|
hierarchy = topic_model.hierarchical_topics(texts_cleaned)
|
|
st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
|
|
st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
|
|
except Exception as viz_error:
|
|
st.warning("Could not generate topic hierarchy visualization. This usually happens when there aren't enough distinct topics to form a hierarchy.")
|
|
st.session_state['hierarchy'] = pd.DataFrame()
|
|
st.session_state['hierarchy_fig'] = None
|
|
|
|
except ValueError as ve:
|
|
if "zero-size array to reduction operation maximum which has no identity" in str(ve):
|
|
st.error("Clustering failed: No valid clusters could be formed. Try reducing the minimum cluster size.")
|
|
elif "Cannot use scipy.linalg.eigh for sparse A with k > N" in str(ve):
|
|
st.error("Clustering failed: Too many components requested for the number of documents. Try with more documents or adjust clustering parameters.")
|
|
else:
|
|
st.error(f"Clustering error: {str(ve)}")
|
|
st.session_state['clustering_error'] = str(ve)
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.stop()
|
|
|
|
except Exception as e:
|
|
st.error(f"An error occurred during clustering: {str(e)}")
|
|
st.session_state['clustering_error'] = str(e)
|
|
st.session_state['clustering_completed'] = False
|
|
st.session_state.active_tab_index = tabs_titles.index("Clustering")
|
|
st.stop()
|
|
|
|
|
|
if st.session_state.get('clustering_completed', False):
|
|
st.subheader("Topic Overview")
|
|
dfc = st.session_state['clustered_data']
|
|
topic_model = st.session_state['topic_model']
|
|
topics = dfc['Topic'].tolist()
|
|
|
|
unique_topics = sorted(list(set(topics)))
|
|
cluster_info = []
|
|
for t in unique_topics:
|
|
cluster_docs = dfc[dfc['Topic'] == t]
|
|
count = len(cluster_docs)
|
|
top_words = topic_model.get_topic(t)
|
|
if top_words:
|
|
top_keywords = ", ".join([w[0] for w in top_words[:5]])
|
|
else:
|
|
top_keywords = "N/A"
|
|
cluster_info.append((t, count, top_keywords))
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
|
|
|
|
st.write("### Topic Overview")
|
|
st.dataframe(
|
|
cluster_df,
|
|
column_config={
|
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
|
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
|
|
"Top Keywords": st.column_config.TextColumn(
|
|
"Top Keywords",
|
|
help="Top 5 keywords that characterize this topic"
|
|
)
|
|
},
|
|
hide_index=True
|
|
)
|
|
|
|
st.subheader("Clustering Results")
|
|
columns_to_display = [c for c in dfc.columns if c != 'text']
|
|
st.dataframe(dfc[columns_to_display], hide_index=True)
|
|
|
|
|
|
st.write("### Intertopic Distance Map")
|
|
if st.session_state.get('intertopic_distance_fig') is not None:
|
|
try:
|
|
st.plotly_chart(st.session_state['intertopic_distance_fig'])
|
|
except Exception:
|
|
st.info("Topic visualization is not available for the current clustering results.")
|
|
|
|
st.write("### Topic Document Visualization")
|
|
if st.session_state.get('topic_document_fig') is not None:
|
|
try:
|
|
st.plotly_chart(st.session_state['topic_document_fig'])
|
|
except Exception:
|
|
st.info("Document visualization is not available for the current clustering results.")
|
|
|
|
st.write("### Topic Hierarchy")
|
|
if st.session_state.get('hierarchy_fig') is not None:
|
|
try:
|
|
st.plotly_chart(st.session_state['hierarchy_fig'])
|
|
except Exception:
|
|
st.info("Topic hierarchy visualization is not available for the current clustering results.")
|
|
if not (df_to_cluster is not None and embeddings is not None and not df_to_cluster.empty and run_clustering):
|
|
pass
|
|
else:
|
|
st.warning("Please select or upload a dataset and filter as needed.")
|
|
|
|
|
|
|
|
|
|
|
|
with tab_summarization:
|
|
st.header("Summarization")
|
|
|
|
with st.expander("βΉοΈ How Summarization Works", expanded=False):
|
|
st.markdown("""
|
|
### Understanding Document Summarization
|
|
|
|
Summarization condenses multiple documents into concise, meaningful summaries while preserving key information. Here's how it works:
|
|
|
|
1. **Summary Generation**:
|
|
- Documents are processed using advanced language models
|
|
- Key themes and important points are identified
|
|
- Content is condensed while maintaining context
|
|
- Both high-level and cluster-specific summaries are available
|
|
|
|
2. **Reference System**:
|
|
- Summaries can include references to source documents
|
|
- References are shown as [ID] or as clickable links
|
|
- Each statement can be traced back to its source
|
|
- Helps maintain accountability and verification
|
|
|
|
3. **Types of Summaries**:
|
|
- **High-Level Summary**: Overview of all selected documents
|
|
- Captures main themes across the entire selection
|
|
- Ideal for quick understanding of large document sets
|
|
- Shows relationships between different topics
|
|
|
|
- **Cluster-Specific Summaries**: Focused on each cluster
|
|
- More detailed for specific themes
|
|
- Shows unique aspects of each cluster
|
|
- Helps understand sub-topics in depth
|
|
|
|
### How to Use Summaries
|
|
|
|
1. **Configuration**:
|
|
- Choose between all clusters or specific ones
|
|
- Set temperature for creativity vs. consistency
|
|
- Adjust max tokens for summary length
|
|
- Enable/disable reference system
|
|
|
|
2. **Reference Options**:
|
|
- Select column for reference IDs
|
|
- Add hyperlinks to references
|
|
- Choose URL column for clickable links
|
|
- References help track information sources
|
|
|
|
3. **Practical Applications**:
|
|
- Quick overview of large datasets
|
|
- Detailed analysis of specific themes
|
|
- Evidence-based reporting with references
|
|
- Compare different document groups
|
|
|
|
### Tips for Better Results
|
|
|
|
- **Temperature Setting**:
|
|
- Higher (0.7-1.0): More creative, varied summaries
|
|
- Lower (0.1-0.3): More consistent, conservative summaries
|
|
- Balance based on your needs for creativity vs. consistency
|
|
|
|
- **Token Length**:
|
|
- Longer limits: More detailed summaries
|
|
- Shorter limits: More concise, focused summaries
|
|
- Adjust based on document complexity
|
|
|
|
- **Reference Usage**:
|
|
- Enable references for traceability
|
|
- Use hyperlinks for easy navigation
|
|
- Choose meaningful reference columns
|
|
- Helps validate summary accuracy
|
|
|
|
### Best Practices
|
|
|
|
1. **For General Overview**:
|
|
- Use high-level summary
|
|
- Keep temperature moderate (0.5-0.7)
|
|
- Enable references for verification
|
|
- Focus on broader themes
|
|
|
|
2. **For Detailed Analysis**:
|
|
- Use cluster-specific summaries
|
|
- Adjust temperature based on need
|
|
- Include references with hyperlinks
|
|
- Look for patterns within clusters
|
|
|
|
3. **For Reporting**:
|
|
- Combine both summary types
|
|
- Use references extensively
|
|
- Balance detail and brevity
|
|
- Ensure source traceability
|
|
""")
|
|
|
|
df_summ = None
|
|
|
|
if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
|
|
df_summ = st.session_state['clustered_data']
|
|
elif 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
|
|
df_summ = st.session_state['filtered_df']
|
|
else:
|
|
st.warning("No data available for summarization. Please cluster first or have some filtered data.")
|
|
|
|
if df_summ is not None and not df_summ.empty:
|
|
text_columns = st.session_state.get('text_columns', [])
|
|
if not text_columns:
|
|
st.warning("No text columns selected. Please select columns for text embedding first.")
|
|
else:
|
|
if 'Topic' not in df_summ.columns or 'topic_model' not in st.session_state:
|
|
st.warning("No 'Topic' column found. Summaries per cluster are only available if you've run clustering.")
|
|
else:
|
|
topic_model = st.session_state['topic_model']
|
|
df_summ['text'] = df_summ.fillna("").astype(str)[text_columns].agg(' '.join, axis=1)
|
|
|
|
|
|
topics = sorted(df_summ['Topic'].unique())
|
|
cluster_info = []
|
|
for t in topics:
|
|
cluster_docs = df_summ[df_summ['Topic'] == t]
|
|
count = len(cluster_docs)
|
|
top_words = topic_model.get_topic(t)
|
|
if top_words:
|
|
top_keywords = ", ".join([w[0] for w in top_words[:5]])
|
|
else:
|
|
top_keywords = "N/A"
|
|
cluster_info.append((t, count, top_keywords))
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
|
|
|
|
|
|
if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
|
|
summary_df = st.session_state['summary_df']
|
|
|
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
|
|
|
|
cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
|
|
|
|
cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
|
|
|
|
st.write("### Available Clusters:")
|
|
st.dataframe(
|
|
cluster_df,
|
|
column_config={
|
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
|
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
|
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
|
|
"Top Keywords": st.column_config.TextColumn(
|
|
"Top Keywords",
|
|
help="Top 5 keywords that characterize this topic"
|
|
)
|
|
},
|
|
hide_index=True
|
|
)
|
|
|
|
|
|
st.subheader("Summarization Settings")
|
|
|
|
summary_scope = st.radio(
|
|
"Generate summaries for:",
|
|
["All clusters", "Specific clusters"]
|
|
)
|
|
if summary_scope == "Specific clusters":
|
|
|
|
if 'Cluster_Name' in cluster_df.columns:
|
|
topic_options = [f"Cluster {t} - {name}" for t, name in zip(cluster_df['Topic'], cluster_df['Cluster_Name'])]
|
|
topic_to_id = {opt: t for opt, t in zip(topic_options, cluster_df['Topic'])}
|
|
selected_topic_options = st.multiselect("Select clusters to summarize", topic_options)
|
|
selected_topics = [topic_to_id[opt] for opt in selected_topic_options]
|
|
else:
|
|
selected_topics = st.multiselect("Select clusters to summarize", topics)
|
|
else:
|
|
selected_topics = topics
|
|
|
|
|
|
default_system_prompt = """You are an expert summarizer skilled in creating concise and relevant summaries.
|
|
You will be given text and an objective context. Please produce a clear, cohesive,
|
|
and thematically relevant summary.
|
|
Focus on key points, insights, or patterns that emerge from the text."""
|
|
|
|
if 'system_prompt' not in st.session_state:
|
|
st.session_state['system_prompt'] = default_system_prompt
|
|
|
|
with st.expander("π§ Advanced Settings", expanded=False):
|
|
st.markdown("""
|
|
### System Prompt Configuration
|
|
|
|
The system prompt guides the AI in how to generate summaries. You can customize it to better suit your needs:
|
|
- Be specific about the style and focus you want
|
|
- Add domain-specific context if needed
|
|
- Include any special formatting requirements
|
|
""")
|
|
|
|
system_prompt = st.text_area(
|
|
"Customize System Prompt",
|
|
value=st.session_state['system_prompt'],
|
|
height=150,
|
|
help="This prompt guides the AI in how to generate summaries. Edit it to customize the summary style and focus."
|
|
)
|
|
|
|
if st.button("Reset to Default"):
|
|
system_prompt = default_system_prompt
|
|
st.session_state['system_prompt'] = default_system_prompt
|
|
|
|
st.markdown("### Generation Parameters")
|
|
temperature = st.slider(
|
|
"Temperature",
|
|
0.0, 1.0, 0.7,
|
|
help="Higher values (0.7-1.0) make summaries more creative but less predictable. Lower values (0.1-0.3) make them more focused and consistent."
|
|
)
|
|
max_tokens = st.slider(
|
|
"Max Tokens",
|
|
100, 3000, 1000,
|
|
help="Maximum length of generated summaries. Higher values allow for more detailed summaries but take longer to generate."
|
|
)
|
|
|
|
st.session_state['system_prompt'] = system_prompt
|
|
|
|
st.write("### Enhanced Summary References")
|
|
st.write("Select columns for references (optional).")
|
|
all_cols = [c for c in df_summ.columns if c not in ['text', 'Topic', 'similarity_score']]
|
|
|
|
|
|
if 'reference_id_column' not in st.session_state:
|
|
st.session_state.reference_id_column = all_cols[0] if all_cols else None
|
|
|
|
url_guess = next((c for c in all_cols if 'url' in c.lower() or 'link' in c.lower()), None)
|
|
if 'url_column' not in st.session_state:
|
|
st.session_state.url_column = url_guess
|
|
|
|
enable_references = st.checkbox(
|
|
"Enable references in summaries",
|
|
value=True,
|
|
help="Add source references to the final summary text."
|
|
)
|
|
reference_id_column = st.selectbox(
|
|
"Select column to use as reference ID:",
|
|
all_cols,
|
|
index=all_cols.index(st.session_state.reference_id_column) if st.session_state.reference_id_column in all_cols else 0
|
|
)
|
|
add_hyperlinks = st.checkbox(
|
|
"Add hyperlinks to references",
|
|
value=True,
|
|
help="If the reference column has a matching URL, make it clickable."
|
|
)
|
|
url_column = None
|
|
if add_hyperlinks:
|
|
url_column = st.selectbox(
|
|
"Select column containing URLs:",
|
|
all_cols,
|
|
index=all_cols.index(st.session_state.url_column) if (st.session_state.url_column in all_cols) else 0
|
|
)
|
|
|
|
|
|
if st.button("Generate Summaries"):
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
|
if not openai_api_key:
|
|
st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
|
else:
|
|
|
|
st.session_state['_summarization_button_clicked'] = True
|
|
|
|
llm = ChatOpenAI(
|
|
api_key=openai_api_key,
|
|
model_name='gpt-4o-mini',
|
|
temperature=temperature,
|
|
max_tokens=max_tokens
|
|
)
|
|
|
|
|
|
if selected_topics:
|
|
df_scope = df_summ[df_summ['Topic'].isin(selected_topics)]
|
|
else:
|
|
st.warning("No topics selected for summarization.")
|
|
df_scope = pd.DataFrame()
|
|
|
|
if df_scope.empty:
|
|
st.warning("No documents match the selected topics for summarization.")
|
|
else:
|
|
all_texts = df_scope['text'].tolist()
|
|
combined_text = " ".join(all_texts)
|
|
if not combined_text.strip():
|
|
st.warning("No text data available for summarization.")
|
|
else:
|
|
|
|
local_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
|
|
local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
|
|
local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
|
|
|
|
|
|
|
|
unique_selected_topics = df_scope['Topic'].unique()
|
|
if len(unique_selected_topics) > 1:
|
|
st.write("### Summaries per Selected Cluster")
|
|
|
|
|
|
with st.spinner("Generating cluster summaries in parallel..."):
|
|
summaries = process_summaries_in_parallel(
|
|
df_scope=df_scope,
|
|
unique_selected_topics=unique_selected_topics,
|
|
llm=llm,
|
|
chat_prompt=local_chat_prompt,
|
|
enable_references=enable_references,
|
|
reference_id_column=reference_id_column,
|
|
url_column=url_column if add_hyperlinks else None,
|
|
max_workers=min(16, len(unique_selected_topics))
|
|
)
|
|
|
|
if summaries:
|
|
summary_df = pd.DataFrame(summaries)
|
|
|
|
st.session_state['summary_df'] = summary_df
|
|
|
|
st.session_state['has_references'] = enable_references
|
|
st.session_state['reference_id_column'] = reference_id_column
|
|
st.session_state['url_column'] = url_column if add_hyperlinks else None
|
|
|
|
|
|
if 'Cluster_Name' in summary_df.columns:
|
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
|
|
cluster_df['Cluster_Name'] = cluster_df['Topic'].map(lambda x: topic_names.get(x, 'Unnamed Cluster'))
|
|
cluster_df = cluster_df[['Topic', 'Cluster_Name', 'Count', 'Top Keywords']]
|
|
|
|
|
|
st.write("### Updated Topic Overview:")
|
|
st.dataframe(
|
|
cluster_df,
|
|
column_config={
|
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
|
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
|
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
|
|
"Top Keywords": st.column_config.TextColumn(
|
|
"Top Keywords",
|
|
help="Top 5 keywords that characterize this topic"
|
|
)
|
|
},
|
|
hide_index=True
|
|
)
|
|
|
|
|
|
with st.spinner("Generating high-level summary from cluster summaries..."):
|
|
|
|
formatted_summaries = []
|
|
total_tokens = 0
|
|
MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75)
|
|
summary_batches = []
|
|
current_batch = []
|
|
current_batch_tokens = 0
|
|
|
|
for _, row in summary_df.iterrows():
|
|
summary_text = row.get('Enhanced_Summary', row['Summary'])
|
|
formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
|
|
summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
|
|
|
|
|
|
if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
|
|
if current_batch:
|
|
summary_batches.append(current_batch)
|
|
current_batch = []
|
|
current_batch_tokens = 0
|
|
|
|
current_batch.append(formatted_summary)
|
|
current_batch_tokens += summary_tokens
|
|
|
|
|
|
if current_batch:
|
|
summary_batches.append(current_batch)
|
|
|
|
|
|
batch_overviews = []
|
|
with st.spinner("Generating batch summaries..."):
|
|
for i, batch in enumerate(summary_batches, 1):
|
|
st.write(f"Processing batch {i} of {len(summary_batches)}...")
|
|
|
|
batch_text = "\n\n".join(batch)
|
|
batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>.
|
|
|
|
Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
|
|
1. Preserve all hyperlinked references exactly as they appear in the input summaries
|
|
2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries
|
|
3. Keep the markdown formatting for better readability
|
|
4. Note that this is part {i} of {len(summary_batches)} parts, so focus on the themes present in these specific clusters
|
|
|
|
Here are the cluster summaries to synthesize:
|
|
|
|
{batch_text}"""
|
|
|
|
|
|
high_level_system_message = SystemMessagePromptTemplate.from_template(st.session_state['system_prompt'])
|
|
high_level_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
|
|
high_level_chat_prompt = ChatPromptTemplate.from_messages([high_level_system_message, high_level_human_message])
|
|
high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
|
|
batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
|
|
batch_overviews.append(batch_overview)
|
|
|
|
|
|
with st.spinner("Generating final combined summary..."):
|
|
combined_overviews = "\n\n### Part ".join([f"{i+1}:\n\n{overview}" for i, overview in enumerate(batch_overviews)])
|
|
final_prompt = f"""Below are {len(batch_overviews)} overview summaries, each covering different clusters of research results. Each part maintains its original references to source documents.
|
|
|
|
Please create a final comprehensive synthesis that:
|
|
1. Integrates the key themes and findings from all parts
|
|
2. Preserves all hyperlinked references exactly as they appear
|
|
3. Maintains the HTML anchor tags (<a href="...">) intact
|
|
4. Keeps the markdown formatting for better readability
|
|
5. Creates a coherent narrative across all parts
|
|
6. Highlights any themes that span multiple parts
|
|
|
|
Here are the overviews to synthesize:
|
|
|
|
### Part 1:
|
|
|
|
{combined_overviews}"""
|
|
|
|
|
|
final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
|
|
if final_prompt_tokens > MAX_SAFE_TOKENS:
|
|
st.error(f"β Final synthesis prompt ({final_prompt_tokens:,} tokens) exceeds safe limit ({MAX_SAFE_TOKENS:,}). Using batch summaries separately.")
|
|
high_level_summary = "# Overall Summary\n\n" + "\n\n".join([f"## Batch {i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
|
|
else:
|
|
|
|
high_level_chain = LLMChain(llm=llm, prompt=high_level_chat_prompt)
|
|
high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
|
|
|
|
|
|
st.session_state['high_level_summary'] = high_level_summary
|
|
st.session_state['enhanced_summary'] = high_level_summary
|
|
|
|
|
|
st.session_state['summarization_completed'] = True
|
|
|
|
|
|
st.write("### High-Level Summary:")
|
|
st.markdown(high_level_summary, unsafe_allow_html=True)
|
|
|
|
|
|
st.write("### Cluster Summaries:")
|
|
if enable_references and 'Enhanced_Summary' in summary_df.columns:
|
|
for idx, row in summary_df.iterrows():
|
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
|
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**")
|
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
|
|
st.write("---")
|
|
with st.expander("View original summaries in table format"):
|
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
|
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
st.dataframe(display_df, hide_index=True)
|
|
else:
|
|
st.write("### Summaries per Cluster:")
|
|
if 'Cluster_Name' in summary_df.columns:
|
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
|
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
st.dataframe(display_df, hide_index=True)
|
|
else:
|
|
st.dataframe(summary_df, hide_index=True)
|
|
|
|
|
|
if 'Enhanced_Summary' in summary_df.columns:
|
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
|
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
else:
|
|
dl_df = summary_df
|
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
|
|
b64 = base64.b64encode(csv_bytes).decode()
|
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>'
|
|
st.markdown(href, unsafe_allow_html=True)
|
|
|
|
|
|
if st.session_state.get('summarization_completed', False):
|
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
|
|
if 'high_level_summary' in st.session_state:
|
|
st.write("### High-Level Summary:")
|
|
st.markdown(st.session_state['enhanced_summary'] if st.session_state.get('enhanced_summary') else st.session_state['high_level_summary'], unsafe_allow_html=True)
|
|
|
|
st.write("### Cluster Summaries:")
|
|
summary_df = st.session_state['summary_df']
|
|
if 'Enhanced_Summary' in summary_df.columns:
|
|
for idx, row in summary_df.iterrows():
|
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
|
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**")
|
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
|
|
st.write("---")
|
|
with st.expander("View original summaries in table format"):
|
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
|
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
st.dataframe(display_df, hide_index=True)
|
|
else:
|
|
st.dataframe(summary_df, hide_index=True)
|
|
|
|
|
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
|
|
if 'Cluster_Name' in dl_df.columns:
|
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
|
|
b64 = base64.b64encode(csv_bytes).decode()
|
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>'
|
|
st.markdown(href, unsafe_allow_html=True)
|
|
else:
|
|
st.warning("No data available for summarization.")
|
|
|
|
|
|
if not st.session_state.get('_summarization_button_clicked', False):
|
|
if 'high_level_summary' in st.session_state:
|
|
st.write("### Existing High-Level Summary:")
|
|
if st.session_state.get('enhanced_summary'):
|
|
st.markdown(st.session_state['enhanced_summary'], unsafe_allow_html=True)
|
|
with st.expander("View original summary (without references)"):
|
|
st.write(st.session_state['high_level_summary'])
|
|
else:
|
|
st.write(st.session_state['high_level_summary'])
|
|
|
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
|
|
st.write("### Existing Cluster Summaries:")
|
|
summary_df = st.session_state['summary_df']
|
|
if 'Enhanced_Summary' in summary_df.columns:
|
|
for idx, row in summary_df.iterrows():
|
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
|
|
st.write(f"**Topic {row['Topic']} - {cluster_name}**")
|
|
st.markdown(row['Enhanced_Summary'], unsafe_allow_html=True)
|
|
st.write("---")
|
|
with st.expander("View original summaries in table format"):
|
|
display_df = summary_df[['Topic', 'Cluster_Name', 'Summary']]
|
|
display_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
st.dataframe(display_df, hide_index=True)
|
|
else:
|
|
st.dataframe(summary_df, hide_index=True)
|
|
|
|
|
|
dl_df = summary_df[['Topic', 'Cluster_Name', 'Summary']] if 'Cluster_Name' in summary_df.columns else summary_df
|
|
if 'Cluster_Name' in dl_df.columns:
|
|
dl_df.columns = ['Topic', 'Cluster Name', 'Summary']
|
|
csv_bytes = dl_df.to_csv(index=False).encode('utf-8')
|
|
b64 = base64.b64encode(csv_bytes).decode()
|
|
href = f'<a href="data:file/csv;base64,{b64}" download="summaries.csv">Download Summaries CSV</a>'
|
|
st.markdown(href, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
with tab_chat:
|
|
st.header("Chat with Your Data")
|
|
|
|
|
|
with st.expander("βΉοΈ How Chat Works", expanded=False):
|
|
st.markdown("""
|
|
### Understanding Chat with Your Data
|
|
|
|
The chat functionality allows you to have an interactive conversation about your data, whether it's filtered, clustered, or raw. Here's how it works:
|
|
|
|
1. **Data Selection**:
|
|
- Choose which dataset to chat about (filtered, clustered, or search results)
|
|
- Optionally focus on specific clusters if clustering was performed
|
|
- System automatically includes relevant context from your selection
|
|
|
|
2. **Context Window**:
|
|
- Shows how much of the GPT-4 context window is being used
|
|
- Helps you understand if you need to filter data further
|
|
- Displays token usage statistics
|
|
|
|
3. **Chat Features**:
|
|
- Ask questions about your data
|
|
- Get insights and analysis
|
|
- Reference specific documents or clusters
|
|
- Download chat context for transparency
|
|
|
|
### Best Practices
|
|
|
|
1. **Data Selection**:
|
|
- Start with filtered or clustered data for more focused conversations
|
|
- Select specific clusters if you want to dive deep into a topic
|
|
- Consider the context window usage when selecting data
|
|
|
|
2. **Asking Questions**:
|
|
- Be specific in your questions
|
|
- Ask about patterns, trends, or insights
|
|
- Reference clusters or documents by their IDs
|
|
- Build on previous questions for deeper analysis
|
|
|
|
3. **Managing Context**:
|
|
- Monitor the context window usage
|
|
- Filter data further if context is too full
|
|
- Download chat context for documentation
|
|
- Clear chat history to start fresh
|
|
|
|
### Tips for Better Results
|
|
|
|
- **Question Types**:
|
|
- "What are the main themes in cluster 3?"
|
|
- "Compare the findings between clusters 1 and 2"
|
|
- "Summarize the methodology used across these documents"
|
|
- "What are the common outcomes reported?"
|
|
|
|
- **Follow-up Questions**:
|
|
- Build on previous answers
|
|
- Ask for clarification
|
|
- Request specific examples
|
|
- Explore relationships between findings
|
|
""")
|
|
|
|
|
|
def get_available_data_sources():
|
|
sources = []
|
|
if 'filtered_df' in st.session_state and not st.session_state['filtered_df'].empty:
|
|
sources.append("Filtered Dataset")
|
|
if 'clustered_data' in st.session_state and not st.session_state['clustered_data'].empty:
|
|
sources.append("Clustered Data")
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
|
|
sources.append("Search Results")
|
|
if ('high_level_summary' in st.session_state or
|
|
('summary_df' in st.session_state and not st.session_state['summary_df'].empty)):
|
|
sources.append("Summarized Data")
|
|
return sources
|
|
|
|
|
|
available_sources = get_available_data_sources()
|
|
|
|
if not available_sources:
|
|
st.warning("No data available for chat. Please filter, cluster, search, or summarize first.")
|
|
st.stop()
|
|
|
|
|
|
if 'chat_data_source' not in st.session_state:
|
|
st.session_state.chat_data_source = available_sources[0]
|
|
elif st.session_state.chat_data_source not in available_sources:
|
|
st.session_state.chat_data_source = available_sources[0]
|
|
|
|
|
|
data_source = st.radio(
|
|
"Select data to chat about:",
|
|
available_sources,
|
|
index=available_sources.index(st.session_state.chat_data_source),
|
|
help="Choose which dataset you want to analyze in the chat."
|
|
)
|
|
|
|
|
|
if data_source != st.session_state.chat_data_source:
|
|
st.session_state.chat_data_source = data_source
|
|
|
|
if 'chat_selected_cluster' in st.session_state:
|
|
del st.session_state.chat_selected_cluster
|
|
|
|
|
|
df_chat = None
|
|
if data_source == "Filtered Dataset":
|
|
df_chat = st.session_state['filtered_df']
|
|
elif data_source == "Clustered Data":
|
|
df_chat = st.session_state['clustered_data']
|
|
elif data_source == "Search Results":
|
|
df_chat = st.session_state['search_results']
|
|
elif data_source == "Summarized Data":
|
|
|
|
summary_rows = []
|
|
|
|
|
|
if 'high_level_summary' in st.session_state:
|
|
summary_rows.append({
|
|
'Summary_Type': 'High-Level Summary',
|
|
'Content': st.session_state.get('enhanced_summary', st.session_state['high_level_summary'])
|
|
})
|
|
|
|
|
|
if 'summary_df' in st.session_state and not st.session_state['summary_df'].empty:
|
|
summary_df = st.session_state['summary_df']
|
|
for _, row in summary_df.iterrows():
|
|
summary_rows.append({
|
|
'Summary_Type': f"Cluster {row['Topic']} Summary",
|
|
'Content': row.get('Enhanced_Summary', row['Summary'])
|
|
})
|
|
|
|
if summary_rows:
|
|
df_chat = pd.DataFrame(summary_rows)
|
|
|
|
if df_chat is not None and not df_chat.empty:
|
|
|
|
selected_cluster = None
|
|
if data_source != "Summarized Data" and 'Topic' in df_chat.columns:
|
|
cluster_option = st.radio(
|
|
"Choose cluster scope:",
|
|
["All Clusters", "Specific Cluster"]
|
|
)
|
|
if cluster_option == "Specific Cluster":
|
|
unique_topics = sorted(df_chat['Topic'].unique())
|
|
|
|
if 'summary_df' in st.session_state and 'Cluster_Name' in st.session_state['summary_df'].columns:
|
|
summary_df = st.session_state['summary_df']
|
|
|
|
topic_names = {t: name for t, name in zip(summary_df['Topic'], summary_df['Cluster_Name'])}
|
|
|
|
topic_options = [
|
|
(t, f"Cluster {t} - {topic_names.get(t, 'Unnamed Cluster')}")
|
|
for t in unique_topics
|
|
]
|
|
selected_cluster = st.selectbox(
|
|
"Select cluster to focus on:",
|
|
[t[0] for t in topic_options],
|
|
format_func=lambda x: next(opt[1] for opt in topic_options if opt[0] == x)
|
|
)
|
|
else:
|
|
selected_cluster = st.selectbox(
|
|
"Select cluster to focus on:",
|
|
unique_topics,
|
|
format_func=lambda x: f"Cluster {x}"
|
|
)
|
|
if selected_cluster is not None:
|
|
df_chat = df_chat[df_chat['Topic'] == selected_cluster]
|
|
st.session_state.chat_selected_cluster = selected_cluster
|
|
elif 'chat_selected_cluster' in st.session_state:
|
|
del st.session_state.chat_selected_cluster
|
|
|
|
|
|
text_columns = st.session_state.get('text_columns', [])
|
|
if not text_columns and data_source != "Summarized Data":
|
|
st.warning("No text columns selected. Please select text columns to enable chat functionality.")
|
|
st.stop()
|
|
|
|
|
|
MAX_ALLOWED_TOKENS = int(MAX_CONTEXT_WINDOW * 0.95)
|
|
|
|
|
|
system_msg = {
|
|
"role": "system",
|
|
"content": """You are a specialized assistant analyzing data from a research database.
|
|
Your role is to:
|
|
1. Provide clear, concise answers based on the data provided
|
|
2. Highlight relevant information from specific results when answering
|
|
3. When referencing specific results, use their row index or ID if available
|
|
4. Clearly state if information is not available in the results
|
|
5. Maintain a professional and analytical tone
|
|
6. Format your responses using Markdown:
|
|
- Use **bold** for emphasis
|
|
- Use bullet points and numbered lists for structured information
|
|
- Create tables using Markdown syntax when presenting structured data
|
|
- Use backticks for code or technical terms
|
|
- Include hyperlinks when referencing external sources
|
|
- Use headings (###) to organize long responses
|
|
|
|
The data is provided in a structured format where:""" + ("""
|
|
- Each result contains multiple fields
|
|
- Text content is primarily in the following columns: """ + ", ".join(text_columns) + """
|
|
- Additional metadata and fields are available for reference
|
|
- If clusters are present, they are numbered (e.g., Cluster 0, Cluster 1, etc.)""" if data_source != "Summarized Data" else """
|
|
- The data consists of AI-generated summaries of the documents
|
|
- Each summary may contain references to source documents in markdown format
|
|
- References are shown as [ID] or as clickable hyperlinks
|
|
- Summaries may be high-level (covering all documents) or cluster-specific""") + """
|
|
"""
|
|
}
|
|
|
|
|
|
system_tokens = len(tokenizer(system_msg["content"])["input_ids"])
|
|
remaining_tokens = MAX_ALLOWED_TOKENS - system_tokens
|
|
|
|
|
|
data_text = "Available Data:\n"
|
|
included_rows = 0
|
|
total_rows = len(df_chat)
|
|
|
|
if data_source == "Summarized Data":
|
|
|
|
for idx, row in df_chat.iterrows():
|
|
row_text = f"\n{row['Summary_Type']}:\n{row['Content']}\n"
|
|
row_tokens = len(tokenizer(row_text)["input_ids"])
|
|
|
|
if remaining_tokens - row_tokens > 0:
|
|
data_text += row_text
|
|
remaining_tokens -= row_tokens
|
|
included_rows += 1
|
|
else:
|
|
break
|
|
else:
|
|
|
|
for idx, row in df_chat.iterrows():
|
|
row_text = f"\nItem {idx}:\n"
|
|
for col in df_chat.columns:
|
|
if not pd.isna(row[col]) and str(row[col]).strip() and col != 'similarity_score':
|
|
row_text += f"{col}: {row[col]}\n"
|
|
|
|
row_tokens = len(tokenizer(row_text)["input_ids"])
|
|
if remaining_tokens - row_tokens > 0:
|
|
data_text += row_text
|
|
remaining_tokens -= row_tokens
|
|
included_rows += 1
|
|
else:
|
|
break
|
|
|
|
|
|
data_tokens = len(tokenizer(data_text)["input_ids"])
|
|
total_tokens = system_tokens + data_tokens
|
|
context_usage_percent = (total_tokens / MAX_CONTEXT_WINDOW) * 100
|
|
|
|
|
|
st.subheader("Context Window Usage")
|
|
st.write(f"System Message: {system_tokens:,} tokens")
|
|
st.write(f"Data Context: {data_tokens:,} tokens")
|
|
st.write(f"Total: {total_tokens:,} tokens ({context_usage_percent:.1f}% of available context)")
|
|
st.write(f"Documents included: {included_rows:,} out of {total_rows:,} ({(included_rows/total_rows*100):.1f}%)")
|
|
|
|
if context_usage_percent > 90:
|
|
st.warning("β οΈ High context usage! Consider reducing the number of results or filtering further.")
|
|
elif context_usage_percent > 75:
|
|
st.info("βΉοΈ Moderate context usage. Still room for your question, but consider reducing results if asking a long question.")
|
|
|
|
|
|
chat_context = f"""System Message:
|
|
{system_msg['content']}
|
|
|
|
{data_text}"""
|
|
st.download_button(
|
|
label="π₯ Download Chat Context",
|
|
data=chat_context,
|
|
file_name="chat_context.txt",
|
|
mime="text/plain",
|
|
help="Download the exact context that the chatbot receives"
|
|
)
|
|
|
|
|
|
col_chat1, col_chat2 = st.columns([3, 1])
|
|
with col_chat1:
|
|
user_input = st.text_area("Ask a question about your data:", key="chat_input")
|
|
with col_chat2:
|
|
if st.button("Clear Chat History"):
|
|
st.session_state.chat_history = []
|
|
st.rerun()
|
|
|
|
|
|
current_tab = tabs_titles.index("Chat")
|
|
|
|
if st.button("Send", key="send_button"):
|
|
if user_input:
|
|
|
|
st.session_state.active_tab_index = current_tab
|
|
|
|
with st.spinner("Processing your question..."):
|
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
|
|
|
|
|
messages = [system_msg]
|
|
messages.append({"role": "user", "content": f"Here is the data to reference:\n\n{data_text}\n\nUser question: {user_input}"})
|
|
|
|
|
|
response = get_chat_response(messages)
|
|
|
|
if response:
|
|
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
|
|
|
st.subheader("Chat History")
|
|
for message in st.session_state.chat_history:
|
|
if message["role"] == "user":
|
|
st.write("**You:**", message["content"])
|
|
else:
|
|
st.write("**Assistant:**")
|
|
st.markdown(message["content"], unsafe_allow_html=True)
|
|
st.write("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
st.header("Automatic Mode")
|
|
|
|
|
|
if 'df' not in st.session_state:
|
|
default_dataset_path = os.path.join(BASE_DIR, 'input', 'export_data_table_results_20251203_101413CET.xlsx')
|
|
df = load_default_dataset(default_dataset_path)
|
|
if df is not None:
|
|
st.session_state['df'] = df.copy()
|
|
st.session_state['using_default_dataset'] = True
|
|
st.session_state['filtered_df'] = df.copy()
|
|
|
|
|
|
if 'text_columns' not in st.session_state or not st.session_state['text_columns']:
|
|
default_text_cols = []
|
|
if 'Title' in df.columns and 'Description' in df.columns:
|
|
default_text_cols = ['Title', 'Description']
|
|
st.session_state['text_columns'] = default_text_cols
|
|
|
|
|
|
|
|
query = st.text_input("Write your query here:")
|
|
|
|
|
|
|
|
|
|
if st.button("SNAP!"):
|
|
if query.strip():
|
|
|
|
st.write("### Step 1: Semantic Search")
|
|
with st.spinner("Performing Semantic Search..."):
|
|
text_columns = st.session_state.get('text_columns', [])
|
|
if text_columns:
|
|
df_full = st.session_state['df']
|
|
embeddings, _ = load_or_compute_embeddings(
|
|
df_full,
|
|
st.session_state.get('using_default_dataset', False),
|
|
st.session_state.get('uploaded_file_name'),
|
|
text_columns
|
|
)
|
|
|
|
if embeddings is not None:
|
|
model = get_embedding_model()
|
|
df_filtered = st.session_state['filtered_df'].fillna("")
|
|
search_texts = df_filtered[text_columns].agg(' '.join, axis=1).tolist()
|
|
|
|
subset_indices = df_filtered.index
|
|
subset_embeddings = embeddings[subset_indices]
|
|
|
|
query_embedding = model.encode([query], device=device)
|
|
similarities = cosine_similarity(query_embedding, subset_embeddings)[0]
|
|
|
|
similarity_threshold = 0.35
|
|
above_threshold_indices = np.where(similarities > similarity_threshold)[0]
|
|
|
|
if len(above_threshold_indices) > 0:
|
|
selected_indices = subset_indices[above_threshold_indices]
|
|
results = df_filtered.loc[selected_indices].copy()
|
|
results['similarity_score'] = similarities[above_threshold_indices]
|
|
results.sort_values(by='similarity_score', ascending=False, inplace=True)
|
|
st.session_state['search_results'] = results.copy()
|
|
st.write(f"Found {len(results)} relevant documents")
|
|
else:
|
|
st.warning("No results found above the similarity threshold.")
|
|
st.stop()
|
|
|
|
|
|
if 'search_results' in st.session_state and not st.session_state['search_results'].empty:
|
|
st.write("### Step 2: Clustering")
|
|
with st.spinner("Performing clustering..."):
|
|
df_to_cluster = st.session_state['search_results'].copy()
|
|
dfc = df_to_cluster.copy().fillna("")
|
|
dfc['text'] = dfc[text_columns].astype(str).agg(' '.join, axis=1)
|
|
|
|
|
|
selected_indices = dfc.index
|
|
embeddings_clustering = embeddings[selected_indices]
|
|
|
|
|
|
stop_words = set(stopwords.words('english'))
|
|
texts_cleaned = []
|
|
for text in dfc['text'].tolist():
|
|
try:
|
|
word_tokens = word_tokenize(text)
|
|
filtered_text = ' '.join([w for w in word_tokens if w.lower() not in stop_words])
|
|
texts_cleaned.append(filtered_text)
|
|
except Exception as e:
|
|
texts_cleaned.append(text)
|
|
|
|
min_cluster_size = 5
|
|
|
|
try:
|
|
|
|
if torch.is_tensor(embeddings_clustering):
|
|
embeddings_for_clustering = embeddings_clustering.cpu().numpy()
|
|
else:
|
|
embeddings_for_clustering = embeddings_clustering
|
|
|
|
|
|
hdbscan_model = HDBSCAN(
|
|
min_cluster_size=min_cluster_size,
|
|
metric='euclidean',
|
|
cluster_selection_method='eom'
|
|
)
|
|
|
|
|
|
topic_model = BERTopic(
|
|
embedding_model=get_embedding_model(),
|
|
hdbscan_model=hdbscan_model
|
|
)
|
|
|
|
|
|
topics, probs = topic_model.fit_transform(
|
|
texts_cleaned,
|
|
embeddings=embeddings_for_clustering
|
|
)
|
|
|
|
|
|
dfc['Topic'] = topics
|
|
st.session_state['topic_model'] = topic_model
|
|
st.session_state['clustered_data'] = dfc.copy()
|
|
st.session_state['clustering_completed'] = True
|
|
|
|
|
|
unique_topics = sorted(list(set(topics)))
|
|
num_clusters = len([t for t in unique_topics if t != -1])
|
|
noise_docs = len([t for t in topics if t == -1])
|
|
clustered_docs = len(topics) - noise_docs
|
|
|
|
st.write(f"Found {num_clusters} distinct clusters")
|
|
|
|
|
|
|
|
|
|
|
|
cluster_info = []
|
|
for t in unique_topics:
|
|
if t != -1:
|
|
cluster_docs = dfc[dfc['Topic'] == t]
|
|
count = len(cluster_docs)
|
|
top_words = topic_model.get_topic(t)
|
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
|
|
cluster_info.append((t, count, top_keywords))
|
|
|
|
if cluster_info:
|
|
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Count", "Top Keywords"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
st.session_state['intertopic_distance_fig'] = topic_model.visualize_topics()
|
|
except Exception:
|
|
st.session_state['intertopic_distance_fig'] = None
|
|
|
|
try:
|
|
st.session_state['topic_document_fig'] = topic_model.visualize_documents(
|
|
texts_cleaned,
|
|
embeddings=embeddings_for_clustering
|
|
)
|
|
except Exception:
|
|
st.session_state['topic_document_fig'] = None
|
|
|
|
try:
|
|
hierarchy = topic_model.hierarchical_topics(texts_cleaned)
|
|
st.session_state['hierarchy'] = hierarchy if hierarchy is not None else pd.DataFrame()
|
|
st.session_state['hierarchy_fig'] = topic_model.visualize_hierarchy()
|
|
except Exception:
|
|
st.session_state['hierarchy'] = pd.DataFrame()
|
|
st.session_state['hierarchy_fig'] = None
|
|
|
|
except Exception as e:
|
|
st.error(f"An error occurred during clustering: {str(e)}")
|
|
st.stop()
|
|
|
|
|
|
if st.session_state.get('clustering_completed', False):
|
|
st.write("### Step 3: Summarization")
|
|
|
|
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
|
if not openai_api_key:
|
|
st.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
|
st.stop()
|
|
|
|
llm = ChatOpenAI(
|
|
api_key=openai_api_key,
|
|
model_name='gpt-4o-mini',
|
|
temperature=0.7,
|
|
max_tokens=1000
|
|
)
|
|
|
|
df_scope = st.session_state['clustered_data']
|
|
unique_selected_topics = df_scope['Topic'].unique()
|
|
|
|
|
|
with st.spinner("Generating summaries..."):
|
|
local_system_message = SystemMessagePromptTemplate.from_template("""You are an expert summarizer skilled in creating concise and relevant summaries.
|
|
You will be given text and an objective context. Please produce a clear, cohesive,
|
|
and thematically relevant summary.
|
|
Focus on key points, insights, or patterns that emerge from the text.""")
|
|
local_human_message = HumanMessagePromptTemplate.from_template("{user_prompt}")
|
|
local_chat_prompt = ChatPromptTemplate.from_messages([local_system_message, local_human_message])
|
|
|
|
|
|
url_column = next((col for col in df_scope.columns if 'url' in col.lower() or 'link' in col.lower() or 'pdf' in col.lower()), None)
|
|
|
|
summaries = process_summaries_in_parallel(
|
|
df_scope=df_scope,
|
|
unique_selected_topics=unique_selected_topics,
|
|
llm=llm,
|
|
chat_prompt=local_chat_prompt,
|
|
enable_references=True,
|
|
reference_id_column=df_scope.columns[0],
|
|
url_column=url_column,
|
|
max_workers=min(16, len(unique_selected_topics))
|
|
)
|
|
|
|
if summaries:
|
|
summary_df = pd.DataFrame(summaries)
|
|
st.session_state['summary_df'] = summary_df
|
|
|
|
|
|
if 'Cluster_Name' in summary_df.columns:
|
|
st.write("### Updated Topic Overview:")
|
|
cluster_info = []
|
|
for t in unique_selected_topics:
|
|
cluster_docs = df_scope[df_scope['Topic'] == t]
|
|
count = len(cluster_docs)
|
|
top_words = topic_model.get_topic(t)
|
|
top_keywords = ", ".join([w[0] for w in top_words[:5]]) if top_words else "N/A"
|
|
cluster_name = summary_df[summary_df['Topic'] == t]['Cluster_Name'].iloc[0]
|
|
cluster_info.append((t, cluster_name, count, top_keywords))
|
|
|
|
cluster_df = pd.DataFrame(cluster_info, columns=["Topic", "Cluster_Name", "Count", "Top Keywords"])
|
|
st.dataframe(
|
|
cluster_df,
|
|
column_config={
|
|
"Topic": st.column_config.NumberColumn("Topic", help="Topic ID (-1 represents outliers)"),
|
|
"Cluster_Name": st.column_config.TextColumn("Cluster Name", help="AI-generated name describing the cluster theme"),
|
|
"Count": st.column_config.NumberColumn("Count", help="Number of documents in this topic"),
|
|
"Top Keywords": st.column_config.TextColumn(
|
|
"Top Keywords",
|
|
help="Top 5 keywords that characterize this topic"
|
|
)
|
|
},
|
|
hide_index=True
|
|
)
|
|
|
|
|
|
with st.spinner("Generating high-level summary..."):
|
|
formatted_summaries = []
|
|
summary_batches = []
|
|
current_batch = []
|
|
current_batch_tokens = 0
|
|
MAX_SAFE_TOKENS = int(MAX_CONTEXT_WINDOW * 0.75)
|
|
|
|
for _, row in summary_df.iterrows():
|
|
summary_text = row.get('Enhanced_Summary', row['Summary'])
|
|
formatted_summary = f"### Cluster {row['Topic']} Summary:\n\n{summary_text}"
|
|
summary_tokens = len(tokenizer(formatted_summary)["input_ids"])
|
|
|
|
if current_batch_tokens + summary_tokens > MAX_SAFE_TOKENS:
|
|
if current_batch:
|
|
summary_batches.append(current_batch)
|
|
current_batch = []
|
|
current_batch_tokens = 0
|
|
|
|
current_batch.append(formatted_summary)
|
|
current_batch_tokens += summary_tokens
|
|
|
|
if current_batch:
|
|
summary_batches.append(current_batch)
|
|
|
|
|
|
batch_overviews = []
|
|
for i, batch in enumerate(summary_batches, 1):
|
|
st.write(f"Processing summary batch {i} of {len(summary_batches)}...")
|
|
batch_text = "\n\n".join(batch)
|
|
batch_prompt = f"""Below are summaries from a subset of clusters from results made using Transformers NLP on a set of results from the CGIAR reporting system. Each summary contains references to source documents in the form of hyperlinked IDs like [ID] or <a href="...">ID</a>.
|
|
|
|
Please create a comprehensive overview that synthesizes these clusters so that both the main themes and findings are covered in an organized way. IMPORTANT:
|
|
1. Preserve all hyperlinked references exactly as they appear in the input summaries
|
|
2. Maintain the HTML anchor tags (<a href="...">) intact when using information from the summaries
|
|
3. Keep the markdown formatting for better readability
|
|
4. Create clear sections with headings for different themes
|
|
5. Use bullet points or numbered lists where appropriate
|
|
6. Focus on synthesizing the main themes and findings
|
|
|
|
Here are the cluster summaries to synthesize:
|
|
|
|
{batch_text}"""
|
|
|
|
high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
|
|
batch_overview = high_level_chain.run(user_prompt=batch_prompt).strip()
|
|
batch_overviews.append(batch_overview)
|
|
|
|
|
|
if len(batch_overviews) > 1:
|
|
st.write("Generating final synthesis...")
|
|
combined_overviews = "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
|
|
final_prompt = f"""Below are multiple overview summaries, each covering different aspects of CGIAR research results. Each part maintains its original references to source documents.
|
|
|
|
Please create a final comprehensive synthesis that:
|
|
1. Integrates the key themes and findings from all parts into a cohesive narrative
|
|
2. Preserves all hyperlinked references exactly as they appear
|
|
3. Maintains the HTML anchor tags (<a href="...">) intact
|
|
4. Uses clear section headings and structured formatting
|
|
5. Highlights cross-cutting themes and relationships between different aspects
|
|
6. Provides a clear introduction and conclusion
|
|
|
|
Here are the overviews to synthesize:
|
|
|
|
# Part 1
|
|
|
|
{combined_overviews}"""
|
|
|
|
final_prompt_tokens = len(tokenizer(final_prompt)["input_ids"])
|
|
if final_prompt_tokens > MAX_SAFE_TOKENS:
|
|
|
|
high_level_summary = "# Comprehensive Overview\n\n" + "\n\n# Part ".join([f"{i+1}\n\n{overview}" for i, overview in enumerate(batch_overviews)])
|
|
else:
|
|
high_level_chain = LLMChain(llm=llm, prompt=local_chat_prompt)
|
|
high_level_summary = high_level_chain.run(user_prompt=final_prompt).strip()
|
|
else:
|
|
|
|
high_level_summary = batch_overviews[0]
|
|
|
|
st.session_state['high_level_summary'] = high_level_summary
|
|
st.session_state['enhanced_summary'] = high_level_summary
|
|
|
|
|
|
st.write("### High-Level Summary:")
|
|
with st.expander("High-Level Summary", expanded=True):
|
|
st.markdown(high_level_summary, unsafe_allow_html=True)
|
|
|
|
st.write("### Cluster Summaries:")
|
|
for idx, row in summary_df.iterrows():
|
|
cluster_name = row.get('Cluster_Name', 'Unnamed Cluster')
|
|
with st.expander(f"Topic {row['Topic']} - {cluster_name}", expanded=False):
|
|
st.markdown(row.get('Enhanced_Summary', row['Summary']), unsafe_allow_html=True)
|
|
st.markdown("##### About this tool")
|
|
with st.expander("Click to expand/collapse", expanded=True):
|
|
st.markdown("""
|
|
This tool draws on CGIAR quality assured results data from 2022-2024 to provide verifiable responses to user questions around the themes and areas CGIAR has/is working on.
|
|
|
|
**Tips:**
|
|
- **Craft a phrase** that describes your topic of interest (e.g., `"climate-smart agriculture"`, `"gender equality livestock"`).
|
|
- Avoid writing full questions β **this is not a chatbot**.
|
|
- Combine **related terms** for better results (e.g., `"irrigation water access smallholders"`).
|
|
- Focus on **concepts or themes** β not single words like `"climate"` or `"yield"` alone.
|
|
- Example good queries:
|
|
- `"climate adaptation smallholder farming"`
|
|
- `"digital agriculture innovations"`
|
|
- `"nutrition-sensitive value chains"`
|
|
|
|
**Example use case**:
|
|
You're interested in CGIAR's contributions to **poverty reduction through improved maize varieties in Africa**.
|
|
A good search phrase would be:
|
|
π `"poverty reduction maize Africa"`
|
|
This will retrieve results related to improved crop varieties, livelihood outcomes, and region-specific interventions, even if the documents use different wording like *"enhanced maize genetics"*, *"smallholder income"*, or *"eastern Africa trials"*.
|
|
""") |