neuralworm commited on
Commit
85c02a6
Β·
verified Β·
1 Parent(s): 36f2148

revert to matplotlib

Browse files
Files changed (1) hide show
  1. psychohistory.py +72 -34
psychohistory.py CHANGED
@@ -1,5 +1,5 @@
1
- import plotly.graph_objects as go # Import Plotly for interactive plots
2
- from mpl_toolkits.mplot3d import Axes3D # Not needed anymore, but you can keep it if you use it elsewhere
3
  import networkx as nx
4
  import numpy as np
5
  import json
@@ -32,6 +32,8 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
32
  return node_count_per_depth
33
 
34
 
 
 
35
  def build_graph_from_json(json_data, G):
36
  """Builds a graph from JSON data, handling subevents recursively."""
37
 
@@ -57,6 +59,7 @@ def build_graph_from_json(json_data, G):
57
  add_event(None, event_data, 0) # Add each event as a root node
58
 
59
 
 
60
  def find_paths(G):
61
  """Finds paths with highest/lowest probability and longest/shortest durations."""
62
  best_path, worst_path = None, None
@@ -92,63 +95,98 @@ def find_paths(G):
92
 
93
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
94
 
95
- def draw_graph_plotly(G, title="3D Event Tree", highlight_color='gray'):
96
- """Draws the graph in 3D using Plotly and returns the HTML string."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  pos = nx.get_node_attributes(G, 'pos')
98
  labels = nx.get_node_attributes(G, 'label')
99
 
100
  if not pos:
101
  print("Graph is empty. No nodes to visualize.")
102
- return ""
103
 
104
  x_vals, y_vals, z_vals = zip(*pos.values())
 
 
105
 
106
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
107
- node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
108
- marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
109
- text=list(labels.values()), textposition='top center', hoverinfo='text')
110
 
111
- edge_traces = []
112
  for edge in G.edges():
113
  x_start, y_start, z_start = pos[edge[0]]
114
  x_end, y_end, z_end = pos[edge[1]]
115
- edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
116
- mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
117
- edge_traces.append(edge_trace)
118
 
119
- layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
120
- title=title)
121
- fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
122
 
123
- # Convert Plotly figure to HTML string
124
- html_str = fig.to_html(full_html=False, include_plotlyjs='cdn')
125
- return html_str
 
 
 
 
126
 
127
 
128
  def main(json_data):
129
  G = nx.DiGraph()
130
- build_graph_from_json(json_data, G)
131
 
132
- # Find the best, worst, longest, and shortest paths FIRST
133
- best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
134
 
135
- # Generate the HTML string for the Plotly graph
136
- html_graph = draw_graph_plotly(G)
137
 
138
- # Now you can use the path variables
139
  if best_path:
140
- best_path_graph = draw_graph_plotly(G.subgraph(best_path), title="Best Path", highlight_color='blue')
141
- html_graph += best_path_graph
142
  if worst_path:
143
- worst_path_graph = draw_graph_plotly(G.subgraph(worst_path), title="Worst Path", highlight_color='red')
144
- html_graph += worst_path_graph
145
  if longest_path:
146
- longest_path_graph = draw_graph_plotly(G.subgraph(longest_path), title="Longest Path", highlight_color='green')
147
- html_graph += longest_path_graph
148
  if shortest_path:
149
- shortest_path_graph = draw_graph_plotly(G.subgraph(shortest_path), title="Shortest Path", highlight_color='purple')
150
- html_graph += shortest_path_graph
151
-
152
- return html_graph # Return the HTML string
153
 
 
 
 
 
 
 
 
 
154
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
4
  import numpy as np
5
  import json
 
32
  return node_count_per_depth
33
 
34
 
35
+
36
+
37
  def build_graph_from_json(json_data, G):
38
  """Builds a graph from JSON data, handling subevents recursively."""
39
 
 
59
  add_event(None, event_data, 0) # Add each event as a root node
60
 
61
 
62
+
63
  def find_paths(G):
64
  """Finds paths with highest/lowest probability and longest/shortest durations."""
65
  best_path, worst_path = None, None
 
95
 
96
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
97
 
98
+ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
99
+ """Draws a specific path in 3D."""
100
+ H = G.subgraph(path).copy()
101
+ pos = nx.get_node_attributes(G, 'pos')
102
+ x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
103
+
104
+ fig = plt.figure(figsize=(16, 12))
105
+ ax = fig.add_subplot(111, projection='3d')
106
+
107
+ node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
108
+ ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
109
+
110
+ for edge in H.edges():
111
+ x_start, y_start, z_start = pos[edge[0]]
112
+ x_end, y_end, z_end = pos[edge[1]]
113
+ ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
114
+
115
+ for node, (x, y, z) in pos.items():
116
+ if node in path:
117
+ ax.text(x, y, z, str(node), fontsize=12, color='black')
118
+
119
+ ax.set_xlabel('Time (weeks)')
120
+ ax.set_ylabel('Event Probability')
121
+ ax.set_zlabel('Event Number')
122
+ ax.set_title('3D Event Tree - Path')
123
+
124
+ plt.savefig(filename, bbox_inches='tight')
125
+ plt.close()
126
+
127
+
128
+ def draw_global_tree_3d(G, filename='global_tree.png'):
129
+ """Draws the entire graph in 3D."""
130
  pos = nx.get_node_attributes(G, 'pos')
131
  labels = nx.get_node_attributes(G, 'label')
132
 
133
  if not pos:
134
  print("Graph is empty. No nodes to visualize.")
135
+ return
136
 
137
  x_vals, y_vals, z_vals = zip(*pos.values())
138
+ fig = plt.figure(figsize=(16, 12))
139
+ ax = fig.add_subplot(111, projection='3d')
140
 
141
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
142
+ ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
 
 
143
 
 
144
  for edge in G.edges():
145
  x_start, y_start, z_start = pos[edge[0]]
146
  x_end, y_end, z_end = pos[edge[1]]
147
+ ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
 
 
148
 
149
+ for node, (x, y, z) in pos.items():
150
+ label = labels.get(node, f"{node}")
151
+ ax.text(x, y, z, label, fontsize=12, color='black')
152
 
153
+ ax.set_xlabel('Time')
154
+ ax.set_ylabel('Probability')
155
+ ax.set_zlabel('Event Number')
156
+ ax.set_title('3D Event Tree')
157
+
158
+ plt.savefig(filename, bbox_inches='tight')
159
+ plt.close()
160
 
161
 
162
  def main(json_data):
163
  G = nx.DiGraph()
164
+ build_graph_from_json(json_data, G) # Build graph from the provided JSON data
165
 
166
+ draw_global_tree_3d(G, filename='global_tree.png')
 
167
 
168
+ best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
 
169
 
 
170
  if best_path:
171
+ print(f"\nPath with the highest average probability: {' -> '.join(map(str, best_path))}")
172
+ print(f"Average probability: {best_mean_prob:.2f}")
173
  if worst_path:
174
+ print(f"\nPath with the lowest average probability: {' -> '.join(map(str, worst_path))}")
175
+ print(f"Average probability: {worst_mean_prob:.2f}")
176
  if longest_path:
177
+ print(f"\nPath with the longest duration: {' -> '.join(map(str, longest_path))}")
178
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_path) - min(G.nodes[node]['pos'][0] for node in longest_path):.2f}")
179
  if shortest_path:
180
+ print(f"\nPath with the shortest duration: {' -> '.join(map(str, shortest_path))}")
181
+ print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
 
 
182
 
183
+ if best_path:
184
+ draw_path_3d(G, best_path, 'best_path.png', 'blue')
185
+ if worst_path:
186
+ draw_path_3d(G, worst_path, 'worst_path.png', 'red')
187
+ if longest_path:
188
+ draw_path_3d(G, longest_path, 'longest_duration_path.png', 'green')
189
+ if shortest_path:
190
+ draw_path_3d(G, shortest_path, 'shortest_duration_path.png', 'purple')
191
 
192
+ return 'global_tree.png' # Return the filename of the global tree