|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from os.path import join as opjoin |
|
from typing import Dict, List |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
def read_a3m(a3m_file: str) -> tuple[List[str], List[str]]: |
|
"""read a3m file from output of mmseqs |
|
|
|
Args: |
|
a3m_file (str): the a3m file searched by mmseqs(colabfold search) |
|
|
|
Returns: |
|
tuple[List[str], List[str]]: the header and seqs of a3m files |
|
""" |
|
heads = [] |
|
seqs = [] |
|
|
|
|
|
uniref_index = 0 |
|
with open(a3m_file, "r") as infile: |
|
for idx, line in enumerate(infile): |
|
if line.startswith(">"): |
|
heads.append(line) |
|
if idx == 0: |
|
query_name = line |
|
elif idx > 0 and line == query_name: |
|
uniref_index = idx |
|
else: |
|
seqs.append(line) |
|
return heads, seqs, uniref_index |
|
|
|
|
|
def read_m8(m8_file: str) -> Dict[str, str]: |
|
"""the uniref_tax.m8 from output of mmseqs |
|
|
|
Args: |
|
m8_file (str): the uniref_tax.m8 from output of mmseqs(colabfold search) |
|
|
|
Returns: |
|
Dict[str, str]: the dict mapping uniref hit_name to NCBI TaxID |
|
""" |
|
uniref_to_ncbi_taxid = {} |
|
with open(m8_file, "r") as infile: |
|
for line in infile: |
|
line_list = line.replace("\n", "").split("\t") |
|
hit_name = line_list[1] |
|
ncbi_taxid = line_list[2] |
|
uniref_to_ncbi_taxid[hit_name] = ncbi_taxid |
|
return uniref_to_ncbi_taxid |
|
|
|
|
|
def update_a3m( |
|
a3m_path: str, |
|
uniref_to_ncbi_taxid: Dict, |
|
save_root: str, |
|
) -> None: |
|
"""add NCBI TaxID to header if "UniRef" in header |
|
|
|
Args: |
|
a3m_path (str): the original a3m path returned by mmseqs(colabfold search) |
|
uniref_to_ncbi_taxid (Dict): the dict mapping uniref hit_name to NCBI TaxID |
|
save_root (str): the updated a3m |
|
""" |
|
heads, seqs, uniref_index = read_a3m(a3m_path) |
|
print(uniref_index) |
|
fname = a3m_path.split("/")[-1] |
|
out_a3m_path = opjoin(save_root, fname) |
|
with open(out_a3m_path, "w") as ofile: |
|
for idx, (head, seq) in enumerate(zip(heads, seqs)): |
|
uniref_id = head.split("\t")[0][1:] |
|
ncbi_taxid = uniref_to_ncbi_taxid.get(uniref_id, None) |
|
if (ncbi_taxid is not None) and (idx < (uniref_index // 2)): |
|
if not uniref_id.startswith("UniRef100_"): |
|
head = head.replace( |
|
uniref_id, f"UniRef100_{uniref_id}_{ncbi_taxid}/" |
|
) |
|
else: |
|
head = head.replace(uniref_id, f"{uniref_id}_{ncbi_taxid}/") |
|
ofile.write(f"{head}{seq}") |
|
|
|
|
|
if __name__ == "__main__": |
|
input_msa_dir = "./scripts/msa/data/mmcif_msa_initial" |
|
|
|
output_msa_dir = "./scripts/msa/data/mmcif_msa_with_taxid" |
|
os.makedirs(output_msa_dir, exist_ok=True) |
|
|
|
a3m_paths = os.listdir(input_msa_dir) |
|
a3m_paths = [opjoin(input_msa_dir, x) for x in a3m_paths if x.endswith(".a3m")] |
|
m8_file = f"{input_msa_dir}/uniref_tax.m8" |
|
uniref_to_ncbi_taxid = read_m8(m8_file) |
|
for a3m_path in tqdm(a3m_paths): |
|
update_a3m( |
|
a3m_path=a3m_path, |
|
uniref_to_ncbi_taxid=uniref_to_ncbi_taxid, |
|
save_root=output_msa_dir, |
|
) |
|
|