File size: 6,679 Bytes
fc741fb
 
 
 
 
 
 
 
 
 
0e1164c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc741fb
 
0e1164c
fc741fb
 
 
 
 
 
 
 
 
 
 
 
 
0e1164c
 
 
 
 
fc741fb
0e1164c
fc741fb
 
 
0e1164c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc741fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e1164c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import socket
import urllib.request
from pathlib import Path
from typing import Literal

import anndata
import pandas as pd
import plotly.express as px
import streamlit as st

from constants import MODELS


def ui_model_selection():

    # shared state variables between pages
    if "SPECIE" not in st.session_state:
        st.session_state["SPECIE"] = None
    if "VERSION" not in st.session_state:
        st.session_state["VERSION"] = None

    specie = st.sidebar.selectbox(
        "**Species**",
        MODELS.keys(),
        index=list(MODELS.keys()).index(st.session_state["SPECIE"]) if st.session_state["SPECIE"] else None,
        placeholder="Supported species",
    )

    if specie:
        version = st.sidebar.selectbox(
            "**Version**",
            MODELS[specie],
            index=MODELS[specie].index(st.session_state["VERSION"]) if st.session_state["VERSION"] else None,
            placeholder="Version",
        )

    st.sidebar.divider()

    if specie and version:
        st.session_state["SPECIE"] = specie
        st.session_state["VERSION"] = version


@st.cache_data
def _fetch_resource(url: str, filename: str) -> str:
    """Helper function for downloading datasets

    Parameters
    ----------
    url : str
        Zenodo url link

    Returns
    -------
    str
        Path where the file was downloaded to, default /tmp
    """
    
    destination = Path(f"/tmp/{filename}")
    if not filename:
        raise ValueError("Filename not specified!")

    if not destination.exists():
        try:
            urllib.request.urlretrieve(url, destination)
        except (socket.gaierror, urllib.error.URLError) as err:
            raise ConnectionError(f"could not download {url} due to {err}")

    return destination.as_posix()


def fetch_resource(specie: str, version: str) -> anndata.AnnData:
    """Load H5AD dataset from Hugging Face (https://huggingface.co/brickmanlab)

    Parameters
    ----------
    specie : str
        Specie
    version : str
        Model version

    Returns
    -------
    anndata.AnnData
        Annotated dataset

    Raises
    ------
    ValueError
        Specie and Version have to exist
    """
    
    if specie not in MODELS and version not in MODELS[specie]:
        raise ValueError(f"Provided {specie} and {version} are not present on Hugging Face models!")
    
    url: str = f"https://huggingface.co/brickmanlab/{specie.lower()}-scanvi/resolve/{version}/adata.h5ad"
    return anndata.read_h5ad(_fetch_resource(url, filename=f"{specie.lower()}_v{version}.h5ad"))


def get_embedding(adata: anndata.AnnData, key: str) -> pd.DataFrame:
    """
    Helper function which retrieves embedding coordinates for each cell.

    Parameters
    ----------
    adata : anndata.AnnData
        scrna-seq dataset
    key : str
        Dimension reduction key, usually starts with X_

    Returns
    -------
    pd.DataFrame
        Embedding coordinates

    Raises
    ------
    ValueError
        Fail if reduction key doesn't exist
    """
    if key not in adata.obsm.keys():
        raise ValueError(f"Reduction key: {key} not available")

    dimension_names = f"{key[2:].upper()}_1", f"{key[2:].upper()}_2"
    return pd.DataFrame(adata.obsm[key][:, :2], columns=dimension_names)


def plot_sc_embedding(
    adata: anndata.AnnData,
    reduction_key: str,
    group_by: str = None,
    feature: str = None,
    layer: str = None,
    ax = None,
):
    """
    Plot single-cell dataset

    Parameters
    ----------
    adata : anndata.AnnData
        scrna-seq dataset
    reduction_key : str
        Reduced space key
    group_by : str
        Key used to color cells
    features: str
        Gene
    ax : _type_
        Axes
    """
    embeddings = get_embedding(adata, reduction_key)

    if group_by:
        embeddings[group_by] = adata.obs[group_by].values
        embeddings = embeddings.sort_values(by=group_by)

        # color_uns_key = f"{group_by}_colors"

        kwargs = {"color": embeddings[group_by].values.tolist()}
        if adata.obs[group_by].dtype == "category":
            ...
        else:
            kwargs["color_continuous_scale"] = px.colors.sequential.Viridis

    if feature:
        X = (
            adata[:, feature].layers["scVI_normalized"].toarray()
            if layer
            else adata.raw[:, feature].X.toarray()
        )
        embeddings[feature] = X.ravel()
        kwargs = {
            "color": embeddings[feature].values.tolist(),
            # "title": feature,
            "color_continuous_scale": px.colors.sequential.Viridis,
        }

    ax_ = ax if ax else st
    ax_.plotly_chart(
        px.scatter(
            data_frame=embeddings,
            x=embeddings.columns[0],
            y=embeddings.columns[1],
            **kwargs,
        ),
        use_container_width=True,
        # .update_xaxes(showgrid=False)
        # .update_yaxes(showgrid=False, zeroline=False)
    )


def plot_feature(
    adata: anndata.AnnData, 
    feature: str, 
    group_by: str, 
    kind: Literal["box"] = "box", 
    ax = None
):
    """Plot feature expression

    Parameters
    ----------
    adata : anndata.AnnData
        Dataset
    feature : str
        Gene name
    group_by : str
        Metadata column
    kind : str
        Type of plot
    ax : _type_, optional
        Axis, by default None
    """

    df = pd.DataFrame(adata.raw[:, feature].X.toarray(), columns=[feature])
    df[group_by] = adata.obs[group_by].values
    df = df.sort_values(by=group_by)

    g = None
    if kind == "box":
        g = px.box(df, x=group_by, y=feature, color=group_by)
    else:
        raise ValueError(f"Provided kind: {kind} not supported")

    ax_ = ax if ax else st
    ax_.plotly_chart(g, use_container_width=True)


def get_degs(adata: anndata.AnnData, key: str) -> pd.DataFrame:
    """Format DEGs to datagrame.

    Code taken from https://github.com/scverse/scanpy/blob/1.10.4/src/scanpy/get/get.py#L27-L111

    Parameters
    ----------
    adata : anndata.AnnData
        Annotated dataframe
    key : str
        Key used to store the degs

    Returns
    -------
    pd.DataFrame
        Dataframe of differentially expressed genes
    """
    
    group = list(adata.uns[key]["names"].dtype.names)
    colnames = ["names", "scores", "logfoldchanges", "pvals", "pvals_adj"]
    
    d = [pd.DataFrame(adata.uns[key][c])[group] for c in colnames]
    d = pd.concat(d, axis=1, names=[None, "group"], keys=colnames)
    d = d.stack(level=1).reset_index()
    d["group"] = pd.Categorical(d["group"], categories=group)
    d = d.sort_values(["group", "level_0"]).drop(columns="level_0")

    return d