akshatsanghvi commited on
Commit
fb793b8
·
1 Parent(s): 531b844

Create character_network_generator.py

Browse files
characters/character_network_generator.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import networkx as nx
4
+ from pyvis.network import Network
5
+
6
+ class CharacterNetworkGenerator:
7
+ def __init__(self):
8
+ pass
9
+
10
+ def generate_char_network(df):
11
+
12
+ windows = 10
13
+ entity_relationship = []
14
+
15
+ for row in df["chars"]:
16
+ prev_entity_window = []
17
+
18
+ for sentence in row:
19
+
20
+ # each sentence = ["Ted", "Lilly"]
21
+ prev_entity_window.append(list(sentence))
22
+
23
+ # We keep only the last 10 entities as previous.
24
+ prev_entity_window = prev_entity_window[-windows:]
25
+
26
+ # Flatten 2D list into 1D list
27
+ prev_entity_flattened = sum(prev_entity_window, [])
28
+
29
+ # Build relationship for each entity.
30
+ for entity in sentence:
31
+ # Check each entity with all previous 10 entities.
32
+ for entity_in_window in prev_entity_flattened:
33
+
34
+ # if they aren't same, append them because they are related.
35
+ if entity != entity_in_window:
36
+
37
+ # Sort them because (ted, lilly is same as lilly, ted.)
38
+ entity_relationship.append(sorted([entity, entity_in_window]))
39
+
40
+ relationship_df = pd.DataFrame({"value": entity_relationship})
41
+ relationship_df["source"] = relationship_df["value"].apply(lambda x: x[0])
42
+ relationship_df["target"] = relationship_df["value"].apply(lambda x: x[1])
43
+ relationship_df = relationship_df.groupby(["source", "target"]).count().reset_index()
44
+ relationship_df = relationship_df.sort_values("value", ascending=False)
45
+
46
+ return relationship_df
47
+
48
+ def draw_char_network(self, df):
49
+
50
+ df = df.sort_values("value", ascending=False).head(200)
51
+
52
+ G = nx.from_pandas_edgelist(
53
+ df,
54
+ source="source",
55
+ target="target",
56
+ edge_attr="value",
57
+ create_using=nx.Graph()
58
+ )
59
+
60
+ net = Network(
61
+ notebook=True, width="1000px",
62
+ height="700px",
63
+ bgcolor="#222222",
64
+ font_color="white",
65
+ cdn_resources="remote"
66
+ )
67
+
68
+ node_degree = dict(G.degree)
69
+
70
+ nx.set_node_attributes(G, node_degree, "size")
71
+ net.from_nx(G)
72
+
73
+ html = net.generate_html()
74
+ html = html.replace("'", "\"")
75
+
76
+ output_html = f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
77
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
78
+ allow-scripts allow-same-origin allow-popups
79
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
80
+ allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
81
+
82
+ return output_html
83
+
84
+ def defaultGraph(self, save_path):
85
+
86
+
87
+ if save_path and not save_path.endswith(".html"):
88
+ save_path += "network.html"
89
+
90
+ with open("characters\himym.html", "r") as f:
91
+ html = f.read()
92
+
93
+ html = html.replace("'", "\"")
94
+
95
+ output_html = f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
96
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
97
+ allow-scripts allow-same-origin allow-popups
98
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
99
+ allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
100
+
101
+ try:
102
+ if save_path:
103
+ with open("../" + save_path, "w") as f:
104
+ f.write(output_html)
105
+ except:
106
+ pass
107
+ return output_html
108
+
109
+ print(CharacterNetworkGenerator().defaultGraph("cache"))