SarthakBhatore commited on
Commit
18b37d7
·
verified ·
1 Parent(s): 4cf8cef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ import networkx as nx
6
+ import seaborn as sns
7
+ from pathlib import Path
8
+ from langchain.document_loaders import DirectoryLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from pyvis.network import Network
11
+ from helpers.df_helpers import documents2Dataframe, df2Graph, graph2Df
12
+ import gradio as gr
13
+ import logging
14
+
15
+ # Constants
16
+ CHUNK_SIZE = 1500
17
+ CHUNK_OVERLAP = 150
18
+ WEIGHT_MULTIPLIER = 4
19
+ COLOR_PALETTE = "hls"
20
+ GRAPH_OUTPUT_DIRECTORY = "./docs/index.html"
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def colors2Community(communities) -> pd.DataFrame:
27
+ palette = sns.color_palette(COLOR_PALETTE, len(communities)).as_hex()
28
+ random.shuffle(palette)
29
+ rows = [{"node": node, "color": color, "group": group + 1}
30
+ for group, community in enumerate(communities)
31
+ for node, color in zip(community, palette)]
32
+ return pd.DataFrame(rows)
33
+
34
+ def contextual_proximity(df: pd.DataFrame) -> pd.DataFrame:
35
+ dfg_long = pd.melt(df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node").drop(columns=["variable"])
36
+ dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id", suffixes=("_1", "_2"))
37
+ dfg_wide = dfg_wide[dfg_wide["node_1"] != dfg_wide["node_2"]].reset_index(drop=True)
38
+ dfg2 = dfg_wide.groupby(["node_1", "node_2"]).agg({"chunk_id": [",".join, "count"]}).reset_index()
39
+ dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
40
+ dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
41
+ dfg2 = dfg2[dfg2["count"] != 1]
42
+ dfg2["edge"] = "contextual proximity"
43
+ return dfg2
44
+
45
+ def load_documents(input_dir):
46
+ loader = DirectoryLoader(input_dir, show_progress=True)
47
+ return loader.load()
48
+
49
+ def split_documents(documents):
50
+ splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len, is_separator_regex=False)
51
+ return splitter.split_documents(documents)
52
+
53
+ def save_dataframes(df, dfg1, output_dir):
54
+ os.makedirs(output_dir, exist_ok=True)
55
+ dfg1.to_csv(output_dir / "graph.csv", sep="|", index=False)
56
+ df.to_csv(output_dir / "chunks.csv", sep="|", index=False)
57
+
58
+ def load_dataframes(output_dir):
59
+ df = pd.read_csv(output_dir / "chunks.csv", sep="|")
60
+ dfg1 = pd.read_csv(output_dir / "graph.csv", sep="|")
61
+ return df, dfg1
62
+
63
+ def build_graph(dfg):
64
+ nodes = pd.concat([dfg['node_1'], dfg['node_2']], axis=0).unique()
65
+ G = nx.Graph()
66
+ G.add_nodes_from(nodes)
67
+ for _, row in dfg.iterrows():
68
+ G.add_edge(row["node_1"], row["node_2"], title=row["edge"], weight=row['count'] / WEIGHT_MULTIPLIER)
69
+ return G
70
+
71
+ def visualize_graph(G, communities):
72
+ colors = colors2Community(communities)
73
+ for _, row in colors.iterrows():
74
+ G.nodes[row['node']].update(group=row['group'], color=row['color'], size=G.degree[row['node']])
75
+ nt = Network(notebook=False, cdn_resources="remote", height="900px", width="100%", select_menu=True)
76
+ nt.from_nx(G)
77
+ nt.force_atlas_2based(central_gravity=0.015, gravity=-31)
78
+ nt.show_buttons(filter_=["physics"])
79
+ html = nt.generate_html().replace("'", "\"")
80
+ return f"""<iframe style="width: 100%; height: 600px; margin:0 auto"
81
+ name="result" allow="midi; geolocation; microphone; camera;
82
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
83
+ allow-scripts allow-same-origin allow-popups
84
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen
85
+ allowpaymentrequest frameborder="0" srcdoc='{html}'></iframe>"""
86
+
87
+ def process_pdfs(input_dir, output_dir, regenerate=False):
88
+ if regenerate:
89
+ documents = load_documents(input_dir)
90
+ pages = split_documents(documents)
91
+ df = documents2Dataframe(pages)
92
+ concepts_list = df2Graph(df, model='zephyr:latest')
93
+ dfg1 = graph2Df(concepts_list)
94
+ save_dataframes(df, dfg1, output_dir)
95
+ else:
96
+ df, dfg1 = load_dataframes(output_dir)
97
+
98
+ dfg1.replace("", np.nan, inplace=True)
99
+ dfg1.dropna(subset=["node_1", "node_2", 'edge'], inplace=True)
100
+ dfg1['count'] = WEIGHT_MULTIPLIER
101
+ dfg2 = contextual_proximity(dfg1)
102
+ dfg = pd.concat([dfg1, dfg2], axis=0).groupby(["node_1", "node_2"]).agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'}).reset_index()
103
+ G = build_graph(dfg)
104
+
105
+ communities_generator = nx.community.girvan_newman(G)
106
+ next_level_communities = next(communities_generator)
107
+ next_level_communities = next(communities_generator) # Two levels of communities
108
+ communities = sorted(map(sorted, next_level_communities))
109
+ logger.info(f"Number of Communities = {len(communities)}")
110
+ logger.info(communities)
111
+
112
+ html = visualize_graph(G, communities)
113
+ return html
114
+
115
+ def main():
116
+ data_dir = "cureus"
117
+ input_dir = Path(f"./data_input/{data_dir}")
118
+ output_dir = Path(f"./data_output/{data_dir}")
119
+ html = process_pdfs(input_dir, output_dir, regenerate=False)
120
+
121
+ demo = gr.Interface(fn=lambda: html, inputs=None, outputs=gr.HTML(), title="Text to knowledge graph", allow_flagging='never')
122
+ demo.launch()
123
+
124
+ if __name__ == "__main__":
125
+ main()