Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import networkx as nx
|
6 |
+
import seaborn as sns
|
7 |
+
from pathlib import Path
|
8 |
+
from langchain.document_loaders import DirectoryLoader
|
9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
+
from pyvis.network import Network
|
11 |
+
from helpers.df_helpers import documents2Dataframe, df2Graph, graph2Df
|
12 |
+
import gradio as gr
|
13 |
+
import logging
|
14 |
+
|
15 |
+
# Constants
|
16 |
+
CHUNK_SIZE = 1500
|
17 |
+
CHUNK_OVERLAP = 150
|
18 |
+
WEIGHT_MULTIPLIER = 4
|
19 |
+
COLOR_PALETTE = "hls"
|
20 |
+
GRAPH_OUTPUT_DIRECTORY = "./docs/index.html"
|
21 |
+
|
22 |
+
# Configure logging
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def colors2Community(communities) -> pd.DataFrame:
|
27 |
+
palette = sns.color_palette(COLOR_PALETTE, len(communities)).as_hex()
|
28 |
+
random.shuffle(palette)
|
29 |
+
rows = [{"node": node, "color": color, "group": group + 1}
|
30 |
+
for group, community in enumerate(communities)
|
31 |
+
for node, color in zip(community, palette)]
|
32 |
+
return pd.DataFrame(rows)
|
33 |
+
|
34 |
+
def contextual_proximity(df: pd.DataFrame) -> pd.DataFrame:
|
35 |
+
dfg_long = pd.melt(df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node").drop(columns=["variable"])
|
36 |
+
dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id", suffixes=("_1", "_2"))
|
37 |
+
dfg_wide = dfg_wide[dfg_wide["node_1"] != dfg_wide["node_2"]].reset_index(drop=True)
|
38 |
+
dfg2 = dfg_wide.groupby(["node_1", "node_2"]).agg({"chunk_id": [",".join, "count"]}).reset_index()
|
39 |
+
dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
|
40 |
+
dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
|
41 |
+
dfg2 = dfg2[dfg2["count"] != 1]
|
42 |
+
dfg2["edge"] = "contextual proximity"
|
43 |
+
return dfg2
|
44 |
+
|
45 |
+
def load_documents(input_dir):
|
46 |
+
loader = DirectoryLoader(input_dir, show_progress=True)
|
47 |
+
return loader.load()
|
48 |
+
|
49 |
+
def split_documents(documents):
|
50 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len, is_separator_regex=False)
|
51 |
+
return splitter.split_documents(documents)
|
52 |
+
|
53 |
+
def save_dataframes(df, dfg1, output_dir):
|
54 |
+
os.makedirs(output_dir, exist_ok=True)
|
55 |
+
dfg1.to_csv(output_dir / "graph.csv", sep="|", index=False)
|
56 |
+
df.to_csv(output_dir / "chunks.csv", sep="|", index=False)
|
57 |
+
|
58 |
+
def load_dataframes(output_dir):
|
59 |
+
df = pd.read_csv(output_dir / "chunks.csv", sep="|")
|
60 |
+
dfg1 = pd.read_csv(output_dir / "graph.csv", sep="|")
|
61 |
+
return df, dfg1
|
62 |
+
|
63 |
+
def build_graph(dfg):
|
64 |
+
nodes = pd.concat([dfg['node_1'], dfg['node_2']], axis=0).unique()
|
65 |
+
G = nx.Graph()
|
66 |
+
G.add_nodes_from(nodes)
|
67 |
+
for _, row in dfg.iterrows():
|
68 |
+
G.add_edge(row["node_1"], row["node_2"], title=row["edge"], weight=row['count'] / WEIGHT_MULTIPLIER)
|
69 |
+
return G
|
70 |
+
|
71 |
+
def visualize_graph(G, communities):
|
72 |
+
colors = colors2Community(communities)
|
73 |
+
for _, row in colors.iterrows():
|
74 |
+
G.nodes[row['node']].update(group=row['group'], color=row['color'], size=G.degree[row['node']])
|
75 |
+
nt = Network(notebook=False, cdn_resources="remote", height="900px", width="100%", select_menu=True)
|
76 |
+
nt.from_nx(G)
|
77 |
+
nt.force_atlas_2based(central_gravity=0.015, gravity=-31)
|
78 |
+
nt.show_buttons(filter_=["physics"])
|
79 |
+
html = nt.generate_html().replace("'", "\"")
|
80 |
+
return f"""<iframe style="width: 100%; height: 600px; margin:0 auto"
|
81 |
+
name="result" allow="midi; geolocation; microphone; camera;
|
82 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
83 |
+
allow-scripts allow-same-origin allow-popups
|
84 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen
|
85 |
+
allowpaymentrequest frameborder="0" srcdoc='{html}'></iframe>"""
|
86 |
+
|
87 |
+
def process_pdfs(input_dir, output_dir, regenerate=False):
|
88 |
+
if regenerate:
|
89 |
+
documents = load_documents(input_dir)
|
90 |
+
pages = split_documents(documents)
|
91 |
+
df = documents2Dataframe(pages)
|
92 |
+
concepts_list = df2Graph(df, model='zephyr:latest')
|
93 |
+
dfg1 = graph2Df(concepts_list)
|
94 |
+
save_dataframes(df, dfg1, output_dir)
|
95 |
+
else:
|
96 |
+
df, dfg1 = load_dataframes(output_dir)
|
97 |
+
|
98 |
+
dfg1.replace("", np.nan, inplace=True)
|
99 |
+
dfg1.dropna(subset=["node_1", "node_2", 'edge'], inplace=True)
|
100 |
+
dfg1['count'] = WEIGHT_MULTIPLIER
|
101 |
+
dfg2 = contextual_proximity(dfg1)
|
102 |
+
dfg = pd.concat([dfg1, dfg2], axis=0).groupby(["node_1", "node_2"]).agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'}).reset_index()
|
103 |
+
G = build_graph(dfg)
|
104 |
+
|
105 |
+
communities_generator = nx.community.girvan_newman(G)
|
106 |
+
next_level_communities = next(communities_generator)
|
107 |
+
next_level_communities = next(communities_generator) # Two levels of communities
|
108 |
+
communities = sorted(map(sorted, next_level_communities))
|
109 |
+
logger.info(f"Number of Communities = {len(communities)}")
|
110 |
+
logger.info(communities)
|
111 |
+
|
112 |
+
html = visualize_graph(G, communities)
|
113 |
+
return html
|
114 |
+
|
115 |
+
def main():
|
116 |
+
data_dir = "cureus"
|
117 |
+
input_dir = Path(f"./data_input/{data_dir}")
|
118 |
+
output_dir = Path(f"./data_output/{data_dir}")
|
119 |
+
html = process_pdfs(input_dir, output_dir, regenerate=False)
|
120 |
+
|
121 |
+
demo = gr.Interface(fn=lambda: html, inputs=None, outputs=gr.HTML(), title="Text to knowledge graph", allow_flagging='never')
|
122 |
+
demo.launch()
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|