Tazik Shahjahan commited on
Commit
aa28aa6
·
0 Parent(s):
Files changed (3) hide show
  1. .gitignore +23 -0
  2. app.py +252 -0
  3. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the virtual environment
2
+ derm/
3
+
4
+ # Ignore model weights & large binary files
5
+ model.h5
6
+ scin_dataset_precomputed_embeddings.npz
7
+
8
+ # Ignore system files
9
+ __pycache__/
10
+ *.pyc
11
+ *.pyo
12
+ .DS_Store
13
+
14
+ # Ignore logs and temporary files
15
+ logs/
16
+ *.log
17
+ *.out
18
+ *.err
19
+
20
+ # Ignore IDE/Editor files
21
+ .vscode/
22
+ .idea/
23
+ *.swp
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import io
4
+ import numpy as np
5
+ import pandas as pd
6
+ import tensorflow as tf
7
+ from tensorflow.keras import layers, regularizers
8
+ from sklearn.preprocessing import MultiLabelBinarizer
9
+ from sklearn.model_selection import train_test_split
10
+ from google.cloud import storage
11
+ from huggingface_hub import hf_hub_download, notebook_login, login
12
+ from PIL import Image
13
+ import gradio as gr
14
+ import collections
15
+
16
+ login()
17
+
18
+ # ======================
19
+ # CONSTANTS & CONFIGURATION
20
+ # ======================
21
+
22
+ SCIN_GCP_PROJECT = 'dx-scin-public'
23
+ SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data'
24
+ SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv'
25
+ SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv'
26
+
27
+ SCIN_HF_MODEL_NAME = 'google/derm-foundation'
28
+ SCIN_HF_EMBEDDING_FILE = 'scin_dataset_precomputed_embeddings.npz'
29
+
30
+ # The 10 conditions we want to predict
31
+ CONDITIONS_TO_PREDICT = [
32
+ 'Eczema',
33
+ 'Allergic Contact Dermatitis',
34
+ 'Insect Bite',
35
+ 'Urticaria',
36
+ 'Psoriasis',
37
+ 'Folliculitis',
38
+ 'Irritant Contact Dermatitis',
39
+ 'Tinea',
40
+ 'Herpes Zoster',
41
+ 'Drug Rash'
42
+ ]
43
+
44
+ # ======================
45
+ # HELPER FUNCTIONS FOR DATA LOADING
46
+ # ======================
47
+
48
+ def initialize_df_with_metadata(bucket, csv_path):
49
+ csv_bytes = bucket.blob(csv_path).download_as_string()
50
+ df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str})
51
+ df['case_id'] = df['case_id'].astype(str)
52
+ return df
53
+
54
+ def augment_metadata_with_labels(df, bucket, csv_path):
55
+ csv_bytes = bucket.blob(csv_path).download_as_string()
56
+ labels_df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str})
57
+ labels_df['case_id'] = labels_df['case_id'].astype(str)
58
+ merged_df = pd.merge(df, labels_df, on='case_id')
59
+ return merged_df
60
+
61
+ def load_embeddings_from_file(repo_id, object_name):
62
+ file_path = hf_hub_download(repo_id=repo_id, filename=object_name, local_dir='./')
63
+ embeddings = {}
64
+ with open(file_path, 'rb') as f:
65
+ npz_file = np.load(f, allow_pickle=True)
66
+ for key, value in npz_file.items():
67
+ embeddings[key] = value
68
+ return embeddings
69
+
70
+ # ======================
71
+ # DATA PREPARATION FUNCTION
72
+ # ======================
73
+
74
+ def prepare_data(df, embeddings):
75
+ MINIMUM_CONFIDENCE = 0 # Adjust this if needed.
76
+ X = []
77
+ y = []
78
+ poor_image_quality_counter = 0
79
+ missing_embedding_counter = 0
80
+ not_in_condition_counter = 0
81
+ condition_confidence_low_counter = 0
82
+
83
+ for row in df.itertuples():
84
+ # Check if the image is marked as having sufficient quality.
85
+ if getattr(row, 'dermatologist_gradable_for_skin_condition_1', None) != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT':
86
+ poor_image_quality_counter += 1
87
+ continue
88
+
89
+ # Parse the labels and confidences.
90
+ try:
91
+ labels = eval(getattr(row, 'dermatologist_skin_condition_on_label_name', '[]'))
92
+ confidences = eval(getattr(row, 'dermatologist_skin_condition_confidence', '[]'))
93
+ except Exception as e:
94
+ continue
95
+
96
+ row_labels = []
97
+ for label, conf in zip(labels, confidences):
98
+ if label not in CONDITIONS_TO_PREDICT:
99
+ not_in_condition_counter += 1
100
+ continue
101
+ if conf < MINIMUM_CONFIDENCE:
102
+ condition_confidence_low_counter += 1
103
+ continue
104
+ row_labels.append(label)
105
+
106
+ # For each image associated with this case, add its embedding and labels.
107
+ for image_path in [getattr(row, 'image_1_path', None),
108
+ getattr(row, 'image_2_path', None),
109
+ getattr(row, 'image_3_path', None)]:
110
+ if pd.isna(image_path) or image_path is None:
111
+ continue
112
+ if image_path not in embeddings:
113
+ missing_embedding_counter += 1
114
+ continue
115
+ X.append(embeddings[image_path])
116
+ y.append(row_labels)
117
+
118
+ print(f'Poor image quality count: {poor_image_quality_counter}')
119
+ print(f'Missing embedding count: {missing_embedding_counter}')
120
+ print(f'Condition not in list count: {not_in_condition_counter}')
121
+ print(f'Excluded due to low confidence count: {condition_confidence_low_counter}')
122
+ return X, y
123
+
124
+ # ======================
125
+ # MODEL BUILDING FUNCTION
126
+ # ======================
127
+
128
+ def build_model(input_dim, output_dim, weight_decay=1e-4):
129
+ inputs = tf.keras.Input(shape=(input_dim,))
130
+ hidden = layers.Dense(256, activation="relu",
131
+ kernel_regularizer=regularizers.l2(weight_decay),
132
+ bias_regularizer=regularizers.l2(weight_decay))(inputs)
133
+ hidden = layers.Dropout(0.1)(hidden)
134
+ hidden = layers.Dense(128, activation="relu",
135
+ kernel_regularizer=regularizers.l2(weight_decay),
136
+ bias_regularizer=regularizers.l2(weight_decay))(hidden)
137
+ hidden = layers.Dropout(0.1)(hidden)
138
+ output = layers.Dense(output_dim, activation="sigmoid")(hidden)
139
+ model = tf.keras.Model(inputs, output)
140
+ model.compile(loss="binary_crossentropy",
141
+ optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))
142
+ return model
143
+
144
+ # ======================
145
+ # MAIN FUNCTION & GRADIO INTERFACE
146
+ # ======================
147
+
148
+ def main():
149
+ # Connect to the Google Cloud Storage bucket.
150
+ storage_client = storage.Client(SCIN_GCP_PROJECT)
151
+ bucket = storage_client.bucket(SCIN_GCS_BUCKET_NAME)
152
+
153
+ # Load SCIN dataset CSVs and merge them.
154
+ df_cases = initialize_df_with_metadata(bucket, SCIN_GCS_CASES_CSV)
155
+ df_full = augment_metadata_with_labels(df_cases, bucket, SCIN_GCS_LABELS_CSV)
156
+ df_full.set_index('case_id', inplace=True)
157
+
158
+ # Load precomputed embeddings from Hugging Face.
159
+ print("Loading embeddings...")
160
+ embeddings = load_embeddings_from_file(SCIN_HF_MODEL_NAME, SCIN_HF_EMBEDDING_FILE)
161
+
162
+ # Prepare the training data.
163
+ print("Preparing training data...")
164
+ X, y = prepare_data(df_full, embeddings)
165
+ X = np.array(X)
166
+ # Convert the list of label lists to binary arrays.
167
+ mlb = MultiLabelBinarizer(classes=CONDITIONS_TO_PREDICT)
168
+ y_bin = mlb.fit_transform(y)
169
+
170
+ # Split the data into train and test sets.
171
+ X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42)
172
+
173
+ # Build the model.
174
+ model = build_model(input_dim=6144, output_dim=len(CONDITIONS_TO_PREDICT))
175
+
176
+ # If a saved model exists, load it; otherwise, train and save it.
177
+ model_file = "model.h5"
178
+ if os.path.exists(model_file):
179
+ print("Loading existing model from", model_file)
180
+ model = tf.keras.models.load_model(model_file)
181
+ else:
182
+ print("Training model... This may take a few minutes.")
183
+ train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32)
184
+ test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)
185
+ model.fit(train_ds, validation_data=test_ds, epochs=15)
186
+ model.save(model_file)
187
+
188
+ # Extract a list of case IDs for dropdown
189
+ case_ids = list(df_full.index)
190
+
191
+ def predict_case(case_id: str):
192
+ """Fetch images and predictions for a given case ID."""
193
+ if case_id not in df_full.index:
194
+ return [], "Case ID not found!", "N/A", "N/A"
195
+
196
+ row = df_full.loc[case_id]
197
+ image_paths = [row.get('image_1_path'), row.get('image_2_path'), row.get('image_3_path')]
198
+ images, predictions_text = [], []
199
+
200
+ # Get Dermatologist's Labels
201
+ dermatologist_conditions = row.get('dermatologist_skin_condition_on_label_name', "N/A")
202
+ dermatologist_confidence = row.get('dermatologist_skin_condition_confidence', "N/A")
203
+
204
+ if isinstance(dermatologist_conditions, str):
205
+ try:
206
+ dermatologist_conditions = eval(dermatologist_conditions)
207
+ dermatologist_confidence = eval(dermatologist_confidence)
208
+ except:
209
+ pass
210
+
211
+ # Process images & generate predictions
212
+ for path in image_paths:
213
+ if isinstance(path, str) and (path in embeddings):
214
+ try:
215
+ img_bytes = bucket.blob(path).download_as_string()
216
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
217
+ images.append(img)
218
+ except:
219
+ continue
220
+
221
+ # Model Prediction
222
+ emb = np.expand_dims(embeddings[path], axis=0)
223
+ pred = model.predict(emb)[0]
224
+ pred_dict = {cond: round(float(prob), 3) for cond, prob in zip(mlb.classes_, pred)}
225
+ predictions_text.append(str(pred_dict))
226
+
227
+ # Format the output
228
+ predictions_text = "\n".join(predictions_text) if predictions_text else "No predictions available."
229
+ dermatologist_conditions = str(dermatologist_conditions)
230
+ dermatologist_confidence = str(dermatologist_confidence)
231
+
232
+ return images, predictions_text, dermatologist_conditions, dermatologist_confidence
233
+
234
+ # Create the Gradio Interface with a Dropdown
235
+ iface = gr.Interface(
236
+ fn=predict_case,
237
+ inputs=gr.Dropdown(choices=case_ids, label="Select a Case ID"),
238
+ outputs=[
239
+ gr.Gallery(label="Case Images"),
240
+ gr.Textbox(label="Model's Predictions"),
241
+ gr.Textbox(label="Dermatologist's Skin Conditions"),
242
+ gr.Textbox(label="Dermatologist's Confidence Ratings")
243
+ ],
244
+ title="Derm Foundation Skin Conditions Explorer",
245
+ description="Select a Case ID from the dropdown to view images and predictions."
246
+ )
247
+
248
+ iface.launch()
249
+
250
+
251
+ if __name__ == "__main__":
252
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ tensorflow
3
+ numpy
4
+ pandas
5
+ scikit-learn
6
+ google-cloud-storage
7
+ huggingface_hub
8
+ pillow