|
|
|
import os |
|
import io |
|
import numpy as np |
|
import pandas as pd |
|
import tensorflow as tf |
|
from tensorflow.keras import layers, regularizers |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
from sklearn.model_selection import train_test_split |
|
from google.cloud import storage |
|
from huggingface_hub import hf_hub_download, notebook_login, login |
|
from PIL import Image |
|
import gradio as gr |
|
import collections |
|
|
|
login() |
|
|
|
|
|
|
|
|
|
|
|
SCIN_GCP_PROJECT = 'dx-scin-public' |
|
SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data' |
|
SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv' |
|
SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv' |
|
|
|
SCIN_HF_MODEL_NAME = 'google/derm-foundation' |
|
SCIN_HF_EMBEDDING_FILE = 'scin_dataset_precomputed_embeddings.npz' |
|
|
|
|
|
CONDITIONS_TO_PREDICT = [ |
|
'Eczema', |
|
'Allergic Contact Dermatitis', |
|
'Insect Bite', |
|
'Urticaria', |
|
'Psoriasis', |
|
'Folliculitis', |
|
'Irritant Contact Dermatitis', |
|
'Tinea', |
|
'Herpes Zoster', |
|
'Drug Rash' |
|
] |
|
|
|
|
|
|
|
|
|
|
|
def initialize_df_with_metadata(bucket, csv_path): |
|
csv_bytes = bucket.blob(csv_path).download_as_string() |
|
df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) |
|
df['case_id'] = df['case_id'].astype(str) |
|
return df |
|
|
|
def augment_metadata_with_labels(df, bucket, csv_path): |
|
csv_bytes = bucket.blob(csv_path).download_as_string() |
|
labels_df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) |
|
labels_df['case_id'] = labels_df['case_id'].astype(str) |
|
merged_df = pd.merge(df, labels_df, on='case_id') |
|
return merged_df |
|
|
|
def load_embeddings_from_file(repo_id, object_name): |
|
file_path = hf_hub_download(repo_id=repo_id, filename=object_name, local_dir='./') |
|
embeddings = {} |
|
with open(file_path, 'rb') as f: |
|
npz_file = np.load(f, allow_pickle=True) |
|
for key, value in npz_file.items(): |
|
embeddings[key] = value |
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(df, embeddings): |
|
MINIMUM_CONFIDENCE = 0 |
|
X = [] |
|
y = [] |
|
poor_image_quality_counter = 0 |
|
missing_embedding_counter = 0 |
|
not_in_condition_counter = 0 |
|
condition_confidence_low_counter = 0 |
|
|
|
for row in df.itertuples(): |
|
|
|
if getattr(row, 'dermatologist_gradable_for_skin_condition_1', None) != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT': |
|
poor_image_quality_counter += 1 |
|
continue |
|
|
|
|
|
try: |
|
labels = eval(getattr(row, 'dermatologist_skin_condition_on_label_name', '[]')) |
|
confidences = eval(getattr(row, 'dermatologist_skin_condition_confidence', '[]')) |
|
except Exception as e: |
|
continue |
|
|
|
row_labels = [] |
|
for label, conf in zip(labels, confidences): |
|
if label not in CONDITIONS_TO_PREDICT: |
|
not_in_condition_counter += 1 |
|
continue |
|
if conf < MINIMUM_CONFIDENCE: |
|
condition_confidence_low_counter += 1 |
|
continue |
|
row_labels.append(label) |
|
|
|
|
|
for image_path in [getattr(row, 'image_1_path', None), |
|
getattr(row, 'image_2_path', None), |
|
getattr(row, 'image_3_path', None)]: |
|
if pd.isna(image_path) or image_path is None: |
|
continue |
|
if image_path not in embeddings: |
|
missing_embedding_counter += 1 |
|
continue |
|
X.append(embeddings[image_path]) |
|
y.append(row_labels) |
|
|
|
print(f'Poor image quality count: {poor_image_quality_counter}') |
|
print(f'Missing embedding count: {missing_embedding_counter}') |
|
print(f'Condition not in list count: {not_in_condition_counter}') |
|
print(f'Excluded due to low confidence count: {condition_confidence_low_counter}') |
|
return X, y |
|
|
|
|
|
|
|
|
|
|
|
def build_model(input_dim, output_dim, weight_decay=1e-4): |
|
inputs = tf.keras.Input(shape=(input_dim,)) |
|
hidden = layers.Dense(256, activation="relu", |
|
kernel_regularizer=regularizers.l2(weight_decay), |
|
bias_regularizer=regularizers.l2(weight_decay))(inputs) |
|
hidden = layers.Dropout(0.1)(hidden) |
|
hidden = layers.Dense(128, activation="relu", |
|
kernel_regularizer=regularizers.l2(weight_decay), |
|
bias_regularizer=regularizers.l2(weight_decay))(hidden) |
|
hidden = layers.Dropout(0.1)(hidden) |
|
output = layers.Dense(output_dim, activation="sigmoid")(hidden) |
|
model = tf.keras.Model(inputs, output) |
|
model.compile(loss="binary_crossentropy", |
|
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
storage_client = storage.Client(SCIN_GCP_PROJECT) |
|
bucket = storage_client.bucket(SCIN_GCS_BUCKET_NAME) |
|
|
|
|
|
df_cases = initialize_df_with_metadata(bucket, SCIN_GCS_CASES_CSV) |
|
df_full = augment_metadata_with_labels(df_cases, bucket, SCIN_GCS_LABELS_CSV) |
|
df_full.set_index('case_id', inplace=True) |
|
|
|
|
|
print("Loading embeddings...") |
|
embeddings = load_embeddings_from_file(SCIN_HF_MODEL_NAME, SCIN_HF_EMBEDDING_FILE) |
|
|
|
|
|
print("Preparing training data...") |
|
X, y = prepare_data(df_full, embeddings) |
|
X = np.array(X) |
|
|
|
mlb = MultiLabelBinarizer(classes=CONDITIONS_TO_PREDICT) |
|
y_bin = mlb.fit_transform(y) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42) |
|
|
|
|
|
model = build_model(input_dim=6144, output_dim=len(CONDITIONS_TO_PREDICT)) |
|
|
|
|
|
model_file = "model.h5" |
|
if os.path.exists(model_file): |
|
print("Loading existing model from", model_file) |
|
model = tf.keras.models.load_model(model_file) |
|
else: |
|
print("Training model... This may take a few minutes.") |
|
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32) |
|
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32) |
|
model.fit(train_ds, validation_data=test_ds, epochs=15) |
|
model.save(model_file) |
|
|
|
|
|
case_ids = list(df_full.index) |
|
|
|
def predict_case(case_id: str): |
|
"""Fetch images and predictions for a given case ID.""" |
|
if case_id not in df_full.index: |
|
return [], "Case ID not found!", "N/A", "N/A" |
|
|
|
row = df_full.loc[case_id] |
|
image_paths = [row.get('image_1_path'), row.get('image_2_path'), row.get('image_3_path')] |
|
images, predictions_text = [], [] |
|
|
|
|
|
dermatologist_conditions = row.get('dermatologist_skin_condition_on_label_name', "N/A") |
|
dermatologist_confidence = row.get('dermatologist_skin_condition_confidence', "N/A") |
|
|
|
if isinstance(dermatologist_conditions, str): |
|
try: |
|
dermatologist_conditions = eval(dermatologist_conditions) |
|
dermatologist_confidence = eval(dermatologist_confidence) |
|
except: |
|
pass |
|
|
|
|
|
for path in image_paths: |
|
if isinstance(path, str) and (path in embeddings): |
|
try: |
|
img_bytes = bucket.blob(path).download_as_string() |
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
images.append(img) |
|
except: |
|
continue |
|
|
|
|
|
emb = np.expand_dims(embeddings[path], axis=0) |
|
pred = model.predict(emb)[0] |
|
pred_dict = {cond: round(float(prob), 3) for cond, prob in zip(mlb.classes_, pred)} |
|
predictions_text.append(str(pred_dict)) |
|
|
|
|
|
predictions_text = "\n".join(predictions_text) if predictions_text else "No predictions available." |
|
dermatologist_conditions = str(dermatologist_conditions) |
|
dermatologist_confidence = str(dermatologist_confidence) |
|
|
|
return images, predictions_text, dermatologist_conditions, dermatologist_confidence |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_case, |
|
inputs=gr.Dropdown(choices=case_ids, label="Select a Case ID"), |
|
outputs=[ |
|
gr.Gallery(label="Case Images"), |
|
gr.Textbox(label="Model's Predictions"), |
|
gr.Textbox(label="Dermatologist's Skin Conditions"), |
|
gr.Textbox(label="Dermatologist's Confidence Ratings") |
|
], |
|
title="Derm Foundation Skin Conditions Explorer", |
|
description="Select a Case ID from the dropdown to view images and predictions." |
|
) |
|
|
|
iface.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|