|
import socket |
|
import urllib.request |
|
from pathlib import Path |
|
from typing import Literal |
|
|
|
import anndata |
|
import pandas as pd |
|
import plotly.express as px |
|
import streamlit as st |
|
|
|
from constants import MODELS |
|
|
|
|
|
def ui_model_selection(): |
|
|
|
|
|
if "SPECIE" not in st.session_state: |
|
st.session_state["SPECIE"] = None |
|
if "VERSION" not in st.session_state: |
|
st.session_state["VERSION"] = None |
|
|
|
specie = st.sidebar.selectbox( |
|
"**Species**", |
|
MODELS.keys(), |
|
index=list(MODELS.keys()).index(st.session_state["SPECIE"]) if st.session_state["SPECIE"] else None, |
|
placeholder="Supported species", |
|
) |
|
|
|
if specie: |
|
version = st.sidebar.selectbox( |
|
"**Version**", |
|
MODELS[specie], |
|
index=MODELS[specie].index(st.session_state["VERSION"]) if st.session_state["VERSION"] else None, |
|
placeholder="Version", |
|
) |
|
|
|
st.sidebar.divider() |
|
|
|
if specie and version: |
|
st.session_state["SPECIE"] = specie |
|
st.session_state["VERSION"] = version |
|
|
|
|
|
@st.cache_data |
|
def _fetch_resource(url: str, filename: str) -> str: |
|
"""Helper function for downloading datasets |
|
|
|
Parameters |
|
---------- |
|
url : str |
|
Zenodo url link |
|
|
|
Returns |
|
------- |
|
str |
|
Path where the file was downloaded to, default /tmp |
|
""" |
|
|
|
destination = Path(f"/tmp/{filename}") |
|
if not filename: |
|
raise ValueError("Filename not specified!") |
|
|
|
if not destination.exists(): |
|
try: |
|
urllib.request.urlretrieve(url, destination) |
|
except (socket.gaierror, urllib.error.URLError) as err: |
|
raise ConnectionError(f"could not download {url} due to {err}") |
|
|
|
return destination.as_posix() |
|
|
|
|
|
def fetch_resource(specie: str, version: str) -> anndata.AnnData: |
|
"""Load H5AD dataset from Hugging Face (https://huggingface.co/brickmanlab) |
|
|
|
Parameters |
|
---------- |
|
specie : str |
|
Specie |
|
version : str |
|
Model version |
|
|
|
Returns |
|
------- |
|
anndata.AnnData |
|
Annotated dataset |
|
|
|
Raises |
|
------ |
|
ValueError |
|
Specie and Version have to exist |
|
""" |
|
|
|
if specie not in MODELS and version not in MODELS[specie]: |
|
raise ValueError(f"Provided {specie} and {version} are not present on Hugging Face models!") |
|
|
|
url: str = f"https://huggingface.co/brickmanlab/{specie.lower()}-scanvi/resolve/{version}/adata.h5ad" |
|
return anndata.read_h5ad(_fetch_resource(url, filename=f"{specie.lower()}_v{version}.h5ad")) |
|
|
|
|
|
def get_embedding(adata: anndata.AnnData, key: str) -> pd.DataFrame: |
|
""" |
|
Helper function which retrieves embedding coordinates for each cell. |
|
|
|
Parameters |
|
---------- |
|
adata : anndata.AnnData |
|
scrna-seq dataset |
|
key : str |
|
Dimension reduction key, usually starts with X_ |
|
|
|
Returns |
|
------- |
|
pd.DataFrame |
|
Embedding coordinates |
|
|
|
Raises |
|
------ |
|
ValueError |
|
Fail if reduction key doesn't exist |
|
""" |
|
if key not in adata.obsm.keys(): |
|
raise ValueError(f"Reduction key: {key} not available") |
|
|
|
dimension_names = f"{key[2:].upper()}_1", f"{key[2:].upper()}_2" |
|
return pd.DataFrame(adata.obsm[key][:, :2], columns=dimension_names) |
|
|
|
|
|
def plot_sc_embedding( |
|
adata: anndata.AnnData, |
|
reduction_key: str, |
|
group_by: str = None, |
|
feature: str = None, |
|
layer: str = None, |
|
ax = None, |
|
): |
|
""" |
|
Plot single-cell dataset |
|
|
|
Parameters |
|
---------- |
|
adata : anndata.AnnData |
|
scrna-seq dataset |
|
reduction_key : str |
|
Reduced space key |
|
group_by : str |
|
Key used to color cells |
|
features: str |
|
Gene |
|
ax : _type_ |
|
Axes |
|
""" |
|
embeddings = get_embedding(adata, reduction_key) |
|
|
|
if group_by: |
|
embeddings[group_by] = adata.obs[group_by].values |
|
embeddings = embeddings.sort_values(by=group_by) |
|
|
|
|
|
|
|
kwargs = {"color": embeddings[group_by].values.tolist()} |
|
if adata.obs[group_by].dtype == "category": |
|
... |
|
else: |
|
kwargs["color_continuous_scale"] = px.colors.sequential.Viridis |
|
|
|
if feature: |
|
X = ( |
|
adata[:, feature].layers["scVI_normalized"].toarray() |
|
if layer |
|
else adata.raw[:, feature].X.toarray() |
|
) |
|
embeddings[feature] = X.ravel() |
|
kwargs = { |
|
"color": embeddings[feature].values.tolist(), |
|
|
|
"color_continuous_scale": px.colors.sequential.Viridis, |
|
} |
|
|
|
ax_ = ax if ax else st |
|
ax_.plotly_chart( |
|
px.scatter( |
|
data_frame=embeddings, |
|
x=embeddings.columns[0], |
|
y=embeddings.columns[1], |
|
**kwargs, |
|
), |
|
use_container_width=True, |
|
|
|
|
|
) |
|
|
|
|
|
def plot_feature( |
|
adata: anndata.AnnData, |
|
feature: str, |
|
group_by: str, |
|
kind: Literal["box"] = "box", |
|
ax = None |
|
): |
|
"""Plot feature expression |
|
|
|
Parameters |
|
---------- |
|
adata : anndata.AnnData |
|
Dataset |
|
feature : str |
|
Gene name |
|
group_by : str |
|
Metadata column |
|
kind : str |
|
Type of plot |
|
ax : _type_, optional |
|
Axis, by default None |
|
""" |
|
|
|
df = pd.DataFrame(adata.raw[:, feature].X.toarray(), columns=[feature]) |
|
df[group_by] = adata.obs[group_by].values |
|
df = df.sort_values(by=group_by) |
|
|
|
g = None |
|
if kind == "box": |
|
g = px.box(df, x=group_by, y=feature, color=group_by) |
|
else: |
|
raise ValueError(f"Provided kind: {kind} not supported") |
|
|
|
ax_ = ax if ax else st |
|
ax_.plotly_chart(g, use_container_width=True) |
|
|
|
|
|
def get_degs(adata: anndata.AnnData, key: str) -> pd.DataFrame: |
|
"""Format DEGs to datagrame. |
|
|
|
Code taken from https://github.com/scverse/scanpy/blob/1.10.4/src/scanpy/get/get.py#L27-L111 |
|
|
|
Parameters |
|
---------- |
|
adata : anndata.AnnData |
|
Annotated dataframe |
|
key : str |
|
Key used to store the degs |
|
|
|
Returns |
|
------- |
|
pd.DataFrame |
|
Dataframe of differentially expressed genes |
|
""" |
|
|
|
group = list(adata.uns[key]["names"].dtype.names) |
|
colnames = ["names", "scores", "logfoldchanges", "pvals", "pvals_adj"] |
|
|
|
d = [pd.DataFrame(adata.uns[key][c])[group] for c in colnames] |
|
d = pd.concat(d, axis=1, names=[None, "group"], keys=colnames) |
|
d = d.stack(level=1).reset_index() |
|
d["group"] = pd.Categorical(d["group"], categories=group) |
|
d = d.sort_values(["group", "level_0"]).drop(columns="level_0") |
|
|
|
return d |
|
|