import gradio as gr
from vectorstore import FAISSVectorStore
from langchain_community.graphs import Neo4jGraph
import os
import json
import html
import pandas as pd
import time
time.sleep(30)
os.environ["http_proxy"] = "185.46.212.98:80"
os.environ["https_proxy"] = "185.46.212.98:80"
os.environ["NO_PROXY"] = "localhost"
neo4j_graph = Neo4jGraph(
url=os.getenv("NEO4J_URI", "bolt://localhost:7999"),
username=os.getenv("NEO4J_USERNAME", "neo4j"),
password=os.getenv("NEO4J_PASSWORD", "graph_test")
)
# Requires ~1GB RAM
vector_store = FAISSVectorStore(model_name='Alibaba-NLP/gte-large-en-v1.5', dimension=1024, trust_remote_code=True, embedding_file="/usr/src/app/doc_explorer/embeddings_full.npy")
# Get document types from Neo4j database
def get_document_types():
query = """
MATCH (n)
RETURN DISTINCT labels(n) AS document_type
"""
result = neo4j_graph.query(query)
return [row["document_type"][0] for row in result]
def search(query, doc_types, use_mmr, lambda_param, top_k):
results, node_ids = vector_store.similarity_search(
query,
k=top_k,
use_mmr=use_mmr,
lambda_param=lambda_param if use_mmr else None,
doc_types=doc_types,
neo4j_graph=neo4j_graph
)
formatted_results = []
formatted_choices = []
for i, result in enumerate(results):
formatted_results.append(f"{i}. {result['document']} (Score: {result['score']:.4f})")
formatted_choices.append(f"{i}. {str(result['document'])[:100]} (Score: {result['score']:.4f})")
return formatted_results, gr.update(choices=formatted_choices, value=[]), node_ids
def get_docs_from_ids(graph_data : dict):
node_ids = [node["id"] for node in graph_data["nodes"]]
print(node_ids)
query = """
MATCH (n)
WHERE n.id IN $node_ids
RETURN n.id AS id, n AS doc, labels(n) AS category
"""
return neo4j_graph.query(query, {"node_ids" : node_ids}), graph_data["edges"]
def get_neighbors_and_graph_data(selected_documents, node_ids, graph_data):
if not selected_documents:
return "No documents selected.", json.dumps(graph_data), graph_data
selected_indices = [int(doc.split('.')[0]) - 1 for doc in selected_documents]
selected_node_ids = [node_ids[i] for i in selected_indices]
query = """
MATCH (n)-[r]-(neighbor)
WHERE n.id IN $node_ids
RETURN n.id AS source_id, n AS source_text, labels(n) AS source_type,
neighbor.id AS neighbor_id, neighbor AS neighbor_text,
labels(neighbor) AS neighbor_type, type(r) AS relationship_type
"""
results = neo4j_graph.query(query, {"node_ids": selected_node_ids})
if not results:
return "No neighbors found for the selected documents.", "[]"
neighbor_info = {}
node_set = set([node["id"] for node in graph_data["nodes"]])
for row in results:
source_id = row['source_id']
if source_id not in neighbor_info:
neighbor_info[source_id] = {
'source_type': row["source_type"][0],
'source_text': row['source_text'],
'neighbors': []
}
if source_id not in node_set:
graph_data["nodes"].append({
"id": source_id,
"label": str(row['source_text'])[:30] + "...",
"group": row['source_type'][0],
"title": f"
{row['source_type'][0]}
{row['source_text']}
",
})
node_set.add(source_id)
neighbor_info[source_id]['neighbors'].append(
f"[{row['relationship_type']}] [{row['neighbor_type'][0]}] {str(row['neighbor_text'])[:200]}"
)
if row['neighbor_id'] not in node_set:
graph_data["nodes"].append({
"id": row['neighbor_id'],
"label": str(row['neighbor_text'])[:30] + "...",
"group": row['neighbor_type'][0],
"title": f"
{row['neighbor_type'][0]}
{html.escape(str(row['neighbor_text']))}
",
})
node_set.add(row['neighbor_id'])
edge = {
"from": source_id,
"to" : row['neighbor_id'],
"label": row['relationship_type']
}
if edge not in graph_data['edges']:
graph_data['edges'].append(edge)
output = []
for source_id, info in neighbor_info.items():
output.append(f"Neighbors for: [{info['source_type']}] {str(info['source_text'])[:100]}")
output.extend(info['neighbors'])
output.append("\n\n") # Empty line for separation
formatted_choices = []
node_ids = []
for i, node in enumerate(graph_data['nodes']):
formatted_choices.append(f"{i+1}. {str(node['label'])})")
node_ids.append(node['id'])
return "\n".join(output), json.dumps(graph_data), graph_data, gr.update(choices=formatted_choices, value=[]), node_ids
def save_docs_to_excel(exported_docs : list[dict], exported_relationships : list[dict]):
cleaned_docs = [dict(doc['doc'], **{'id': doc['id'], 'category': doc['category'][0], "relationships" : ""}) for doc in exported_docs]
for relationship in exported_relationships:
for doc in cleaned_docs:
if doc['id'] == relationship['from']:
doc["relationships"] += f"[{relationship['label']}] {relationship['to']}\n"
df = pd.DataFrame(cleaned_docs)
df.to_excel("doc_explorer/exported_docs/docs.xlsx")
return gr.update(value="doc_explorer/exported_docs/docs.xlsx", visible=True)
# JavaScript code for graph visualization
js_code = """
function(graph_data_str) {
if (!graph_data_str) return;
const container = document.getElementById('graph-container');
container.innerHTML = '';
let data;
try {
data = JSON.parse(graph_data_str);
} catch (error) {
console.error("Failed to parse graph data:", error);
container.innerHTML = "Error: Failed to load graph data.";
return;
}
data.nodes.forEach(node => {
const div = document.createElement('div');
div.innerHTML = node.title;
node.title = div.firstChild;
});
const nodes = new vis.DataSet(data.nodes);
const edges = new vis.DataSet(data.edges);
const options = {
nodes: {
shape: 'dot',
size: 16,
font: {
size: 12,
color: '#000000'
},
borderWidth: 2
},
edges: {
width: 1,
font: {
size: 10,
align: 'middle'
},
color: { color: '#7A7A7A', hover: '#2B7CE9' }
},
physics: {
forceAtlas2Based: {
gravitationalConstant: -26,
centralGravity: 0.005,
springLength: 230,
springConstant: 0.18
},
maxVelocity: 146,
solver: 'forceAtlas2Based',
timestep: 0.35,
stabilization: { iterations: 150 }
},
interaction: {
hover: true,
tooltipDelay: 200
}
};
const network = new vis.Network(container, { nodes: nodes, edges: edges }, options);
}
"""
head = """
"""
custom_css = """
#graph-container {
border: 1px solid #ddd;
border-radius: 4px;
}
.vis-tooltip {
font-family: Arial, sans-serif;
padding: 10px;
border-radius: 4px;
background-color: rgba(255, 255, 255, 0.9);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
max-width: 300px;
color: #333;
word-wrap: break-word;
overflow-wrap: break-word;
}
.node-tooltip {
width: 100%;
}
.node-tooltip h3 {
margin: 0 0 5px 0;
font-size: 14px;
color: #333;
}
.node-tooltip p {
margin: 0;
font-size: 12px;
color: #666;
white-space: normal;
}
"""
with gr.Blocks(head=head, css=custom_css) as demo:
with gr.Tab("Search"):
gr.Markdown("# Document Search Engine")
gr.Markdown("Enter a query to search for similar documents. You can filter by document type and use MMR for diverse results.")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(label="Enter your query")
doc_type_input = gr.Dropdown(choices=get_document_types(), label="Select document type", multiselect=True)
with gr.Column(scale=2):
mmr_input = gr.Checkbox(label="Use MMR for diverse results")
lambda_input = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Lambda parameter (MMR diversity)", visible=False)
top_k_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of results")
search_button = gr.Button("Search")
results_output = gr.Textbox(label="Search Results")
selected_documents = gr.Dropdown(label="Select documents to view their neighbors", choices=[], multiselect=True, interactive=True)
with gr.Row():
neighbor_search_button = gr.Button("Find Neighbors")
send_to_export = gr.Button("Send docs to export Tab")
neighbors_output = gr.Textbox(label="Document Neighbors")
graph_data_state = gr.State({"nodes": [], "edges": []})
graph_data_str = gr.Textbox(visible=False)
graph_container = gr.HTML('