eaglelandsonce commited on
Commit
0e91aec
1 Parent(s): 673466f

Create 17_Graph2.py

Browse files
Files changed (1) hide show
  1. pages/17_Graph2.py +110 -0
pages/17_Graph2.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import tensorflow_gnn as tfgnn
4
+ from tensorflow_gnn.models import mt_albis
5
+ import networkx as nx
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ # Set environment variable for legacy Keras
10
+ import os
11
+ os.environ['TF_USE_LEGACY_KERAS'] = '1'
12
+
13
+ # Define the model function
14
+ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
15
+ graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
16
+
17
+ # Encode input features to match the required output shape of 128
18
+ graph = tfgnn.keras.layers.MapFeatures(
19
+ node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
20
+ )(graph)
21
+
22
+ # For each round of message passing...
23
+ for _ in range(2):
24
+ # ... create and apply a Keras layer.
25
+ graph = mt_albis.MtAlbisGraphUpdate(
26
+ units=128, message_dim=64,
27
+ attention_type="none", simple_conv_reduce_type="mean",
28
+ normalization_type="layer", next_state_type="residual",
29
+ state_dropout_rate=0.2, l2_regularization=1e-5,
30
+ receiver_tag=tfgnn.TARGET # Use TARGET instead of NODES
31
+ )(graph)
32
+
33
+ return tf.keras.Model(inputs, graph)
34
+
35
+ # Function to create a sample graph with meaningful synthetic data
36
+ def create_sample_graph():
37
+ num_nodes = 10
38
+ num_edges = 15
39
+
40
+ graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
41
+
42
+ # Create synthetic features
43
+ years_published = np.random.randint(1990, 2022, size=num_nodes).astype(np.float32)
44
+ num_authors = np.random.randint(1, 10, size=num_nodes).astype(np.float32)
45
+ citation_weights = np.random.uniform(0.1, 5.0, size=num_edges).astype(np.float32)
46
+
47
+ # Combine features into a single array per node
48
+ node_features = np.stack([years_published, num_authors], axis=-1)
49
+ edge_features = citation_weights.reshape(-1, 1)
50
+
51
+ graph_tensor = tfgnn.GraphTensor.from_pieces(
52
+ node_sets={
53
+ "papers": tfgnn.NodeSet.from_fields(
54
+ sizes=[num_nodes],
55
+ features={"features": tf.convert_to_tensor(node_features)}
56
+ )
57
+ },
58
+ edge_sets={
59
+ "cites": tfgnn.EdgeSet.from_fields(
60
+ sizes=[num_edges],
61
+ adjacency=tfgnn.Adjacency.from_indices(
62
+ source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
63
+ target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
64
+ ),
65
+ features={"features": tf.convert_to_tensor(edge_features)}
66
+ )
67
+ }
68
+ )
69
+
70
+ return graph, graph_tensor
71
+
72
+ # Streamlit app
73
+ def main():
74
+ st.title("Graph Neural Network Architecture Visualization")
75
+
76
+ # Create sample graph
77
+ nx_graph, graph_tensor = create_sample_graph()
78
+
79
+ # Create and compile the model
80
+ model = model_fn(graph_tensor.spec)
81
+ model.compile(optimizer='adam', loss='binary_crossentropy')
82
+
83
+ # Display model summary
84
+ st.subheader("Model Summary")
85
+ model.summary(print_fn=lambda x: st.text(x))
86
+
87
+ # Visualize the graph
88
+ st.subheader("Sample Graph Visualization")
89
+ fig, ax = plt.subplots(figsize=(10, 8))
90
+ pos = nx.spring_layout(nx_graph)
91
+ nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
92
+ node_size=500, arrowsize=20, ax=ax)
93
+ st.pyplot(fig)
94
+
95
+ # Display graph tensor info
96
+ st.subheader("Graph Tensor Information")
97
+ st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
98
+ st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
99
+ st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
100
+ st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
101
+
102
+ # Display sample node and edge features
103
+ st.subheader("Sample Node and Edge Features")
104
+ st.write("Node Features (Year Published, Number of Authors):")
105
+ st.write(node_features)
106
+ st.write("Edge Features (Citation Weight):")
107
+ st.write(edge_features)
108
+
109
+ if __name__ == "__main__":
110
+ main()