# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import tarfile import time from typing import List, Tuple import requests from requests.auth import HTTPBasicAuth from tqdm import tqdm TQDM_BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} estimate remaining: {remaining}]" logger = logging.getLogger(__name__) username = "example_user" password = "example_password" def run_mmseqs2_service( x, prefix, use_env=True, use_filter=True, use_templates=False, filter=None, use_pairing=False, pairing_strategy="greedy", host_url="https://api.colabfold.com", user_agent: str = "", email: str = "", ) -> Tuple[List[str], List[str]]: submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" headers = {} if user_agent != "": headers["User-Agent"] = user_agent else: logger.warning( "No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future." ) def submit(seqs, mode, N=101): n, query = N, "" for seq in seqs: query += f"{seq}\n" n += 1 while True: error_count = 0 try: # https://requests.readthedocs.io/en/latest/user/advanced/#advanced # "good practice to set connect timeouts to slightly larger than a multiple of 3" res = requests.post( f"{host_url}/{submission_endpoint}", data={"q": query, "mode": mode, "email": email}, timeout=6.02, headers=headers, auth=HTTPBasicAuth(username, password), ) except requests.exceptions.Timeout: logger.warning("Timeout while submitting to MSA server. Retrying...") continue except Exception as e: error_count += 1 logger.warning( f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" ) logger.warning(f"Error: {e}") time.sleep(5) if error_count > 5: raise continue break try: out = res.json() except ValueError: logger.error(f"Server didn't reply with json: {res.text}") out = {"status": "ERROR"} return out def status(ID): while True: error_count = 0 try: res = requests.get( f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=HTTPBasicAuth(username, password), ) except requests.exceptions.Timeout: logger.warning( "Timeout while fetching status from MSA server. Retrying..." ) continue except Exception as e: error_count += 1 logger.warning( f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" ) logger.warning(f"Error: {e}") time.sleep(5) if error_count > 5: raise continue break try: out = res.json() except ValueError: logger.error(f"Server didn't reply with json: {res.text}") out = {"status": "ERROR"} return out def download(ID, path): error_count = 0 while True: try: res = requests.get( f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=HTTPBasicAuth(username, password), ) except requests.exceptions.Timeout: logger.warning( "Timeout while fetching result from MSA server. Retrying..." ) continue except Exception as e: error_count += 1 logger.warning( f"Error while fetching result from MSA server. Retrying... ({error_count}/5)" ) logger.warning(f"Error: {e}") time.sleep(5) if error_count > 5: raise continue break with open(path, "wb") as out: out.write(res.content) # process input x seqs = [x] if isinstance(x, str) else x # compatibility to old option if filter is not None: use_filter = filter # setup mode if use_filter: mode = "env" if use_env else "all" else: mode = "env-nofilter" if use_env else "nofilter" if use_pairing: use_templates = False use_env = False mode = "" # greedy is default, complete was the previous behavior if pairing_strategy == "greedy": mode = "pairgreedy" elif pairing_strategy == "complete": mode = "paircomplete" # define path path = prefix if not os.path.isdir(path): os.mkdir(path) # call mmseqs2 api tar_gz_file = f"{path}/out.tar.gz" N, REDO = 101, True # deduplicate and keep track of order seqs_unique = [] # TODO this might be slow for large sets [seqs_unique.append(x) for x in seqs if x not in seqs_unique] Ms = [N + seqs_unique.index(seq) for seq in seqs] # lets do it! logger.error("Msa server is running.") if not os.path.isfile(tar_gz_file): TIME_ESTIMATE = 100 with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: while REDO: pbar.set_description("SUBMIT") # Resubmit job until it goes through out = submit(seqs_unique, mode, N) while out["status"] in ["UNKNOWN", "RATELIMIT"]: sleep_time = 60 logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") # resubmit time.sleep(sleep_time) out = submit(seqs_unique, mode, N) if out["status"] == "ERROR": raise Exception( f"MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." ) if out["status"] == "MAINTENANCE": raise Exception( f"MMseqs2 API is undergoing maintenance. Please try again in a few minutes." ) # wait for job to finish ID, TIME = out["id"], 0 pbar.set_description(out["status"]) while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: t = 60 logger.error(f"Sleeping for {t}s. Reason: {out['status']}") time.sleep(t) out = status(ID) pbar.set_description(out["status"]) if out["status"] == "RUNNING": TIME += t pbar.n = min(99, int(100 * TIME / (30.0 * 60))) pbar.refresh() if out["status"] == "COMPLETE": pbar.n = 100 pbar.refresh() REDO = False if out["status"] == "ERROR": REDO = False raise Exception( f"MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later." ) # Download results download(ID, tar_gz_file) with tarfile.open(tar_gz_file) as tar_gz: tar_gz.extractall(os.path.dirname(tar_gz_file)) files = os.listdir(os.path.dirname(tar_gz_file)) if ( "0.a3m" not in files or "pdb70_220313_db.m8" not in files or "uniref_tax.m8" not in files ): raise FileNotFoundError( f"Files 0.a3m, pdb70_220313_db.m8, and uniref_tax.m8 not found in the directory." ) else: print("Files downloaded and extracted successfully.")