import os os.chdir('..') base_dir = os.getcwd() from dataloader import CellLoader from celle_main import instantiate_from_config from omegaconf import OmegaConf def run_image_prediction( sequence_input, nucleus_image, model_ckpt_path, model_config_path, device ): """ Run Celle model with provided inputs and display results. :param sequence: Path to sequence file :param nucleus_image_path: Path to nucleus image :param protein_image_path: Path to protein image (optional) :param model_ckpt_path: Path to model checkpoint :param model_config_path: Path to model config """ # Instantiate dataset object dataset = CellLoader( sequence_mode="embedding", vocab="esm2", split_key="val", crop_method="center", resize=600, crop_size=256, text_seq_len=1000, pad_mode="end", threshold="median", ) # Convert SEQUENCE to sequence using dataset.tokenize_sequence() sequence = dataset.tokenize_sequence(sequence_input) # Load model config and set ckpt_path if not provided in config config = OmegaConf.load(model_config_path) if config["model"]["params"]["ckpt_path"] is None: config["model"]["params"]["ckpt_path"] = model_ckpt_path # Set condition_model_path and vqgan_model_path to None config["model"]["params"]["condition_model_path"] = None config["model"]["params"]["vqgan_model_path"] = None os.chdir(os.path.dirname(model_ckpt_path)) # Instantiate model from config and move to device model = instantiate_from_config(config.model).to(device) os.chdir(base_dir) # Sample from model using provided sequence and nucleus image _, _, _, predicted_threshold, predicted_heatmap = model.celle.sample( text=sequence.to(device), condition=nucleus_image.to(device), timesteps=1, temperature=1, progress=False, ) # Move predicted_threshold and predicted_heatmap to CPU and select first element of batch predicted_threshold = predicted_threshold.cpu()[0, 0] predicted_heatmap = predicted_heatmap.cpu()[0, 0] return predicted_threshold, predicted_heatmap