LTEnjoy's picture
Update demo/modules/init_model.py
ed1afd7 verified
raw
history blame
4.01 kB
import faiss
import numpy as np
import pandas as pd
import os
import yaml
import glob
from easydict import EasyDict
from utils.constants import sequence_level
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
from tqdm import tqdm
print(os.listdir("/data"))
def load_model():
model_config = {
"protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
"text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
"load_protein_pretrained": False,
"load_text_pretrained": False,
"from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
}
model = ProTrekTrimodalModel(**model_config)
model.eval()
return model
def load_faiss_index(index_path: str):
if config.faiss_config.IO_FLAG_MMAP:
index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(index_path)
index.metric_type = faiss.METRIC_INNER_PRODUCT
return index
def load_index():
all_index = {}
# Load protein sequence index
all_index["sequence"] = {}
for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
db_name = db["name"]
index_dir = db["index_dir"]
index_path = f"{index_dir}/sequence.index"
sequence_index = load_faiss_index(index_path)
id_path = f"{index_dir}/ids.tsv"
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
# Load protein structure index
print("Loading structure index...")
all_index["structure"] = {}
for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
db_name = db["name"]
index_dir = db["index_dir"]
index_path = f"{index_dir}/structure.index"
structure_index = load_faiss_index(index_path)
id_path = f"{index_dir}/ids.tsv"
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
# Load text index
all_index["text"] = {}
valid_subsections = {}
for db in tqdm(config.text_index_dir, desc="Loading text index..."):
db_name = db["name"]
index_dir = db["index_dir"]
all_index["text"][db_name] = {}
text_dir = f"{index_dir}/subsections"
# Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
valid_subsections[db_name] = set()
sequence_level.add("Global")
for subsection in tqdm(sequence_level):
index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
if not os.path.exists(index_path):
continue
text_index = load_faiss_index(index_path)
id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
valid_subsections[db_name].add(subsection)
# Sort valid_subsections
for db_name in valid_subsections:
valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
return all_index, valid_subsections
# Load the config file
root_dir = __file__.rsplit("/", 3)[0]
config_path = f"{root_dir}/demo/config.yaml"
with open(config_path, 'r', encoding='utf-8') as r:
config = EasyDict(yaml.safe_load(r))
device = "cuda"
print("Loading model...")
model = load_model()
# model.to(device)
all_index, valid_subsections = load_index()
print("Done...")
# model = None
# all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}