AioMedica / app.py
chris1nexus
Update app
9b08ae2
raw
history blame
7.71 kB
import streamlit as st
import openslide
import os
from streamlit_option_menu import option_menu
import torch
from predict import Predictor
@st.cache(suppress_st_warning=True)
def load_model():
predictor = Predictor()
return predictor
@st.cache(suppress_st_warning=True)
def load_dependencies():
if torch.cuda.is_available():
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html")
else:
os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html")
def main():
# environment variables for the inference api
os.environ['DATA_DIR'] = 'queries'
os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches')
os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides')
os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots')
os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True)
# manually put the metadata in the metadata folder
os.environ['CLASS_METADATA'] ='metadata/label_map.pkl'
# manually put the desired weights in the weights folder
os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights'
os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth')
os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth')
st.set_page_config(page_title="",layout='wide')
predictor = load_model()#Predictor()
ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool."
CONTACT_TEXT = """
_Built by Christian Cancedda and LabLab lads with love_ ❤️
[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus)
[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda)
Star project repository:
[![GitHub stars](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus/inference-graph-transformer)
"""
VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window"
DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease"
with st.sidebar:
choice = option_menu("LastMinute - Diagnosis",
["About", "Visualize WSI slide", "Cancer Detection", "Contact"],
icons=['house', 'upload', 'activity', 'person lines fill'],
menu_icon="app-indicator", default_index=0,
styles={
# "container": {"padding": "5!important", "background-color": "#fafafa", },
"container": {"border-radius": ".0rem"},
# "icon": {"color": "orange", "font-size": "25px"},
# "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px",
# "--hover-color": "#eee"},
# "nav-link-selected": {"background-color": "#02ab21"},
}
)
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p style='text-align: center'>
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">Project Repository</a>
</p>
<p class="aligncenter">
<a href="https://github.com/Chris1nexus/inference-graph-transformer" target="_blank">
<img src="https://img.shields.io/github/stars/Chris1nexus/inference-graph-transformer?style=social"/>
</a>
</p>
<p class="aligncenter">
<a href="https://twitter.com/chris_cancedda" target="_blank">
<img src="https://img.shields.io/twitter/follow/chris_cancedda?style=social"/>
</a>
</p>
""",
unsafe_allow_html=True,
)
if choice == "About":
st.title(choice)
if choice == "Visualize WSI slide":
st.title(choice)
st.markdown(VISUALIZE_TEXT)
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
if uploaded_file is not None:
ori = openslide.OpenSlide(uploaded_file.name)
width, height = ori.dimensions
REDUCTION_FACTOR = 20
w, h = int(width/512), int(height/512)
w_r, h_r = int(width/20), int(height/20)
resized_img = ori.get_thumbnail((w_r,h_r))
resized_img = resized_img.resize((w_r,h_r))
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
#print('ratios ', ratio_w, ratio_h)
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
st.image(resized_img, use_column_width='never')
if choice == "Cancer Detection":
state = dict()
st.title(choice)
st.markdown(DETECT_TEXT)
uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)")
if uploaded_file is not None:
# To read file as bytes:
#print(uploaded_file)
with open(os.path.join(uploaded_file.name),"wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner(text="Computation is running"):
predicted_class, viz_dict = predictor.predict(uploaded_file.name)
st.info('Computation completed.')
st.header(f'Predicted to be: {predicted_class}')
st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected')
state['cur'] = predicted_class
mapper = {'ORI': predicted_class, predicted_class:'ORI'}
readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' }
#def fn():
# st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR')
# state['cur'] = mapper[state['cur']]
# return
#st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn )
#st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR')
st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR'
# use_column_width='never',
)
if choice == "Contact":
st.title(choice)
st.markdown(CONTACT_TEXT)
if __name__ == '__main__':
#'''
load_dependencies()
#'''
main()