File size: 4,302 Bytes
5984d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
import pathlib
from typing import List

import gradio as gr
import numpy as np
import pandas as pd
from gt4sd.algorithms.conditional_generation.paccmann_rl import (
    PaccMannRL,
    PaccMannRLOmicBasedGenerator,
    PaccMannRLProteinBasedGenerator,
)
from gt4sd.algorithms.generation.paccmann_vae import PaccMannVAE, PaccMannVAEGenerator
from gt4sd.algorithms.registry import ApplicationsRegistry

from utils import draw_grid_generate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(
    algorithm_version: str,
    inference_type: str,
    protein_target: str,
    omics_target: str,
    temperature: float,
    length: float,
    number_of_samples: int,
):
    if inference_type == "Unbiased":
        algorithm_class = PaccMannVAEGenerator
        model_class = PaccMannVAE
        target = None
    elif inference_type == "Conditional":
        if "Protein" in algorithm_version:
            algorithm_class = PaccMannRLProteinBasedGenerator
            target = protein_target
        elif "Omic" in algorithm_version:
            algorithm_class = PaccMannRLOmicBasedGenerator
            try:
                test_target = [float(x) for x in omics_target.split(" ")]
            except Exception:
                raise ValueError(
                    f"Expected 2128 space-separated omics values, got {omics_target}"
                )
            if len(test_target) != 2128:
                raise ValueError(
                    f"Expected 2128 omics values, got {len(target)}: {target}"
                )
            target = f"[{omics_target.replace(' ', ',')}]"
        else:
            raise ValueError(f"Unknown algorithm version {algorithm_version}")
        model_class = PaccMannRL
    else:
        raise ValueError(f"Unknown inference type {inference_type}")

    config = algorithm_class(
        algorithm_version.split("_")[-1],
        temperature=temperature,
        generated_length=length,
    )
    print("Target is ", target)
    print(type(target), len(target))
    model = model_class(config, target=target)
    samples = list(model.sample(number_of_samples))

    return draw_grid_generate(samples=samples, n_cols=5)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_application"].split("Based")[0].split("PaccMannRL")[-1]
        + "_"
        + x["algorithm_version"]
        for x in list(filter(lambda x: "PaccMannRL" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="PaccMannRL",
        inputs=[
            gr.Dropdown(algos, label="Algorithm version", value="Protein_v0"),
            gr.Radio(
                choices=["Conditional", "Unbiased"],
                label="Inference type",
                value="Conditional",
            ),
            gr.Textbox(
                label="Protein target",
                placeholder="MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT",
                lines=1,
            ),
            gr.Textbox(
                label="Gene expression target",
                placeholder=f"{' '.join(map(str, np.round(np.random.rand(2128), 2)))}",
                lines=1,
            ),
            gr.Slider(minimum=0.5, maximum=2, value=1, label="Decoding temperature"),
            gr.Slider(
                minimum=5,
                maximum=400,
                value=100,
                label="Maximal sequence length",
                step=1,
            ),
            gr.Slider(
                minimum=1, maximum=50, value=10, label="Number of samples", step=1
            ),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)