FoldMark / scripts /msa /step3-uniref_add_taxid.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import join as opjoin
from typing import Dict, List
from tqdm import tqdm
def read_a3m(a3m_file: str) -> tuple[List[str], List[str]]:
"""read a3m file from output of mmseqs
Args:
a3m_file (str): the a3m file searched by mmseqs(colabfold search)
Returns:
tuple[List[str], List[str]]: the header and seqs of a3m files
"""
heads = []
seqs = []
# Record the row index. The index before this index is the MSA of Uniref30 DB,
# and the index after this index is the MSA of ColabfoldDB.
uniref_index = 0
with open(a3m_file, "r") as infile:
for idx, line in enumerate(infile):
if line.startswith(">"):
heads.append(line)
if idx == 0:
query_name = line
elif idx > 0 and line == query_name:
uniref_index = idx
else:
seqs.append(line)
return heads, seqs, uniref_index
def read_m8(m8_file: str) -> Dict[str, str]:
"""the uniref_tax.m8 from output of mmseqs
Args:
m8_file (str): the uniref_tax.m8 from output of mmseqs(colabfold search)
Returns:
Dict[str, str]: the dict mapping uniref hit_name to NCBI TaxID
"""
uniref_to_ncbi_taxid = {}
with open(m8_file, "r") as infile:
for line in infile:
line_list = line.replace("\n", "").split("\t")
hit_name = line_list[1]
ncbi_taxid = line_list[2]
uniref_to_ncbi_taxid[hit_name] = ncbi_taxid
return uniref_to_ncbi_taxid
def update_a3m(
a3m_path: str,
uniref_to_ncbi_taxid: Dict,
save_root: str,
) -> None:
"""add NCBI TaxID to header if "UniRef" in header
Args:
a3m_path (str): the original a3m path returned by mmseqs(colabfold search)
uniref_to_ncbi_taxid (Dict): the dict mapping uniref hit_name to NCBI TaxID
save_root (str): the updated a3m
"""
heads, seqs, uniref_index = read_a3m(a3m_path)
print(uniref_index)
fname = a3m_path.split("/")[-1]
out_a3m_path = opjoin(save_root, fname)
with open(out_a3m_path, "w") as ofile:
for idx, (head, seq) in enumerate(zip(heads, seqs)):
uniref_id = head.split("\t")[0][1:]
ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None)
if (ncbi_taxid is not None) and (idx < (uniref_index // 2)):
if not uniref_id.startswith("UniRef100_"):
head = head.replace(
uniref_id, f"UniRef100_{uniref_id}_{ncbi_taxid}/"
)
else:
head = head.replace(uniref_id, f"{uniref_id}_{ncbi_taxid}/")
ofile.write(f"{head}{seq}")
if __name__ == "__main__":
input_msa_dir = "./scripts/msa/data/mmcif_msa_initial"
output_msa_dir = "./scripts/msa/data/mmcif_msa_with_taxid"
os.makedirs(output_msa_dir, exist_ok=True)
a3m_paths = os.listdir(input_msa_dir)
a3m_paths = [opjoin(input_msa_dir, x) for x in a3m_paths if x.endswith(".a3m")]
m8_file = f"{input_msa_dir}/uniref_tax.m8"
uniref_to_ncbi_taxid = read_m8(m8_file)
for a3m_path in tqdm(a3m_paths):
update_a3m(
a3m_path=a3m_path,
uniref_to_ncbi_taxid=uniref_to_ncbi_taxid,
save_root=output_msa_dir,
)