Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Rename pages/15_Graphs.py to pages/15_TransferLearning_HF.py
Browse files- pages/15_Graphs.py +0 -97
- pages/15_TransferLearning_HF.py +85 -0
pages/15_Graphs.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import tensorflow as tf
|
3 |
-
import tensorflow_gnn as tfgnn
|
4 |
-
from tensorflow_gnn.models import mt_albis
|
5 |
-
import networkx as nx
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
|
8 |
-
# Set environment variable for legacy Keras
|
9 |
-
import os
|
10 |
-
os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
11 |
-
|
12 |
-
# Define the model function
|
13 |
-
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
14 |
-
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
|
15 |
-
|
16 |
-
# Encode input features to match the required output shape of 128
|
17 |
-
graph = tfgnn.keras.layers.MapFeatures(
|
18 |
-
node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
|
19 |
-
)(graph)
|
20 |
-
|
21 |
-
# For each round of message passing...
|
22 |
-
for _ in range(2):
|
23 |
-
# ... create and apply a Keras layer.
|
24 |
-
graph = mt_albis.MtAlbisGraphUpdate(
|
25 |
-
units=128, message_dim=64,
|
26 |
-
attention_type="none", simple_conv_reduce_type="mean",
|
27 |
-
normalization_type="layer", next_state_type="residual",
|
28 |
-
state_dropout_rate=0.2, l2_regularization=1e-5,
|
29 |
-
receiver_tag=tfgnn.TARGET # Use TARGET instead of NODES
|
30 |
-
)(graph)
|
31 |
-
|
32 |
-
return tf.keras.Model(inputs, graph)
|
33 |
-
|
34 |
-
# Function to create a sample graph
|
35 |
-
def create_sample_graph():
|
36 |
-
num_nodes = 10
|
37 |
-
num_edges = 15
|
38 |
-
|
39 |
-
graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
|
40 |
-
|
41 |
-
# Create a GraphTensor
|
42 |
-
node_features = tf.random.normal((num_nodes, 128)) # Match the dense layer output
|
43 |
-
edge_features = tf.random.normal((num_edges, 32))
|
44 |
-
|
45 |
-
graph_tensor = tfgnn.GraphTensor.from_pieces(
|
46 |
-
node_sets={
|
47 |
-
"papers": tfgnn.NodeSet.from_fields(
|
48 |
-
sizes=[num_nodes],
|
49 |
-
features={"features": node_features}
|
50 |
-
)
|
51 |
-
},
|
52 |
-
edge_sets={
|
53 |
-
"cites": tfgnn.EdgeSet.from_fields(
|
54 |
-
sizes=[num_edges],
|
55 |
-
adjacency=tfgnn.Adjacency.from_indices(
|
56 |
-
source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
|
57 |
-
target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
|
58 |
-
),
|
59 |
-
features={"features": edge_features}
|
60 |
-
)
|
61 |
-
}
|
62 |
-
)
|
63 |
-
|
64 |
-
return graph, graph_tensor
|
65 |
-
|
66 |
-
# Streamlit app
|
67 |
-
def main():
|
68 |
-
st.title("Graph Neural Network Architecture Visualization")
|
69 |
-
|
70 |
-
# Create sample graph
|
71 |
-
nx_graph, graph_tensor = create_sample_graph()
|
72 |
-
|
73 |
-
# Create and compile the model
|
74 |
-
model = model_fn(graph_tensor.spec)
|
75 |
-
model.compile(optimizer='adam', loss='binary_crossentropy')
|
76 |
-
|
77 |
-
# Display model summary
|
78 |
-
st.subheader("Model Summary")
|
79 |
-
model.summary(print_fn=lambda x: st.text(x))
|
80 |
-
|
81 |
-
# Visualize the graph
|
82 |
-
st.subheader("Sample Graph Visualization")
|
83 |
-
fig, ax = plt.subplots(figsize=(10, 8))
|
84 |
-
pos = nx.spring_layout(nx_graph)
|
85 |
-
nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
|
86 |
-
node_size=500, arrowsize=20, ax=ax)
|
87 |
-
st.pyplot(fig)
|
88 |
-
|
89 |
-
# Display graph tensor info
|
90 |
-
st.subheader("Graph Tensor Information")
|
91 |
-
st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
|
92 |
-
st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
|
93 |
-
st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
|
94 |
-
st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
|
95 |
-
|
96 |
-
if __name__ == "__main__":
|
97 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/15_TransferLearning_HF.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tensorflow as tf
|
3 |
+
from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification
|
4 |
+
import tensorflow_datasets as tfds
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Load the dataset
|
9 |
+
dataset_name = "cats_vs_dogs"
|
10 |
+
(ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True)
|
11 |
+
|
12 |
+
# Preprocess the dataset
|
13 |
+
def preprocess_image(image, label):
|
14 |
+
image = tf.image.resize(image, (224, 224)) # ViT requires 224x224 images
|
15 |
+
image = image / 255.0
|
16 |
+
return image, label
|
17 |
+
|
18 |
+
ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
|
19 |
+
ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE)
|
20 |
+
|
21 |
+
# Streamlit app
|
22 |
+
st.title("Transfer Learning with Vision Transformer for Image Classification")
|
23 |
+
|
24 |
+
# Input parameters
|
25 |
+
batch_size = st.slider("Batch Size", 16, 128, 32, 16)
|
26 |
+
epochs = st.slider("Epochs", 5, 50, 10, 5)
|
27 |
+
|
28 |
+
# Load the pre-trained Vision Transformer model
|
29 |
+
model_name = "google/vit-base-patch16-224-in21k"
|
30 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
31 |
+
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
|
32 |
+
|
33 |
+
# Freeze the convolutional base
|
34 |
+
base_model.trainable = False
|
35 |
+
|
36 |
+
# Add custom layers on top
|
37 |
+
inputs = tf.keras.Input(shape=(224, 224, 3))
|
38 |
+
x = feature_extractor(inputs)
|
39 |
+
x = tf.keras.layers.Flatten()(base_model(x)[0])
|
40 |
+
x = tf.keras.layers.Dense(256, activation='relu')(x)
|
41 |
+
x = tf.keras.layers.Dropout(0.5)(x)
|
42 |
+
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
|
43 |
+
model = tf.keras.Model(inputs, outputs)
|
44 |
+
|
45 |
+
model.summary()
|
46 |
+
|
47 |
+
# Compile the model
|
48 |
+
model.compile(optimizer='adam',
|
49 |
+
loss='binary_crossentropy', # Change loss function based on the number of classes
|
50 |
+
metrics=['accuracy'])
|
51 |
+
|
52 |
+
# Train the model
|
53 |
+
if st.button("Train Model"):
|
54 |
+
with st.spinner("Training the model..."):
|
55 |
+
history = model.fit(
|
56 |
+
ds_train,
|
57 |
+
epochs=epochs,
|
58 |
+
validation_data=ds_val
|
59 |
+
)
|
60 |
+
|
61 |
+
st.success("Model training completed!")
|
62 |
+
|
63 |
+
# Display training curves
|
64 |
+
st.subheader("Training and Validation Accuracy")
|
65 |
+
fig, ax = plt.subplots()
|
66 |
+
ax.plot(history.history['accuracy'], label='Training Accuracy')
|
67 |
+
ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
|
68 |
+
ax.set_xlabel('Epoch')
|
69 |
+
ax.set_ylabel('Accuracy')
|
70 |
+
ax.legend()
|
71 |
+
st.pyplot(fig)
|
72 |
+
|
73 |
+
st.subheader("Training and Validation Loss")
|
74 |
+
fig, ax = plt.subplots()
|
75 |
+
ax.plot(history.history['loss'], label='Training Loss')
|
76 |
+
ax.plot(history.history['val_loss'], label='Validation Loss')
|
77 |
+
ax.set_xlabel('Epoch')
|
78 |
+
ax.set_ylabel('Loss')
|
79 |
+
ax.legend()
|
80 |
+
st.pyplot(fig)
|
81 |
+
|
82 |
+
# Evaluate the model
|
83 |
+
if st.button("Evaluate Model"):
|
84 |
+
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
|
85 |
+
st.write(f"Validation accuracy: {test_acc}")
|