Spaces:
Sleeping
Sleeping
File size: 3,399 Bytes
a7a0009 b6966f8 a7a0009 8e03eb3 673466f a7a0009 673466f a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 2478080 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 673466f a7a0009 8e03eb3 a7a0009 a786ddb a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 8e03eb3 a7a0009 a786ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import mt_albis
import networkx as nx
import matplotlib.pyplot as plt
# Set environment variable for legacy Keras
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'
# Define the model function
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
# Encode input features to match the required output shape of 128
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
)(graph)
# For each round of message passing...
for _ in range(2):
# ... create and apply a Keras layer.
graph = mt_albis.MtAlbisGraphUpdate(
units=128, message_dim=64,
attention_type="none", simple_conv_reduce_type="mean",
normalization_type="layer", next_state_type="residual",
state_dropout_rate=0.2, l2_regularization=1e-5,
receiver_tag=tfgnn.TARGET # Use TARGET instead of NODES
)(graph)
return tf.keras.Model(inputs, graph)
# Function to create a sample graph
def create_sample_graph():
num_nodes = 10
num_edges = 15
graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
# Create a GraphTensor
node_features = tf.random.normal((num_nodes, 128)) # Match the dense layer output
edge_features = tf.random.normal((num_edges, 32))
graph_tensor = tfgnn.GraphTensor.from_pieces(
node_sets={
"papers": tfgnn.NodeSet.from_fields(
sizes=[num_nodes],
features={"features": node_features}
)
},
edge_sets={
"cites": tfgnn.EdgeSet.from_fields(
sizes=[num_edges],
adjacency=tfgnn.Adjacency.from_indices(
source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
),
features={"features": edge_features}
)
}
)
return graph, graph_tensor
# Streamlit app
def main():
st.title("Graph Neural Network Architecture Visualization")
# Create sample graph
nx_graph, graph_tensor = create_sample_graph()
# Create and compile the model
model = model_fn(graph_tensor.spec)
model.compile(optimizer='adam', loss='binary_crossentropy')
# Display model summary
st.subheader("Model Summary")
model.summary(print_fn=lambda x: st.text(x))
# Visualize the graph
st.subheader("Sample Graph Visualization")
fig, ax = plt.subplots(figsize=(10, 8))
pos = nx.spring_layout(nx_graph)
nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
node_size=500, arrowsize=20, ax=ax)
st.pyplot(fig)
# Display graph tensor info
st.subheader("Graph Tensor Information")
st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
if __name__ == "__main__":
main()
|