import faiss import numpy as np import pandas as pd import os import yaml import glob from easydict import EasyDict from utils.constants import sequence_level from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel from tqdm import tqdm print(os.listdir("/data")) def load_model(): model_config = { "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0], "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0], "load_protein_pretrained": False, "load_text_pretrained": False, "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0] } model = ProTrekTrimodalModel(**model_config) model.eval() return model def load_faiss_index(index_path: str): if config.faiss_config.IO_FLAG_MMAP: index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP) else: index = faiss.read_index(index_path) index.metric_type = faiss.METRIC_INNER_PRODUCT return index def load_index(): all_index = {} # Load protein sequence index all_index["sequence"] = {} for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."): db_name = db["name"] index_dir = db["index_dir"] index_path = f"{index_dir}/sequence.index" sequence_index = load_faiss_index(index_path) id_path = f"{index_dir}/ids.tsv" uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten() all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids} # Load protein structure index print("Loading structure index...") all_index["structure"] = {} for db in tqdm(config.structure_index_dir, desc="Loading structure index..."): db_name = db["name"] index_dir = db["index_dir"] index_path = f"{index_dir}/structure.index" structure_index = load_faiss_index(index_path) id_path = f"{index_dir}/ids.tsv" uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten() all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids} # Load text index all_index["text"] = {} valid_subsections = {} for db in tqdm(config.text_index_dir, desc="Loading text index..."): db_name = db["name"] index_dir = db["index_dir"] all_index["text"][db_name] = {} text_dir = f"{index_dir}/subsections" # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index. valid_subsections[db_name] = set() sequence_level.add("Global") for subsection in tqdm(sequence_level): index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index" if not os.path.exists(index_path): continue text_index = load_faiss_index(index_path) id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv" text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten() all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids} valid_subsections[db_name].add(subsection) # Sort valid_subsections for db_name in valid_subsections: valid_subsections[db_name] = sorted(list(valid_subsections[db_name])) return all_index, valid_subsections # Load the config file root_dir = __file__.rsplit("/", 3)[0] config_path = f"{root_dir}/demo/config.yaml" with open(config_path, 'r', encoding='utf-8') as r: config = EasyDict(yaml.safe_load(r)) device = "cuda" print("Loading model...") model = load_model() # model.to(device) all_index, valid_subsections = load_index() print("Done...") # model = None # all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}