|
import gradio as gr |
|
from vectorstore import FAISSVectorStore |
|
from langchain_community.graphs import Neo4jGraph |
|
import os |
|
import json |
|
import html |
|
import pandas as pd |
|
import time |
|
|
|
time.sleep(30) |
|
|
|
os.environ["http_proxy"] = "185.46.212.98:80" |
|
os.environ["https_proxy"] = "185.46.212.98:80" |
|
os.environ["NO_PROXY"] = "localhost" |
|
|
|
neo4j_graph = Neo4jGraph( |
|
url=os.getenv("NEO4J_URI", "bolt://localhost:7999"), |
|
username=os.getenv("NEO4J_USERNAME", "neo4j"), |
|
password=os.getenv("NEO4J_PASSWORD", "graph_test") |
|
) |
|
|
|
|
|
vector_store = FAISSVectorStore(model_name='Alibaba-NLP/gte-large-en-v1.5', dimension=1024, trust_remote_code=True, embedding_file="/usr/src/app/doc_explorer/embeddings_full.npy") |
|
|
|
|
|
def get_document_types(): |
|
query = """ |
|
MATCH (n) |
|
RETURN DISTINCT labels(n) AS document_type |
|
""" |
|
result = neo4j_graph.query(query) |
|
return [row["document_type"][0] for row in result] |
|
|
|
def search(query, doc_types, use_mmr, lambda_param, top_k): |
|
results, node_ids = vector_store.similarity_search( |
|
query, |
|
k=top_k, |
|
use_mmr=use_mmr, |
|
lambda_param=lambda_param if use_mmr else None, |
|
doc_types=doc_types, |
|
neo4j_graph=neo4j_graph |
|
) |
|
|
|
formatted_results = [] |
|
formatted_choices = [] |
|
for i, result in enumerate(results): |
|
formatted_results.append(f"{i}. {result['document']} (Score: {result['score']:.4f})") |
|
formatted_choices.append(f"{i}. {str(result['document'])[:100]} (Score: {result['score']:.4f})") |
|
return formatted_results, gr.update(choices=formatted_choices, value=[]), node_ids |
|
|
|
def get_docs_from_ids(graph_data : dict): |
|
node_ids = [node["id"] for node in graph_data["nodes"]] |
|
print(node_ids) |
|
query = """ |
|
MATCH (n) |
|
WHERE n.id IN $node_ids |
|
RETURN n.id AS id, n AS doc, labels(n) AS category |
|
""" |
|
|
|
return neo4j_graph.query(query, {"node_ids" : node_ids}), graph_data["edges"] |
|
|
|
def get_neighbors_and_graph_data(selected_documents, node_ids, graph_data): |
|
if not selected_documents: |
|
return "No documents selected.", json.dumps(graph_data), graph_data |
|
|
|
selected_indices = [int(doc.split('.')[0]) - 1 for doc in selected_documents] |
|
selected_node_ids = [node_ids[i] for i in selected_indices] |
|
|
|
query = """ |
|
MATCH (n)-[r]-(neighbor) |
|
WHERE n.id IN $node_ids |
|
RETURN n.id AS source_id, n AS source_text, labels(n) AS source_type, |
|
neighbor.id AS neighbor_id, neighbor AS neighbor_text, |
|
labels(neighbor) AS neighbor_type, type(r) AS relationship_type |
|
""" |
|
results = neo4j_graph.query(query, {"node_ids": selected_node_ids}) |
|
|
|
if not results: |
|
return "No neighbors found for the selected documents.", "[]" |
|
|
|
neighbor_info = {} |
|
node_set = set([node["id"] for node in graph_data["nodes"]]) |
|
|
|
for row in results: |
|
source_id = row['source_id'] |
|
if source_id not in neighbor_info: |
|
neighbor_info[source_id] = { |
|
'source_type': row["source_type"][0], |
|
'source_text': row['source_text'], |
|
'neighbors': [] |
|
} |
|
if source_id not in node_set: |
|
graph_data["nodes"].append({ |
|
"id": source_id, |
|
"label": str(row['source_text'])[:30] + "...", |
|
"group": row['source_type'][0], |
|
"title": f"<div class='node-tooltip'><h3>{row['source_type'][0]}</h3><p>{row['source_text']}</p></div>", |
|
}) |
|
node_set.add(source_id) |
|
|
|
neighbor_info[source_id]['neighbors'].append( |
|
f"[{row['relationship_type']}] [{row['neighbor_type'][0]}] {str(row['neighbor_text'])[:200]}" |
|
) |
|
|
|
if row['neighbor_id'] not in node_set: |
|
graph_data["nodes"].append({ |
|
"id": row['neighbor_id'], |
|
"label": str(row['neighbor_text'])[:30] + "...", |
|
"group": row['neighbor_type'][0], |
|
"title": f"<div class='node-tooltip'><h3>{row['neighbor_type'][0]}</h3><p>{html.escape(str(row['neighbor_text']))}</p></div>", |
|
}) |
|
node_set.add(row['neighbor_id']) |
|
|
|
edge = { |
|
"from": source_id, |
|
"to" : row['neighbor_id'], |
|
"label": row['relationship_type'] |
|
} |
|
if edge not in graph_data['edges']: |
|
graph_data['edges'].append(edge) |
|
|
|
output = [] |
|
for source_id, info in neighbor_info.items(): |
|
output.append(f"Neighbors for: [{info['source_type']}] {str(info['source_text'])[:100]}") |
|
output.extend(info['neighbors']) |
|
output.append("\n\n") |
|
|
|
formatted_choices = [] |
|
node_ids = [] |
|
for i, node in enumerate(graph_data['nodes']): |
|
formatted_choices.append(f"{i+1}. {str(node['label'])})") |
|
node_ids.append(node['id']) |
|
|
|
return "\n".join(output), json.dumps(graph_data), graph_data, gr.update(choices=formatted_choices, value=[]), node_ids |
|
|
|
def save_docs_to_excel(exported_docs : list[dict], exported_relationships : list[dict]): |
|
cleaned_docs = [dict(doc['doc'], **{'id': doc['id'], 'category': doc['category'][0], "relationships" : ""}) for doc in exported_docs] |
|
for relationship in exported_relationships: |
|
for doc in cleaned_docs: |
|
if doc['id'] == relationship['from']: |
|
doc["relationships"] += f"[{relationship['label']}] {relationship['to']}\n" |
|
|
|
df = pd.DataFrame(cleaned_docs) |
|
df.to_excel("doc_explorer/exported_docs/docs.xlsx") |
|
return gr.update(value="doc_explorer/exported_docs/docs.xlsx", visible=True) |
|
|
|
|
|
js_code = """ |
|
function(graph_data_str) { |
|
if (!graph_data_str) return; |
|
const container = document.getElementById('graph-container'); |
|
container.innerHTML = ''; |
|
let data; |
|
try { |
|
data = JSON.parse(graph_data_str); |
|
} catch (error) { |
|
console.error("Failed to parse graph data:", error); |
|
container.innerHTML = "Error: Failed to load graph data."; |
|
return; |
|
} |
|
|
|
data.nodes.forEach(node => { |
|
const div = document.createElement('div'); |
|
div.innerHTML = node.title; |
|
node.title = div.firstChild; |
|
}); |
|
|
|
const nodes = new vis.DataSet(data.nodes); |
|
const edges = new vis.DataSet(data.edges); |
|
const options = { |
|
nodes: { |
|
shape: 'dot', |
|
size: 16, |
|
font: { |
|
size: 12, |
|
color: '#000000' |
|
}, |
|
borderWidth: 2 |
|
}, |
|
edges: { |
|
width: 1, |
|
font: { |
|
size: 10, |
|
align: 'middle' |
|
}, |
|
color: { color: '#7A7A7A', hover: '#2B7CE9' } |
|
}, |
|
physics: { |
|
forceAtlas2Based: { |
|
gravitationalConstant: -26, |
|
centralGravity: 0.005, |
|
springLength: 230, |
|
springConstant: 0.18 |
|
}, |
|
maxVelocity: 146, |
|
solver: 'forceAtlas2Based', |
|
timestep: 0.35, |
|
stabilization: { iterations: 150 } |
|
}, |
|
interaction: { |
|
hover: true, |
|
tooltipDelay: 200 |
|
} |
|
}; |
|
const network = new vis.Network(container, { nodes: nodes, edges: edges }, options); |
|
} |
|
""" |
|
|
|
head = """ |
|
<script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script> |
|
<link href="https://unpkg.com/vis-network/styles/vis-network.min.css" rel="stylesheet" type="text/css" /> |
|
""" |
|
|
|
custom_css = """ |
|
#graph-container { |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
} |
|
.vis-tooltip { |
|
font-family: Arial, sans-serif; |
|
padding: 10px; |
|
border-radius: 4px; |
|
background-color: rgba(255, 255, 255, 0.9); |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
max-width: 300px; |
|
color: #333; |
|
word-wrap: break-word; |
|
overflow-wrap: break-word; |
|
} |
|
.node-tooltip { |
|
width: 100%; |
|
} |
|
.node-tooltip h3 { |
|
margin: 0 0 5px 0; |
|
font-size: 14px; |
|
color: #333; |
|
} |
|
.node-tooltip p { |
|
margin: 0; |
|
font-size: 12px; |
|
color: #666; |
|
white-space: normal; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(head=head, css=custom_css) as demo: |
|
with gr.Tab("Search"): |
|
|
|
gr.Markdown("# Document Search Engine") |
|
gr.Markdown("Enter a query to search for similar documents. You can filter by document type and use MMR for diverse results.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
query_input = gr.Textbox(label="Enter your query") |
|
doc_type_input = gr.Dropdown(choices=get_document_types(), label="Select document type", multiselect=True) |
|
with gr.Column(scale=2): |
|
mmr_input = gr.Checkbox(label="Use MMR for diverse results") |
|
lambda_input = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Lambda parameter (MMR diversity)", visible=False) |
|
top_k_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of results") |
|
|
|
search_button = gr.Button("Search") |
|
results_output = gr.Textbox(label="Search Results") |
|
|
|
selected_documents = gr.Dropdown(label="Select documents to view their neighbors", choices=[], multiselect=True, interactive=True) |
|
|
|
with gr.Row(): |
|
neighbor_search_button = gr.Button("Find Neighbors") |
|
send_to_export = gr.Button("Send docs to export Tab") |
|
|
|
neighbors_output = gr.Textbox(label="Document Neighbors") |
|
|
|
graph_data_state = gr.State({"nodes": [], "edges": []}) |
|
graph_data_str = gr.Textbox(visible=False) |
|
graph_container = gr.HTML('<div id="graph-container" style="height: 600px;"> Hey ! </div>') |
|
|
|
node_ids = gr.State([]) |
|
exported_docs = gr.State([]) |
|
exported_relationships = gr.State([]) |
|
|
|
def update_lambda_visibility(use_mmr): |
|
return gr.update(visible=use_mmr) |
|
|
|
mmr_input.change(fn=update_lambda_visibility, inputs=mmr_input, outputs=lambda_input) |
|
|
|
search_button.click( |
|
fn=search, |
|
inputs=[query_input, doc_type_input, mmr_input, lambda_input, top_k_input], |
|
outputs=[results_output, selected_documents, node_ids] |
|
) |
|
|
|
neighbor_search_button.click( |
|
fn=get_neighbors_and_graph_data, |
|
inputs=[selected_documents, node_ids, graph_data_state], |
|
outputs=[neighbors_output, graph_data_str, graph_data_state, selected_documents, node_ids] |
|
).then( |
|
fn=None, |
|
inputs=graph_data_str, |
|
outputs=None, |
|
js=js_code, |
|
) |
|
|
|
send_to_export.click( |
|
fn=get_docs_from_ids, |
|
inputs=graph_data_state, |
|
outputs=[exported_docs, exported_relationships] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Export"): |
|
with gr.Row(): |
|
exported_docs_btn = gr.Button("Display exported docs") |
|
exported_excel_btn = gr.Button("Export to excel") |
|
exported_excel = gr.File(visible=False) |
|
|
|
exported_docs_display = gr.Markdown(visible=False) |
|
|
|
exported_docs_btn.click( |
|
fn= lambda docs: gr.update(value='\n\n'.join([f"[{doc['category'][0]}]\n{doc['doc']}\n\n" for doc in docs]), visible=True), |
|
inputs=exported_docs, |
|
outputs=exported_docs_display |
|
) |
|
exported_excel_btn.click( |
|
fn=save_docs_to_excel, |
|
inputs=[exported_docs, exported_relationships], |
|
outputs=exported_excel |
|
) |
|
|
|
demo.launch() |
|
|