|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import tarfile |
|
import time |
|
from typing import List, Tuple |
|
|
|
import requests |
|
from requests.auth import HTTPBasicAuth |
|
from tqdm import tqdm |
|
|
|
TQDM_BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} estimate remaining: {remaining}]" |
|
logger = logging.getLogger(__name__) |
|
|
|
username = "example_user" |
|
password = "example_password" |
|
|
|
|
|
def run_mmseqs2_service( |
|
x, |
|
prefix, |
|
use_env=True, |
|
use_filter=True, |
|
use_templates=False, |
|
filter=None, |
|
use_pairing=False, |
|
pairing_strategy="greedy", |
|
host_url="https://api.colabfold.com", |
|
user_agent: str = "", |
|
email: str = "", |
|
) -> Tuple[List[str], List[str]]: |
|
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" |
|
|
|
headers = {} |
|
if user_agent != "": |
|
headers["User-Agent"] = user_agent |
|
else: |
|
logger.warning( |
|
"No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future." |
|
) |
|
|
|
def submit(seqs, mode, N=101): |
|
n, query = N, "" |
|
for seq in seqs: |
|
query += f"{seq}\n" |
|
n += 1 |
|
|
|
while True: |
|
error_count = 0 |
|
try: |
|
|
|
|
|
res = requests.post( |
|
f"{host_url}/{submission_endpoint}", |
|
data={"q": query, "mode": mode, "email": email}, |
|
timeout=6.02, |
|
headers=headers, |
|
auth=HTTPBasicAuth(username, password), |
|
) |
|
except requests.exceptions.Timeout: |
|
logger.warning("Timeout while submitting to MSA server. Retrying...") |
|
continue |
|
except Exception as e: |
|
error_count += 1 |
|
logger.warning( |
|
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" |
|
) |
|
logger.warning(f"Error: {e}") |
|
time.sleep(5) |
|
if error_count > 5: |
|
raise |
|
continue |
|
break |
|
|
|
try: |
|
out = res.json() |
|
except ValueError: |
|
logger.error(f"Server didn't reply with json: {res.text}") |
|
out = {"status": "ERROR"} |
|
return out |
|
|
|
def status(ID): |
|
while True: |
|
error_count = 0 |
|
try: |
|
res = requests.get( |
|
f"{host_url}/ticket/{ID}", |
|
timeout=6.02, |
|
headers=headers, |
|
auth=HTTPBasicAuth(username, password), |
|
) |
|
except requests.exceptions.Timeout: |
|
logger.warning( |
|
"Timeout while fetching status from MSA server. Retrying..." |
|
) |
|
continue |
|
except Exception as e: |
|
error_count += 1 |
|
logger.warning( |
|
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" |
|
) |
|
logger.warning(f"Error: {e}") |
|
time.sleep(5) |
|
if error_count > 5: |
|
raise |
|
continue |
|
break |
|
try: |
|
out = res.json() |
|
except ValueError: |
|
logger.error(f"Server didn't reply with json: {res.text}") |
|
out = {"status": "ERROR"} |
|
return out |
|
|
|
def download(ID, path): |
|
error_count = 0 |
|
while True: |
|
try: |
|
res = requests.get( |
|
f"{host_url}/result/download/{ID}", |
|
timeout=6.02, |
|
headers=headers, |
|
auth=HTTPBasicAuth(username, password), |
|
) |
|
except requests.exceptions.Timeout: |
|
logger.warning( |
|
"Timeout while fetching result from MSA server. Retrying..." |
|
) |
|
continue |
|
except Exception as e: |
|
error_count += 1 |
|
logger.warning( |
|
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" |
|
) |
|
logger.warning(f"Error: {e}") |
|
time.sleep(5) |
|
if error_count > 5: |
|
raise |
|
continue |
|
break |
|
with open(path, "wb") as out: |
|
out.write(res.content) |
|
|
|
|
|
seqs = [x] if isinstance(x, str) else x |
|
|
|
|
|
if filter is not None: |
|
use_filter = filter |
|
|
|
|
|
if use_filter: |
|
mode = "env" if use_env else "all" |
|
else: |
|
mode = "env-nofilter" if use_env else "nofilter" |
|
|
|
if use_pairing: |
|
use_templates = False |
|
use_env = False |
|
mode = "" |
|
|
|
if pairing_strategy == "greedy": |
|
mode = "pairgreedy" |
|
elif pairing_strategy == "complete": |
|
mode = "paircomplete" |
|
|
|
|
|
path = prefix |
|
if not os.path.isdir(path): |
|
os.mkdir(path) |
|
|
|
|
|
tar_gz_file = f"{path}/out.tar.gz" |
|
N, REDO = 101, True |
|
|
|
|
|
seqs_unique = [] |
|
|
|
[seqs_unique.append(x) for x in seqs if x not in seqs_unique] |
|
Ms = [N + seqs_unique.index(seq) for seq in seqs] |
|
|
|
logger.error("Msa server is running.") |
|
if not os.path.isfile(tar_gz_file): |
|
TIME_ESTIMATE = 100 |
|
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: |
|
while REDO: |
|
pbar.set_description("SUBMIT") |
|
|
|
|
|
out = submit(seqs_unique, mode, N) |
|
while out["status"] in ["UNKNOWN", "RATELIMIT"]: |
|
sleep_time = 60 |
|
logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") |
|
|
|
time.sleep(sleep_time) |
|
out = submit(seqs_unique, mode, N) |
|
|
|
if out["status"] == "ERROR": |
|
raise Exception( |
|
f"MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." |
|
) |
|
|
|
if out["status"] == "MAINTENANCE": |
|
raise Exception( |
|
f"MMseqs2 API is undergoing maintenance. Please try again in a few minutes." |
|
) |
|
|
|
|
|
ID, TIME = out["id"], 0 |
|
pbar.set_description(out["status"]) |
|
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: |
|
t = 60 |
|
logger.error(f"Sleeping for {t}s. Reason: {out['status']}") |
|
time.sleep(t) |
|
out = status(ID) |
|
pbar.set_description(out["status"]) |
|
if out["status"] == "RUNNING": |
|
TIME += t |
|
pbar.n = min(99, int(100 * TIME / (30.0 * 60))) |
|
pbar.refresh() |
|
if out["status"] == "COMPLETE": |
|
pbar.n = 100 |
|
pbar.refresh() |
|
REDO = False |
|
|
|
if out["status"] == "ERROR": |
|
REDO = False |
|
raise Exception( |
|
f"MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." |
|
) |
|
|
|
|
|
download(ID, tar_gz_file) |
|
with tarfile.open(tar_gz_file) as tar_gz: |
|
tar_gz.extractall(os.path.dirname(tar_gz_file)) |
|
files = os.listdir(os.path.dirname(tar_gz_file)) |
|
if ( |
|
"0.a3m" not in files |
|
or "pdb70_220313_db.m8" not in files |
|
or "uniref_tax.m8" not in files |
|
): |
|
raise FileNotFoundError( |
|
f"Files 0.a3m, pdb70_220313_db.m8, and uniref_tax.m8 not found in the directory." |
|
) |
|
else: |
|
print("Files downloaded and extracted successfully.") |
|
|