import socket import urllib.request from pathlib import Path from typing import Literal import anndata import pandas as pd import plotly.express as px import streamlit as st from constants import MODELS def ui_model_selection(): # shared state variables between pages if "SPECIE" not in st.session_state: st.session_state["SPECIE"] = None if "VERSION" not in st.session_state: st.session_state["VERSION"] = None specie = st.sidebar.selectbox( "**Species**", MODELS.keys(), index=list(MODELS.keys()).index(st.session_state["SPECIE"]) if st.session_state["SPECIE"] else None, placeholder="Supported species", ) if specie: version = st.sidebar.selectbox( "**Version**", MODELS[specie], index=MODELS[specie].index(st.session_state["VERSION"]) if st.session_state["VERSION"] else None, placeholder="Version", ) st.sidebar.divider() if specie and version: st.session_state["SPECIE"] = specie st.session_state["VERSION"] = version @st.cache_data def _fetch_resource(url: str, filename: str) -> str: """Helper function for downloading datasets Parameters ---------- url : str Zenodo url link Returns ------- str Path where the file was downloaded to, default /tmp """ destination = Path(f"/tmp/{filename}") if not filename: raise ValueError("Filename not specified!") if not destination.exists(): try: urllib.request.urlretrieve(url, destination) except (socket.gaierror, urllib.error.URLError) as err: raise ConnectionError(f"could not download {url} due to {err}") return destination.as_posix() def fetch_resource(specie: str, version: str) -> anndata.AnnData: """Load H5AD dataset from Hugging Face (https://huggingface.co/brickmanlab) Parameters ---------- specie : str Specie version : str Model version Returns ------- anndata.AnnData Annotated dataset Raises ------ ValueError Specie and Version have to exist """ if specie not in MODELS and version not in MODELS[specie]: raise ValueError(f"Provided {specie} and {version} are not present on Hugging Face models!") url: str = f"https://huggingface.co/brickmanlab/{specie.lower()}-scanvi/resolve/{version}/adata.h5ad" return anndata.read_h5ad(_fetch_resource(url, filename=f"{specie.lower()}_v{version}.h5ad")) def get_embedding(adata: anndata.AnnData, key: str) -> pd.DataFrame: """ Helper function which retrieves embedding coordinates for each cell. Parameters ---------- adata : anndata.AnnData scrna-seq dataset key : str Dimension reduction key, usually starts with X_ Returns ------- pd.DataFrame Embedding coordinates Raises ------ ValueError Fail if reduction key doesn't exist """ if key not in adata.obsm.keys(): raise ValueError(f"Reduction key: {key} not available") dimension_names = f"{key[2:].upper()}_1", f"{key[2:].upper()}_2" return pd.DataFrame(adata.obsm[key][:, :2], columns=dimension_names) def plot_sc_embedding( adata: anndata.AnnData, reduction_key: str, group_by: str = None, feature: str = None, layer: str = None, ax = None, ): """ Plot single-cell dataset Parameters ---------- adata : anndata.AnnData scrna-seq dataset reduction_key : str Reduced space key group_by : str Key used to color cells features: str Gene ax : _type_ Axes """ embeddings = get_embedding(adata, reduction_key) if group_by: embeddings[group_by] = adata.obs[group_by].values embeddings = embeddings.sort_values(by=group_by) # color_uns_key = f"{group_by}_colors" kwargs = {"color": embeddings[group_by].values.tolist()} if adata.obs[group_by].dtype == "category": ... else: kwargs["color_continuous_scale"] = px.colors.sequential.Viridis if feature: X = ( adata[:, feature].layers["scVI_normalized"].toarray() if layer else adata.raw[:, feature].X.toarray() ) embeddings[feature] = X.ravel() kwargs = { "color": embeddings[feature].values.tolist(), # "title": feature, "color_continuous_scale": px.colors.sequential.Viridis, } ax_ = ax if ax else st ax_.plotly_chart( px.scatter( data_frame=embeddings, x=embeddings.columns[0], y=embeddings.columns[1], **kwargs, ), use_container_width=True, # .update_xaxes(showgrid=False) # .update_yaxes(showgrid=False, zeroline=False) ) def plot_feature( adata: anndata.AnnData, feature: str, group_by: str, kind: Literal["box"] = "box", ax = None ): """Plot feature expression Parameters ---------- adata : anndata.AnnData Dataset feature : str Gene name group_by : str Metadata column kind : str Type of plot ax : _type_, optional Axis, by default None """ df = pd.DataFrame(adata.raw[:, feature].X.toarray(), columns=[feature]) df[group_by] = adata.obs[group_by].values df = df.sort_values(by=group_by) g = None if kind == "box": g = px.box(df, x=group_by, y=feature, color=group_by) else: raise ValueError(f"Provided kind: {kind} not supported") ax_ = ax if ax else st ax_.plotly_chart(g, use_container_width=True) def get_degs(adata: anndata.AnnData, key: str) -> pd.DataFrame: """Format DEGs to datagrame. Code taken from https://github.com/scverse/scanpy/blob/1.10.4/src/scanpy/get/get.py#L27-L111 Parameters ---------- adata : anndata.AnnData Annotated dataframe key : str Key used to store the degs Returns ------- pd.DataFrame Dataframe of differentially expressed genes """ group = list(adata.uns[key]["names"].dtype.names) colnames = ["names", "scores", "logfoldchanges", "pvals", "pvals_adj"] d = [pd.DataFrame(adata.uns[key][c])[group] for c in colnames] d = pd.concat(d, axis=1, names=[None, "group"], keys=colnames) d = d.stack(level=1).reset_index() d["group"] = pd.Categorical(d["group"], categories=group) d = d.sort_values(["group", "level_0"]).drop(columns="level_0") return d