neuralworm commited on
Commit
ddfd741
·
1 Parent(s): 1e7ebeb

better graph

Browse files
Files changed (1) hide show
  1. psychohistory.py +107 -84
psychohistory.py CHANGED
@@ -6,158 +6,181 @@ import json
6
  import sys
7
  import random
8
 
9
- def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
10
- """Generates a tree of nodes with positions adjusted on the x-axis, y-axis, and number of nodes on the z-axis."""
11
- if node_count_per_depth is None:
12
- node_count_per_depth = {}
13
-
14
- if depth > max_depth:
15
- return node_count_per_depth
16
-
17
- if depth not in node_count_per_depth:
18
- node_count_per_depth[depth] = 0
19
-
20
- num_children = random.randint(1, max_nodes)
21
- x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
-
23
- for x in x_positions:
24
- node_id = len(G.nodes)
25
- node_count_per_depth[depth] += 1
26
- prob = random.uniform(0, 1)
27
- G.add_node(node_id, pos=(x, prob, depth))
28
- if parent is not None:
29
- G.add_edge(parent, node_id)
30
- generate_tree(x, current_y + 1, depth + 1, max_depth, max_nodes, x_range, G, parent=node_id, node_count_per_depth=node_count_per_depth)
31
-
32
- return node_count_per_depth
33
-
34
-
35
-
36
-
37
  def build_graph_from_json(json_data, G):
38
- """Builds a graph from JSON data, handling subevents recursively."""
39
-
40
  def add_event(parent_id, event_data, depth):
 
 
41
  node_id = len(G.nodes)
42
- prob = event_data['probability'] / 100.0
43
- # Use event_number as the z-coordinate for better visualization
44
- pos = (depth, prob, event_data['event_number'])
45
- label = event_data['name']
46
  G.add_node(node_id, pos=pos, label=label)
47
  if parent_id is not None:
48
- G.add_edge(parent_id, node_id) # Connect to parent
49
 
 
50
  subevents = event_data.get('subevents', {}).get('event', [])
51
  if not isinstance(subevents, list):
52
- subevents = [subevents]
53
 
54
  for subevent in subevents:
55
- add_event(node_id, subevent, depth + 1) # Recursively add subevents
56
-
57
- # Iterate through all top-level events
58
- for event_data in json_data.get('events', {}).values():
59
- add_event(None, event_data, 0) # Add each event as a root node
60
 
 
 
 
61
 
62
 
63
  def find_paths(G):
64
- """Finds paths with highest/lowest probability and longest/shortest durations."""
65
- best_path, worst_path = None, None
66
- longest_path, shortest_path = None, None
67
- best_mean_prob, worst_mean_prob = -1, float('inf')
68
- max_duration, min_duration = -1, float('inf')
69
-
70
- # Use nx.all_pairs_shortest_path for efficiency
71
- all_paths_dict = dict(nx.all_pairs_shortest_path(G))
72
-
73
- for source, paths_from_source in all_paths_dict.items():
74
- for target, path in paths_from_source.items():
75
- if source != target and all('pos' in G.nodes[node] for node in path):
76
- probabilities = [G.nodes[node]['pos'][1] for node in path]
77
- mean_prob = np.mean(probabilities)
78
-
79
- if mean_prob > best_mean_prob:
80
- best_mean_prob = mean_prob
81
- best_path = path
82
- if mean_prob < worst_mean_prob:
83
- worst_mean_prob = mean_prob
84
- worst_path = path
85
-
86
- x_positions = [G.nodes[node]['pos'][0] for node in path]
87
- duration = max(x_positions) - min(x_positions)
88
-
89
- if duration > max_duration:
90
- max_duration = duration
91
- longest_path = path
92
- if duration < min_duration and duration > 0: # Avoid paths with 0 duration
93
- min_duration = duration
94
- shortest_path = path
95
-
96
- return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
99
- """Draws a specific path in 3D."""
 
 
100
  H = G.subgraph(path).copy()
 
101
  pos = nx.get_node_attributes(G, 'pos')
 
 
102
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
103
 
104
  fig = plt.figure(figsize=(16, 12))
105
  ax = fig.add_subplot(111, projection='3d')
106
 
107
- node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
 
 
 
 
 
 
 
 
 
 
 
108
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
109
 
 
110
  for edge in H.edges():
111
  x_start, y_start, z_start = pos[edge[0]]
112
  x_end, y_end, z_end = pos[edge[1]]
113
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
114
 
 
115
  for node, (x, y, z) in pos.items():
116
  if node in path:
117
  ax.text(x, y, z, str(node), fontsize=12, color='black')
118
 
 
119
  ax.set_xlabel('Time (weeks)')
120
  ax.set_ylabel('Event Probability')
121
  ax.set_zlabel('Event Number')
122
  ax.set_title('3D Event Tree - Path')
123
 
124
- plt.savefig(filename, bbox_inches='tight')
125
- plt.close()
126
 
127
 
128
  def draw_global_tree_3d(G, filename='global_tree.png'):
129
- """Draws the entire graph in 3D."""
 
130
  pos = nx.get_node_attributes(G, 'pos')
131
  labels = nx.get_node_attributes(G, 'label')
132
 
 
133
  if not pos:
134
  print("Graph is empty. No nodes to visualize.")
135
  return
136
 
 
137
  x_vals, y_vals, z_vals = zip(*pos.values())
 
138
  fig = plt.figure(figsize=(16, 12))
139
  ax = fig.add_subplot(111, projection='3d')
140
 
141
- node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
 
 
 
 
 
 
 
 
 
 
142
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
143
 
 
144
  for edge in G.edges():
145
  x_start, y_start, z_start = pos[edge[0]]
146
  x_end, y_end, z_end = pos[edge[1]]
147
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
148
 
 
149
  for node, (x, y, z) in pos.items():
150
  label = labels.get(node, f"{node}")
151
  ax.text(x, y, z, label, fontsize=12, color='black')
152
 
 
153
  ax.set_xlabel('Time')
154
  ax.set_ylabel('Probability')
155
  ax.set_zlabel('Event Number')
156
  ax.set_title('3D Event Tree')
157
 
158
- plt.savefig(filename, bbox_inches='tight')
159
- plt.close()
160
-
161
 
162
  def main(json_data):
163
  G = nx.DiGraph()
@@ -189,4 +212,4 @@ def main(json_data):
189
  if shortest_path:
190
  draw_path_3d(G, shortest_path, 'shortest_duration_path.png', 'purple')
191
 
192
- return 'global_tree.png' # Return the filename of the global tree
 
6
  import sys
7
  import random
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def build_graph_from_json(json_data, G):
10
+ """Builds a graph from JSON data."""
 
11
  def add_event(parent_id, event_data, depth):
12
+ """Recursively adds events and subevents to the graph."""
13
+ # Add the current event node
14
  node_id = len(G.nodes)
15
+ prob = event_data['probability'] / 100.0 # Convert percentage to probability
16
+ pos = (depth, prob, event_data['event_number']) # Use event_number for z position
17
+ label = event_data['name'] # Use event name as label
 
18
  G.add_node(node_id, pos=pos, label=label)
19
  if parent_id is not None:
20
+ G.add_edge(parent_id, node_id)
21
 
22
+ # Add child events
23
  subevents = event_data.get('subevents', {}).get('event', [])
24
  if not isinstance(subevents, list):
25
+ subevents = [subevents] # Ensure subevents is a list
26
 
27
  for subevent in subevents:
28
+ add_event(node_id, subevent, depth + 1)
 
 
 
 
29
 
30
+ # Start from the root event (assuming there's only one top-level event)
31
+ root_event = list(json_data.get('events', {}).values())[0]
32
+ add_event(None, root_event, 0)
33
 
34
 
35
  def find_paths(G):
36
+ """Finds the paths with the highest and lowest average probability,
37
+ and the longest and shortest durations in graph G."""
38
+ best_path = None
39
+ worst_path = None
40
+ longest_duration_path = None
41
+ shortest_duration_path = None
42
+ best_mean_prob = -1
43
+ worst_mean_prob = float('inf')
44
+ max_duration = -1
45
+ min_duration = float('inf')
46
+
47
+ for source in G.nodes:
48
+ for target in G.nodes:
49
+ if source != target:
50
+ all_paths = list(nx.all_simple_paths(G, source=source, target=target))
51
+ for path in all_paths:
52
+ # Check if all nodes in the path have the 'pos' attribute
53
+ if not all('pos' in G.nodes[node] for node in path):
54
+ continue # Skip paths with nodes missing the 'pos' attribute
55
+
56
+ # Calculate the mean probability of the path
57
+ probabilities = [G.nodes[node]['pos'][1] for node in path] # Get node probabilities
58
+ mean_prob = np.mean(probabilities)
59
+
60
+ # Evaluate path with the highest mean probability
61
+ if mean_prob > best_mean_prob:
62
+ best_mean_prob = mean_prob
63
+ best_path = path
64
+
65
+ # Evaluate path with the lowest mean probability
66
+ if mean_prob < worst_mean_prob:
67
+ worst_mean_prob = mean_prob
68
+ worst_path = path
69
+
70
+ # Calculate path duration
71
+ x_positions = [G.nodes[node]['pos'][0] for node in path]
72
+ duration = max(x_positions) - min(x_positions)
73
+
74
+ # Evaluate path with the longest duration
75
+ if duration > max_duration:
76
+ max_duration = duration
77
+ longest_duration_path = path
78
+
79
+ # Evaluate path with the shortest duration
80
+ if duration < min_duration:
81
+ min_duration = duration
82
+ shortest_duration_path = path
83
+
84
+ return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path
85
 
86
  def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
87
+ """Draws only the specific path in 3D using networkx and matplotlib
88
+ and saves the figure to a file."""
89
+ # Create a subgraph containing only the nodes and edges of the path
90
  H = G.subgraph(path).copy()
91
+
92
  pos = nx.get_node_attributes(G, 'pos')
93
+
94
+ # Get data for 3D visualization
95
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
96
 
97
  fig = plt.figure(figsize=(16, 12))
98
  ax = fig.add_subplot(111, projection='3d')
99
 
100
+ # Assign colors to nodes based on probability
101
+ node_colors = []
102
+ for node in path:
103
+ prob = G.nodes[node]['pos'][1]
104
+ if prob < 0.33:
105
+ node_colors.append('red')
106
+ elif prob < 0.67:
107
+ node_colors.append('blue')
108
+ else:
109
+ node_colors.append('green')
110
+
111
+ # Draw nodes
112
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
113
 
114
+ # Draw edges
115
  for edge in H.edges():
116
  x_start, y_start, z_start = pos[edge[0]]
117
  x_end, y_end, z_end = pos[edge[1]]
118
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
119
 
120
+ # Add labels to nodes
121
  for node, (x, y, z) in pos.items():
122
  if node in path:
123
  ax.text(x, y, z, str(node), fontsize=12, color='black')
124
 
125
+ # Set labels and title
126
  ax.set_xlabel('Time (weeks)')
127
  ax.set_ylabel('Event Probability')
128
  ax.set_zlabel('Event Number')
129
  ax.set_title('3D Event Tree - Path')
130
 
131
+ plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
132
+ plt.close() # Close the figure to free resources
133
 
134
 
135
  def draw_global_tree_3d(G, filename='global_tree.png'):
136
+ """Draws the entire graph in 3D using networkx and matplotlib
137
+ and saves the figure to a file."""
138
  pos = nx.get_node_attributes(G, 'pos')
139
  labels = nx.get_node_attributes(G, 'label')
140
 
141
+ # Check if the graph is empty
142
  if not pos:
143
  print("Graph is empty. No nodes to visualize.")
144
  return
145
 
146
+ # Get data for 3D visualization
147
  x_vals, y_vals, z_vals = zip(*pos.values())
148
+
149
  fig = plt.figure(figsize=(16, 12))
150
  ax = fig.add_subplot(111, projection='3d')
151
 
152
+ # Assign colors to nodes based on probability
153
+ node_colors = []
154
+ for node, (x, prob, z) in pos.items():
155
+ if prob < 0.33:
156
+ node_colors.append('red')
157
+ elif prob < 0.67:
158
+ node_colors.append('blue')
159
+ else:
160
+ node_colors.append('green')
161
+
162
+ # Draw nodes
163
  ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
164
 
165
+ # Draw edges
166
  for edge in G.edges():
167
  x_start, y_start, z_start = pos[edge[0]]
168
  x_end, y_end, z_end = pos[edge[1]]
169
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
170
 
171
+ # Add labels to nodes
172
  for node, (x, y, z) in pos.items():
173
  label = labels.get(node, f"{node}")
174
  ax.text(x, y, z, label, fontsize=12, color='black')
175
 
176
+ # Set labels and title
177
  ax.set_xlabel('Time')
178
  ax.set_ylabel('Probability')
179
  ax.set_zlabel('Event Number')
180
  ax.set_title('3D Event Tree')
181
 
182
+ plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
183
+ plt.close() # Close the figure to free resources
 
184
 
185
  def main(json_data):
186
  G = nx.DiGraph()
 
212
  if shortest_path:
213
  draw_path_3d(G, shortest_path, 'shortest_duration_path.png', 'purple')
214
 
215
+ return 'global_tree.png' # Return the filename of the global tree