|
import streamlit as st |
|
import openslide |
|
import os |
|
from streamlit_option_menu import option_menu |
|
import torch |
|
from predict import Predictor |
|
|
|
@st.cache(suppress_st_warning=True) |
|
def load_model(): |
|
predictor = Predictor() |
|
return predictor |
|
|
|
@st.cache(suppress_st_warning=True) |
|
def load_dependencies(): |
|
if torch.cuda.is_available(): |
|
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") |
|
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") |
|
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") |
|
else: |
|
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") |
|
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") |
|
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ['DATA_DIR'] = 'queries' |
|
os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches') |
|
os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides') |
|
os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots') |
|
os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True) |
|
|
|
|
|
|
|
os.environ['CLASS_METADATA'] ='metadata/label_map.pkl' |
|
|
|
|
|
os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights' |
|
os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth') |
|
os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth') |
|
|
|
|
|
st.set_page_config(page_title="",layout='wide') |
|
predictor = load_model() |
|
|
|
|
|
|
|
|
|
|
|
ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool." |
|
CONTACT_TEXT = """ |
|
_Built by Christian Cancedda and LabLab lads with love_ ❤️ |
|
[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus) |
|
[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda) |
|
Star project repository: |
|
[![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer) |
|
""" |
|
VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window" |
|
DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease" |
|
|
|
|
|
|
|
with st.sidebar: |
|
choice = option_menu("LastMinute - Diagnosis", |
|
["About", "Visualize WSI slide", "Cancer Detection", "Contact"], |
|
icons=['house', 'upload', 'activity', 'person lines fill'], |
|
menu_icon="app-indicator", default_index=0, |
|
styles={ |
|
|
|
"container": {"border-radius": ".0rem"}, |
|
|
|
|
|
|
|
|
|
} |
|
) |
|
st.sidebar.markdown( |
|
""" |
|
<style> |
|
.aligncenter { |
|
text-align: center; |
|
} |
|
</style> |
|
<p style='text-align: center'> |
|
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">Project Repository</a> |
|
</p> |
|
<p class="aligncenter"> |
|
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank"> |
|
<img src="https://img.shields.io/github/stars/Chris1nexus/inference-graph-transformer?style=social"/> |
|
</a> |
|
</p> |
|
|
|
<p class="aligncenter"> |
|
<a href="https://twitter.com/chris_cancedda" target="_blank"> |
|
<img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/> |
|
</a> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
if choice == "About": |
|
st.title(choice) |
|
|
|
|
|
|
|
if choice == "Visualize WSI slide": |
|
st.title(choice) |
|
st.markdown(VISUALIZE_TEXT) |
|
|
|
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") |
|
if uploaded_file is not None: |
|
ori = openslide.OpenSlide(uploaded_file.name) |
|
width, height = ori.dimensions |
|
|
|
REDUCTION_FACTOR = 20 |
|
w, h = int(width/512), int(height/512) |
|
w_r, h_r = int(width/20), int(height/20) |
|
resized_img = ori.get_thumbnail((w_r,h_r)) |
|
resized_img = resized_img.resize((w_r,h_r)) |
|
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height |
|
|
|
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) |
|
st.image(resized_img, use_column_width='never') |
|
|
|
if choice == "Cancer Detection": |
|
state = dict() |
|
|
|
st.title(choice) |
|
st.markdown(DETECT_TEXT) |
|
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") |
|
if uploaded_file is not None: |
|
|
|
|
|
with open(os.path.join(uploaded_file.name),"wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
with st.spinner(text="Computation is running"): |
|
predicted_class, viz_dict = predictor.predict(uploaded_file.name) |
|
st.info('Computation completed.') |
|
st.header(f'Predicted to be: {predicted_class}') |
|
st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected') |
|
state['cur'] = predicted_class |
|
mapper = {'ORI': predicted_class, predicted_class:'ORI'} |
|
readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR' |
|
|
|
) |
|
|
|
|
|
if choice == "Contact": |
|
st.title(choice) |
|
st.markdown(CONTACT_TEXT) |
|
|
|
if __name__ == '__main__': |
|
|
|
load_dependencies() |
|
|
|
main() |
|
|