louiecerv commited on
Commit
d06cee5
·
1 Parent(s): 701fb6c

sync with remote

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import os
4
+ import requests
5
+ import tempfile
6
+ import matplotlib.pyplot as plt
7
+ from tensorflow.keras.models import Sequential
8
+ from tensorflow.keras.layers import Flatten, Dense, Reshape
9
+ from tensorflow.keras.losses import SparseCategoricalCrossentropy
10
+ from io import StringIO
11
+
12
+ # Constants for dataset information
13
+ TRAIN_FILE = "train_images.tfrecords"
14
+ VAL_FILE = "val_images.tfrecords"
15
+ TRAIN_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/train_images.tfrecords"
16
+ VAL_URL = "https://huggingface.co/datasets/louiecerv/cardiac_images/resolve/main/val_images.tfrecords"
17
+
18
+ # Use a persistent temp directory
19
+ tmpdir = tempfile.gettempdir()
20
+
21
+ # Function to download a file with progress display
22
+ def download_file(url, local_filename, target_dir):
23
+ os.makedirs(target_dir, exist_ok=True)
24
+ filepath = os.path.join(target_dir, local_filename)
25
+
26
+ if os.path.exists(filepath):
27
+ st.write(f"File already exists: {filepath}")
28
+ return filepath
29
+
30
+ with requests.get(url, stream=True) as r:
31
+ r.raise_for_status()
32
+ total_size = int(r.headers.get('content-length', 0))
33
+
34
+ progress_bar = st.empty() # Create a placeholder
35
+
36
+ with open(filepath, 'wb') as f:
37
+ downloaded_size = 0
38
+ for chunk in r.iter_content(chunk_size=8192):
39
+ if chunk:
40
+ f.write(chunk)
41
+ downloaded_size += len(chunk)
42
+ progress_percent = int(downloaded_size / total_size * 100)
43
+ progress_bar.progress(progress_percent, text=f"Downloading {local_filename}...")
44
+
45
+ return filepath
46
+
47
+ # Download only if files are missing
48
+ train_file_path = download_file(TRAIN_URL, TRAIN_FILE, tmpdir)
49
+ val_file_path = download_file(VAL_URL, VAL_FILE, tmpdir)
50
+
51
+ # Dictionary describing the fields stored in TFRecord
52
+ image_feature_description = {
53
+ 'height': tf.io.FixedLenFeature([], tf.int64),
54
+ 'width': tf.io.FixedLenFeature([], tf.int64),
55
+ 'depth': tf.io.FixedLenFeature([], tf.int64),
56
+ 'name': tf.io.FixedLenFeature([], tf.string),
57
+ 'image_raw': tf.io.FixedLenFeature([], tf.string),
58
+ 'label_raw': tf.io.FixedLenFeature([], tf.string),
59
+ }
60
+
61
+ # Helper function to parse the image and label data from TFRecord
62
+ def _parse_image_function(example_proto):
63
+ return tf.io.parse_single_example(example_proto, image_feature_description)
64
+
65
+ # Function to read and decode an example from the dataset
66
+ @tf.function
67
+ def read_and_decode(example):
68
+ image_raw = tf.io.decode_raw(example['image_raw'], tf.int64)
69
+ image_raw.set_shape([65536])
70
+ image = tf.reshape(image_raw, [256, 256, 1])
71
+
72
+ image = tf.cast(image, tf.float32) * (1. / 1024)
73
+
74
+ label_raw = tf.io.decode_raw(example['label_raw'], tf.uint8)
75
+ label_raw.set_shape([65536])
76
+ label = tf.reshape(label_raw, [256, 256, 1])
77
+
78
+ return image, label
79
+
80
+ # Load and parse datasets
81
+ raw_training_dataset = tf.data.TFRecordDataset(train_file_path)
82
+ raw_val_dataset = tf.data.TFRecordDataset(val_file_path)
83
+
84
+ parsed_training_dataset = raw_training_dataset.map(_parse_image_function)
85
+ parsed_val_dataset = raw_val_dataset.map(_parse_image_function)
86
+
87
+ # Prepare datasets
88
+ tf_autotune = tf.data.experimental.AUTOTUNE
89
+ train = parsed_training_dataset.map(read_and_decode, num_parallel_calls=tf_autotune)
90
+ val = parsed_val_dataset.map(read_and_decode)
91
+
92
+ BUFFER_SIZE = 10
93
+ BATCH_SIZE = 1
94
+
95
+ train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
96
+ train_dataset = train_dataset.prefetch(buffer_size=tf_autotune)
97
+ test_dataset = val.batch(BATCH_SIZE)
98
+
99
+ st.write(train_dataset)
100
+
101
+ def display(display_list):
102
+ fig = plt.figure(figsize=(10, 10))
103
+ title = ['Input Image', 'Label']
104
+
105
+ for i in range(len(display_list)):
106
+ ax = fig.add_subplot(1, len(display_list), i + 1)
107
+ display_resized = tf.reshape(display_list[i], [256, 256])
108
+ ax.set_title(title[i])
109
+ ax.imshow(display_resized, cmap='gray')
110
+ ax.axis('off')
111
+
112
+ st.pyplot(fig)
113
+
114
+ # Streamlit app interface
115
+ st.title("Cardiac Images Dataset")
116
+
117
+ # Display sample images
118
+ for image, label in train.take(2):
119
+ sample_image, sample_label = image, label
120
+ display([sample_image, sample_label])
121
+
122
+ tf.keras.backend.clear_session()
123
+
124
+ # set up the model architecture
125
+ model = tf.keras.models.Sequential([
126
+ Flatten(input_shape=[256, 256, 1]),
127
+ Dense(64, activation='relu'),
128
+ Dense(256*256*2, activation='softmax'),
129
+ Reshape((256, 256, 2))
130
+ ])
131
+
132
+ # specify how to train the model with algorithm, the loss function and metrics
133
+ model.compile(
134
+ optimizer='adam',
135
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
136
+ metrics=['accuracy'])
137
+
138
+ # Capture the model summary
139
+ model_summary = StringIO()
140
+ model.summary(print_fn=lambda x: model_summary.write(x + '\n'))
141
+
142
+ # Display the model summary in Streamlit
143
+ st.markdown(model_summary.getvalue())
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ datasets
3
+ tensorflow
4
+ pandas
5
+ matplotlib