bjhk / doc_explorer /explorer.py
heymenn's picture
Upload 15 files
6aaddef verified
import gradio as gr
from vectorstore import FAISSVectorStore
from langchain_community.graphs import Neo4jGraph
import os
import json
import html
import pandas as pd
import time
time.sleep(30)
os.environ["http_proxy"] = "185.46.212.98:80"
os.environ["https_proxy"] = "185.46.212.98:80"
os.environ["NO_PROXY"] = "localhost"
neo4j_graph = Neo4jGraph(
url=os.getenv("NEO4J_URI", "bolt://localhost:7999"),
username=os.getenv("NEO4J_USERNAME", "neo4j"),
password=os.getenv("NEO4J_PASSWORD", "graph_test")
)
# Requires ~1GB RAM
vector_store = FAISSVectorStore(model_name='Alibaba-NLP/gte-large-en-v1.5', dimension=1024, trust_remote_code=True, embedding_file="/usr/src/app/doc_explorer/embeddings_full.npy")
# Get document types from Neo4j database
def get_document_types():
query = """
MATCH (n)
RETURN DISTINCT labels(n) AS document_type
"""
result = neo4j_graph.query(query)
return [row["document_type"][0] for row in result]
def search(query, doc_types, use_mmr, lambda_param, top_k):
results, node_ids = vector_store.similarity_search(
query,
k=top_k,
use_mmr=use_mmr,
lambda_param=lambda_param if use_mmr else None,
doc_types=doc_types,
neo4j_graph=neo4j_graph
)
formatted_results = []
formatted_choices = []
for i, result in enumerate(results):
formatted_results.append(f"{i}. {result['document']} (Score: {result['score']:.4f})")
formatted_choices.append(f"{i}. {str(result['document'])[:100]} (Score: {result['score']:.4f})")
return formatted_results, gr.update(choices=formatted_choices, value=[]), node_ids
def get_docs_from_ids(graph_data : dict):
node_ids = [node["id"] for node in graph_data["nodes"]]
print(node_ids)
query = """
MATCH (n)
WHERE n.id IN $node_ids
RETURN n.id AS id, n AS doc, labels(n) AS category
"""
return neo4j_graph.query(query, {"node_ids" : node_ids}), graph_data["edges"]
def get_neighbors_and_graph_data(selected_documents, node_ids, graph_data):
if not selected_documents:
return "No documents selected.", json.dumps(graph_data), graph_data
selected_indices = [int(doc.split('.')[0]) - 1 for doc in selected_documents]
selected_node_ids = [node_ids[i] for i in selected_indices]
query = """
MATCH (n)-[r]-(neighbor)
WHERE n.id IN $node_ids
RETURN n.id AS source_id, n AS source_text, labels(n) AS source_type,
neighbor.id AS neighbor_id, neighbor AS neighbor_text,
labels(neighbor) AS neighbor_type, type(r) AS relationship_type
"""
results = neo4j_graph.query(query, {"node_ids": selected_node_ids})
if not results:
return "No neighbors found for the selected documents.", "[]"
neighbor_info = {}
node_set = set([node["id"] for node in graph_data["nodes"]])
for row in results:
source_id = row['source_id']
if source_id not in neighbor_info:
neighbor_info[source_id] = {
'source_type': row["source_type"][0],
'source_text': row['source_text'],
'neighbors': []
}
if source_id not in node_set:
graph_data["nodes"].append({
"id": source_id,
"label": str(row['source_text'])[:30] + "...",
"group": row['source_type'][0],
"title": f"<div class='node-tooltip'><h3>{row['source_type'][0]}</h3><p>{row['source_text']}</p></div>",
})
node_set.add(source_id)
neighbor_info[source_id]['neighbors'].append(
f"[{row['relationship_type']}] [{row['neighbor_type'][0]}] {str(row['neighbor_text'])[:200]}"
)
if row['neighbor_id'] not in node_set:
graph_data["nodes"].append({
"id": row['neighbor_id'],
"label": str(row['neighbor_text'])[:30] + "...",
"group": row['neighbor_type'][0],
"title": f"<div class='node-tooltip'><h3>{row['neighbor_type'][0]}</h3><p>{html.escape(str(row['neighbor_text']))}</p></div>",
})
node_set.add(row['neighbor_id'])
edge = {
"from": source_id,
"to" : row['neighbor_id'],
"label": row['relationship_type']
}
if edge not in graph_data['edges']:
graph_data['edges'].append(edge)
output = []
for source_id, info in neighbor_info.items():
output.append(f"Neighbors for: [{info['source_type']}] {str(info['source_text'])[:100]}")
output.extend(info['neighbors'])
output.append("\n\n") # Empty line for separation
formatted_choices = []
node_ids = []
for i, node in enumerate(graph_data['nodes']):
formatted_choices.append(f"{i+1}. {str(node['label'])})")
node_ids.append(node['id'])
return "\n".join(output), json.dumps(graph_data), graph_data, gr.update(choices=formatted_choices, value=[]), node_ids
def save_docs_to_excel(exported_docs : list[dict], exported_relationships : list[dict]):
cleaned_docs = [dict(doc['doc'], **{'id': doc['id'], 'category': doc['category'][0], "relationships" : ""}) for doc in exported_docs]
for relationship in exported_relationships:
for doc in cleaned_docs:
if doc['id'] == relationship['from']:
doc["relationships"] += f"[{relationship['label']}] {relationship['to']}\n"
df = pd.DataFrame(cleaned_docs)
df.to_excel("doc_explorer/exported_docs/docs.xlsx")
return gr.update(value="doc_explorer/exported_docs/docs.xlsx", visible=True)
# JavaScript code for graph visualization
js_code = """
function(graph_data_str) {
if (!graph_data_str) return;
const container = document.getElementById('graph-container');
container.innerHTML = '';
let data;
try {
data = JSON.parse(graph_data_str);
} catch (error) {
console.error("Failed to parse graph data:", error);
container.innerHTML = "Error: Failed to load graph data.";
return;
}
data.nodes.forEach(node => {
const div = document.createElement('div');
div.innerHTML = node.title;
node.title = div.firstChild;
});
const nodes = new vis.DataSet(data.nodes);
const edges = new vis.DataSet(data.edges);
const options = {
nodes: {
shape: 'dot',
size: 16,
font: {
size: 12,
color: '#000000'
},
borderWidth: 2
},
edges: {
width: 1,
font: {
size: 10,
align: 'middle'
},
color: { color: '#7A7A7A', hover: '#2B7CE9' }
},
physics: {
forceAtlas2Based: {
gravitationalConstant: -26,
centralGravity: 0.005,
springLength: 230,
springConstant: 0.18
},
maxVelocity: 146,
solver: 'forceAtlas2Based',
timestep: 0.35,
stabilization: { iterations: 150 }
},
interaction: {
hover: true,
tooltipDelay: 200
}
};
const network = new vis.Network(container, { nodes: nodes, edges: edges }, options);
}
"""
head = """
<script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
<link href="https://unpkg.com/vis-network/styles/vis-network.min.css" rel="stylesheet" type="text/css" />
"""
custom_css = """
#graph-container {
border: 1px solid #ddd;
border-radius: 4px;
}
.vis-tooltip {
font-family: Arial, sans-serif;
padding: 10px;
border-radius: 4px;
background-color: rgba(255, 255, 255, 0.9);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
max-width: 300px;
color: #333;
word-wrap: break-word;
overflow-wrap: break-word;
}
.node-tooltip {
width: 100%;
}
.node-tooltip h3 {
margin: 0 0 5px 0;
font-size: 14px;
color: #333;
}
.node-tooltip p {
margin: 0;
font-size: 12px;
color: #666;
white-space: normal;
}
"""
with gr.Blocks(head=head, css=custom_css) as demo:
with gr.Tab("Search"):
gr.Markdown("# Document Search Engine")
gr.Markdown("Enter a query to search for similar documents. You can filter by document type and use MMR for diverse results.")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(label="Enter your query")
doc_type_input = gr.Dropdown(choices=get_document_types(), label="Select document type", multiselect=True)
with gr.Column(scale=2):
mmr_input = gr.Checkbox(label="Use MMR for diverse results")
lambda_input = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Lambda parameter (MMR diversity)", visible=False)
top_k_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of results")
search_button = gr.Button("Search")
results_output = gr.Textbox(label="Search Results")
selected_documents = gr.Dropdown(label="Select documents to view their neighbors", choices=[], multiselect=True, interactive=True)
with gr.Row():
neighbor_search_button = gr.Button("Find Neighbors")
send_to_export = gr.Button("Send docs to export Tab")
neighbors_output = gr.Textbox(label="Document Neighbors")
graph_data_state = gr.State({"nodes": [], "edges": []})
graph_data_str = gr.Textbox(visible=False)
graph_container = gr.HTML('<div id="graph-container" style="height: 600px;"> Hey ! </div>')
node_ids = gr.State([])
exported_docs = gr.State([])
exported_relationships = gr.State([])
def update_lambda_visibility(use_mmr):
return gr.update(visible=use_mmr)
mmr_input.change(fn=update_lambda_visibility, inputs=mmr_input, outputs=lambda_input)
search_button.click(
fn=search,
inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input],
outputs=[results_output, selected_documents, node_ids]
)
neighbor_search_button.click(
fn=get_neighbors_and_graph_data,
inputs=[selected_documents, node_ids, graph_data_state],
outputs=[neighbors_output, graph_data_str, graph_data_state, selected_documents, node_ids]
).then(
fn=None,
inputs=graph_data_str,
outputs=None,
js=js_code,
)
send_to_export.click(
fn=get_docs_from_ids,
inputs=graph_data_state,
outputs=[exported_docs, exported_relationships]
)
# gr.Examples(
# examples=[
# ["What is machine learning?", "Article", True, 0.5, 5],
# ["How to implement a neural network?", "Tutorial", False, 0.5, 3],
# ["Latest advancements in NLP", "Research Paper", True, 0.7, 10]
# ],
# inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input]
# )
with gr.Tab("Export"):
with gr.Row():
exported_docs_btn = gr.Button("Display exported docs")
exported_excel_btn = gr.Button("Export to excel")
exported_excel = gr.File(visible=False)
exported_docs_display = gr.Markdown(visible=False)
exported_docs_btn.click(
fn= lambda docs: gr.update(value='\n\n'.join([f"[{doc['category'][0]}]\n{doc['doc']}\n\n" for doc in docs]), visible=True),
inputs=exported_docs,
outputs=exported_docs_display
)
exported_excel_btn.click(
fn=save_docs_to_excel,
inputs=[exported_docs, exported_relationships],
outputs=exported_excel
)
demo.launch()