|
|
|
import streamlit as st |
|
from utils import fetch_resource, ui_model_selection |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.markdown(""" |
|
# SHAP features |
|
|
|
Predicted features (genes) used by the scANVI classifier to determine a cell type. The features |
|
have been determined using [SHAP](https://shap.readthedocs.io/en/latest/). |
|
|
|
Each metric for a feature is determined from 10 random boostraps with replacement. |
|
|
|
- weight_mean: $\mu$ of SHAP value |
|
- weight_std: $\sigma$ of SHAP value |
|
- weight_ci_upper: $\mu$ + $\sigma$ |
|
- weight_ci_lower: $\mu$ - $\sigma$ |
|
- logfoldchanges: Log2fold change from differentiation expression analysis |
|
- pvals_adj: Adjusted p-value from differentiation expression analysis |
|
- scores: Estimated score from differentiation expression analysis |
|
""") |
|
|
|
ui_model_selection() |
|
|
|
filter_condition = [] |
|
|
|
if st.session_state["SPECIE"] and st.session_state["VERSION"]: |
|
adata = fetch_resource(st.session_state["SPECIE"], st.session_state["VERSION"]) |
|
|
|
explainer = st.sidebar.selectbox( |
|
"**Explainer**", |
|
adata.uns["explainer"].keys(), |
|
index=None, |
|
placeholder="Select explainer ...", |
|
) |
|
|
|
if explainer: |
|
shap_values = ( |
|
adata.uns["explainer"][explainer] |
|
.pop("shap_values") |
|
.reset_index() |
|
.rename(columns={"index": "feature"}) |
|
) |
|
params = [f"{k}:\t{v}" for k, v in adata.uns["explainer"][explainer].items()] |
|
|
|
st.sidebar.markdown("**Parameters**") |
|
for k, v in adata.uns["explainer"][explainer].items(): |
|
st.sidebar.markdown(f"{k}:\t{v}") |
|
|
|
celltype = st.sidebar.selectbox( |
|
"**Cell type**", |
|
adata.obs.ct.cat.categories, |
|
index=None, |
|
placeholder="Select cell type ...", |
|
) |
|
|
|
features = st.sidebar.multiselect( |
|
"**Genes**", |
|
sorted(shap_values.feature.unique()), |
|
placeholder="Select genes ...", |
|
) |
|
|
|
if celltype: |
|
filter_condition.append("ct == @celltype") |
|
if features: |
|
filter_condition.append("feature in @features") |
|
|
|
if filter_condition: |
|
shap_values = shap_values.query(" & ".join(filter_condition)) |
|
|
|
st.dataframe(shap_values, use_container_width=True, height=650) |
|
|