eaglelandsonce commited on
Commit
f72b662
·
verified ·
1 Parent(s): 47a1d07

Rename pages/15_Graphs.py to pages/15_TransferLearning_HF.py

Browse files
pages/15_Graphs.py DELETED
@@ -1,97 +0,0 @@
1
- import streamlit as st
2
- import tensorflow as tf
3
- import tensorflow_gnn as tfgnn
4
- from tensorflow_gnn.models import mt_albis
5
- import networkx as nx
6
- import matplotlib.pyplot as plt
7
-
8
- # Set environment variable for legacy Keras
9
- import os
10
- os.environ['TF_USE_LEGACY_KERAS'] = '1'
11
-
12
- # Define the model function
13
- def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
- graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
-
16
- # Encode input features to match the required output shape of 128
17
- graph = tfgnn.keras.layers.MapFeatures(
18
- node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
19
- )(graph)
20
-
21
- # For each round of message passing...
22
- for _ in range(2):
23
- # ... create and apply a Keras layer.
24
- graph = mt_albis.MtAlbisGraphUpdate(
25
- units=128, message_dim=64,
26
- attention_type="none", simple_conv_reduce_type="mean",
27
- normalization_type="layer", next_state_type="residual",
28
- state_dropout_rate=0.2, l2_regularization=1e-5,
29
- receiver_tag=tfgnn.TARGET # Use TARGET instead of NODES
30
- )(graph)
31
-
32
- return tf.keras.Model(inputs, graph)
33
-
34
- # Function to create a sample graph
35
- def create_sample_graph():
36
- num_nodes = 10
37
- num_edges = 15
38
-
39
- graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
40
-
41
- # Create a GraphTensor
42
- node_features = tf.random.normal((num_nodes, 128)) # Match the dense layer output
43
- edge_features = tf.random.normal((num_edges, 32))
44
-
45
- graph_tensor = tfgnn.GraphTensor.from_pieces(
46
- node_sets={
47
- "papers": tfgnn.NodeSet.from_fields(
48
- sizes=[num_nodes],
49
- features={"features": node_features}
50
- )
51
- },
52
- edge_sets={
53
- "cites": tfgnn.EdgeSet.from_fields(
54
- sizes=[num_edges],
55
- adjacency=tfgnn.Adjacency.from_indices(
56
- source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
57
- target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
58
- ),
59
- features={"features": edge_features}
60
- )
61
- }
62
- )
63
-
64
- return graph, graph_tensor
65
-
66
- # Streamlit app
67
- def main():
68
- st.title("Graph Neural Network Architecture Visualization")
69
-
70
- # Create sample graph
71
- nx_graph, graph_tensor = create_sample_graph()
72
-
73
- # Create and compile the model
74
- model = model_fn(graph_tensor.spec)
75
- model.compile(optimizer='adam', loss='binary_crossentropy')
76
-
77
- # Display model summary
78
- st.subheader("Model Summary")
79
- model.summary(print_fn=lambda x: st.text(x))
80
-
81
- # Visualize the graph
82
- st.subheader("Sample Graph Visualization")
83
- fig, ax = plt.subplots(figsize=(10, 8))
84
- pos = nx.spring_layout(nx_graph)
85
- nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
86
- node_size=500, arrowsize=20, ax=ax)
87
- st.pyplot(fig)
88
-
89
- # Display graph tensor info
90
- st.subheader("Graph Tensor Information")
91
- st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
92
- st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
93
- st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
94
- st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
95
-
96
- if __name__ == "__main__":
97
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/15_TransferLearning_HF.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification
4
+ import tensorflow_datasets as tfds
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Load the dataset
9
+ dataset_name = "cats_vs_dogs"
10
+ (ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True)
11
+
12
+ # Preprocess the dataset
13
+ def preprocess_image(image, label):
14
+ image = tf.image.resize(image, (224, 224)) # ViT requires 224x224 images
15
+ image = image / 255.0
16
+ return image, label
17
+
18
+ ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
19
+ ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
20
+
21
+ # Streamlit app
22
+ st.title("Transfer Learning with Vision Transformer for Image Classification")
23
+
24
+ # Input parameters
25
+ batch_size = st.slider("Batch Size", 16, 128, 32, 16)
26
+ epochs = st.slider("Epochs", 5, 50, 10, 5)
27
+
28
+ # Load the pre-trained Vision Transformer model
29
+ model_name = "google/vit-base-patch16-224-in21k"
30
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
31
+ base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
32
+
33
+ # Freeze the convolutional base
34
+ base_model.trainable = False
35
+
36
+ # Add custom layers on top
37
+ inputs = tf.keras.Input(shape=(224, 224, 3))
38
+ x = feature_extractor(inputs)
39
+ x = tf.keras.layers.Flatten()(base_model(x)[0])
40
+ x = tf.keras.layers.Dense(256, activation='relu')(x)
41
+ x = tf.keras.layers.Dropout(0.5)(x)
42
+ outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
43
+ model = tf.keras.Model(inputs, outputs)
44
+
45
+ model.summary()
46
+
47
+ # Compile the model
48
+ model.compile(optimizer='adam',
49
+ loss='binary_crossentropy', # Change loss function based on the number of classes
50
+ metrics=['accuracy'])
51
+
52
+ # Train the model
53
+ if st.button("Train Model"):
54
+ with st.spinner("Training the model..."):
55
+ history = model.fit(
56
+ ds_train,
57
+ epochs=epochs,
58
+ validation_data=ds_val
59
+ )
60
+
61
+ st.success("Model training completed!")
62
+
63
+ # Display training curves
64
+ st.subheader("Training and Validation Accuracy")
65
+ fig, ax = plt.subplots()
66
+ ax.plot(history.history['accuracy'], label='Training Accuracy')
67
+ ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
68
+ ax.set_xlabel('Epoch')
69
+ ax.set_ylabel('Accuracy')
70
+ ax.legend()
71
+ st.pyplot(fig)
72
+
73
+ st.subheader("Training and Validation Loss")
74
+ fig, ax = plt.subplots()
75
+ ax.plot(history.history['loss'], label='Training Loss')
76
+ ax.plot(history.history['val_loss'], label='Validation Loss')
77
+ ax.set_xlabel('Epoch')
78
+ ax.set_ylabel('Loss')
79
+ ax.legend()
80
+ st.pyplot(fig)
81
+
82
+ # Evaluate the model
83
+ if st.button("Evaluate Model"):
84
+ test_loss, test_acc = model.evaluate(ds_val, verbose=2)
85
+ st.write(f"Validation accuracy: {test_acc}")