File size: 3,399 Bytes
a7a0009
 
 
 
 
 
 
b6966f8
 
 
 
a7a0009
 
 
8e03eb3
673466f
a7a0009
673466f
a7a0009
8e03eb3
 
a7a0009
8e03eb3
a7a0009
8e03eb3
 
 
 
2478080
a7a0009
8e03eb3
a7a0009
 
 
 
 
 
8e03eb3
a7a0009
8e03eb3
a7a0009
673466f
a7a0009
8e03eb3
a7a0009
 
 
 
 
 
 
 
 
 
 
a786ddb
 
a7a0009
 
 
 
 
8e03eb3
a7a0009
 
 
 
 
8e03eb3
a7a0009
 
8e03eb3
a7a0009
 
 
8e03eb3
a7a0009
 
 
8e03eb3
a7a0009
 
 
 
8e03eb3
a7a0009
 
8e03eb3
a7a0009
 
 
 
 
 
 
 
a786ddb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import mt_albis
import networkx as nx
import matplotlib.pyplot as plt

# Set environment variable for legacy Keras
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'

# Define the model function
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
    graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)

    # Encode input features to match the required output shape of 128
    graph = tfgnn.keras.layers.MapFeatures(
        node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
    )(graph)

    # For each round of message passing...
    for _ in range(2):
        # ... create and apply a Keras layer.
        graph = mt_albis.MtAlbisGraphUpdate(
            units=128, message_dim=64,
            attention_type="none", simple_conv_reduce_type="mean",
            normalization_type="layer", next_state_type="residual",
            state_dropout_rate=0.2, l2_regularization=1e-5,
            receiver_tag=tfgnn.TARGET  # Use TARGET instead of NODES
        )(graph)

    return tf.keras.Model(inputs, graph)

# Function to create a sample graph
def create_sample_graph():
    num_nodes = 10
    num_edges = 15

    graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)

    # Create a GraphTensor
    node_features = tf.random.normal((num_nodes, 128))  # Match the dense layer output
    edge_features = tf.random.normal((num_edges, 32))

    graph_tensor = tfgnn.GraphTensor.from_pieces(
        node_sets={
            "papers": tfgnn.NodeSet.from_fields(
                sizes=[num_nodes],
                features={"features": node_features}
            )
        },
        edge_sets={
            "cites": tfgnn.EdgeSet.from_fields(
                sizes=[num_edges],
                adjacency=tfgnn.Adjacency.from_indices(
                    source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
                    target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
                ),
                features={"features": edge_features}
            )
        }
    )

    return graph, graph_tensor

# Streamlit app
def main():
    st.title("Graph Neural Network Architecture Visualization")

    # Create sample graph
    nx_graph, graph_tensor = create_sample_graph()

    # Create and compile the model
    model = model_fn(graph_tensor.spec)
    model.compile(optimizer='adam', loss='binary_crossentropy')

    # Display model summary
    st.subheader("Model Summary")
    model.summary(print_fn=lambda x: st.text(x))

    # Visualize the graph
    st.subheader("Sample Graph Visualization")
    fig, ax = plt.subplots(figsize=(10, 8))
    pos = nx.spring_layout(nx_graph)
    nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
            node_size=500, arrowsize=20, ax=ax)
    st.pyplot(fig)

    # Display graph tensor info
    st.subheader("Graph Tensor Information")
    st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
    st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
    st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
    st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")

if __name__ == "__main__":
    main()