File size: 2,208 Bytes
5d2263b
 
3d993f1
5d2263b
 
 
 
 
 
5cbd5ac
5d2263b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2657b5
 
5d2263b
3d993f1
 
 
5d2263b
 
 
860c3d7
 
5d2263b
 
860c3d7
5d2263b
 
 
 
 
 
5cbd5ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
os.chdir('..')
base_dir = os.getcwd()
from dataloader import CellLoader
from celle_main import instantiate_from_config
from omegaconf import OmegaConf

def run_image_prediction(
    sequence_input,
    nucleus_image,
    model_ckpt_path,
    model_config_path,
    device
):
    """
    Run Celle model with provided inputs and display results.

    :param sequence: Path to sequence file
    :param nucleus_image_path: Path to nucleus image
    :param protein_image_path: Path to protein image (optional)
    :param model_ckpt_path: Path to model checkpoint
    :param model_config_path: Path to model config
    """
    # Instantiate dataset object
    dataset = CellLoader(
        sequence_mode="embedding",
        vocab="esm2",
        split_key="val",
        crop_method="center",
        resize=600,
        crop_size=256,
        text_seq_len=1000,
        pad_mode="end",
        threshold="median",
    )

    # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
    sequence = dataset.tokenize_sequence(sequence_input)

    # Load model config and set ckpt_path if not provided in config
    config = OmegaConf.load(model_config_path)
    if config["model"]["params"]["ckpt_path"] is None:
        config["model"]["params"]["ckpt_path"] = model_ckpt_path

    # Set condition_model_path and vqgan_model_path to None
    config["model"]["params"]["condition_model_path"] = None
    config["model"]["params"]["vqgan_model_path"] = None

    os.chdir(os.path.dirname(model_ckpt_path))

    # Instantiate model from config and move to device
    model = instantiate_from_config(config.model).to(device)

    os.chdir(base_dir)

    # Sample from model using provided sequence and nucleus image
    _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample(
        text=sequence.to(device),
        condition=nucleus_image.to(device),
        timesteps=1,
        temperature=1,
        progress=False,
    )

    # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch
    predicted_threshold = predicted_threshold.cpu()[0, 0]
    predicted_heatmap = predicted_heatmap.cpu()[0, 0]

    return predicted_threshold, predicted_heatmap