File size: 4,685 Bytes
0e91aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37fcf28
0e91aec
 
37fcf28
0e91aec
 
 
 
 
 
 
 
37fcf28
 
 
 
0e91aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db27fce
0e91aec
 
 
 
 
b1de838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e91aec
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import mt_albis
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# Set environment variable for legacy Keras
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'

# Define the model function
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
    graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)

    # Encode input features to match the required output shape of 128
    graph = tfgnn.keras.layers.MapFeatures(
        node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
    )(graph)

    # For each round of message passing...
    for _ in range(2):
        # ... create and apply a Keras layer.
        graph = mt_albis.MtAlbisGraphUpdate(
            units=128, message_dim=64,
            attention_type="none", simple_conv_reduce_type="mean",
            normalization_type="layer", next_state_type="residual",
            state_dropout_rate=0.2, l2_regularization=1e-5,
            receiver_tag=tfgnn.TARGET  # Use TARGET instead of NODES
        )(graph)

    return tf.keras.Model(inputs, graph)

# Function to create a sample graph with meaningful synthetic data
def create_sample_graph():
    num_nodes = 10
    num_edges = 15

    # Create a random graph
    graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)

    # Generate synthetic features
    years_published = np.random.randint(1990, 2022, size=num_nodes).astype(np.float32)
    num_authors = np.random.randint(1, 10, size=num_nodes).astype(np.float32)
    citation_weights = np.random.uniform(0.1, 5.0, size=num_edges).astype(np.float32)
    
    # Combine features into a single array per node
    node_features = np.stack([years_published, num_authors], axis=-1)
    edge_features = citation_weights.reshape(-1, 1)

    # Assign random titles to nodes
    paper_titles = [f"Paper {i+1}" for i in range(num_nodes)]
    nx.set_node_attributes(graph, {i: {'title': title} for i, title in enumerate(paper_titles)})

    graph_tensor = tfgnn.GraphTensor.from_pieces(
        node_sets={
            "papers": tfgnn.NodeSet.from_fields(
                sizes=[num_nodes],
                features={"features": tf.convert_to_tensor(node_features)}
            )
        },
        edge_sets={
            "cites": tfgnn.EdgeSet.from_fields(
                sizes=[num_edges],
                adjacency=tfgnn.Adjacency.from_indices(
                    source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
                    target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
                ),
                features={"features": tf.convert_to_tensor(edge_features)}
            )
        }
    )

    return graph, graph_tensor, node_features, edge_features

# Streamlit app
def main():
    st.title("Graph Neural Network Architecture Visualization")

    if st.button("Recreate Graph"):
        recreate_graph = True
    else:
        recreate_graph = False

    if recreate_graph:
        # Create sample graph
        nx_graph, graph_tensor, node_features, edge_features = create_sample_graph()

        # Create and compile the model
        model = model_fn(graph_tensor.spec)
        model.compile(optimizer='adam', loss='binary_crossentropy')

        # Display model summary
        st.subheader("Model Summary")
        model.summary(print_fn=lambda x: st.text(x))

        # Visualize the graph
        st.subheader("Sample Graph Visualization")
        fig, ax = plt.subplots(figsize=(10, 8))
        pos = nx.spring_layout(nx_graph)
        labels = nx.get_node_attributes(nx_graph, 'title')
        nx.draw(nx_graph, pos, labels=labels, with_labels=True, node_color='lightblue',
                node_size=3000, arrowsize=20, ax=ax)  # Increased node_size to 3000
        st.pyplot(fig)

        # Display graph tensor info
        st.subheader("Graph Tensor Information")
        st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
        st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
        st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
        st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")

        # Display sample node and edge features
        st.subheader("Sample Node and Edge Features")
        st.write("Node Features (Year Published, Number of Authors):")
        st.write(node_features)
        st.write("Edge Features (Citation Weight):")
        st.write(edge_features)

if __name__ == "__main__":
    main()