"""Utilitaries methods for the commands module.""" import hashlib from contextlib import contextmanager from rich.progress import Progress, SpinnerColumn, TextColumn from folding_studio.api_call.upload_custom_files import ( CustomFileType, upload_custom_files, ) from folding_studio.console import console @contextmanager def success_fail_catch_spinner(message: str, spinner_name: str = "dots"): """Wrapper around a rich progress spinner that adapt its state icon. Args: message (str): message to show supporting rich format. spinner_name (str, optional): rich SpinnerColumn spinner_name attribute. Defaults to "dots". Examples: ``` with success_fail_catch_spinner("Running Task"): ... >>> Running Task ⠼ # spins as long as the context manager is running # if no error raised then transforms into >>> Running Task ✅ # otherwise transforms into >>> Running Task ❌ An error occurred: ... ``` """ err = None with Progress( TextColumn("{task.description}"), SpinnerColumn(spinner_name, finished_text=""), console=console, ) as progress: task_id = progress.add_task(message, total=1) # response = client.send_request(query) try: yield progress.update( task_id, completed=1, description=f"{message} :white_check_mark:" ) except Exception as e: progress.update(task_id, completed=1, description=f"{message} :x:") err = e # for coherent message order the print has to be made outside the Progress context manager if err is not None: console.print(f"An error occurred: {err}") raise err @contextmanager def success_fail_catch_print(*args, **kwargs): """Wrapper around rich `print` that adapts its state icon. Examples: ``` with success_fail_catch_print("Running Task..."): ... >>> Running Task... # if no error raised then transforms into >>> Running Task... ✅ # otherwise transforms into >>> Running Task... ❌ An error occurred: ... ``` """ console.print(*args, **kwargs, end=" ") try: yield console.print(":white_check_mark:") except Exception as e: console.print(":x:") console.print(f"An error occurred: {e}") raise e def a3m_to_aligned_pqt(directory: str) -> str: """ Finds .a3m files in a directory and merges them into a single aligned Parquet file. Args: directory (str): Path to the directory containing .a3m files. Returns: str: The path to the saved Parquet file. Raises: ValueError: If the directory is invalid, if no records are found in a file, or if query sequences differ among files. """ dir_path = Path(directory) if not dir_path.is_dir(): raise ValueError(f"{directory} is not a valid directory.") mapped_files = {} for file in dir_path.glob("*.a3m"): dbname = file.stem.replace("_hits", "").replace("hits_", "") source = dbname.lower() if dbname else "uniref90" mapped_files[file] = source def parse_a3m(file_path: Path, source: str) -> pd.DataFrame: """ Parses a simple FASTA file. The first record is flagged with source "query"; subsequent records use the provided source. Uses the header both as a comment and (if desired) as a pairing key. """ with open(file_path, "r") as f: lines = f.read().splitlines() records = [] header = None seq_lines = [] for line in lines: if line.startswith(">"): if header is not None: seq = "".join(seq_lines).strip() record_source = "query" if not records else source records.append( { "sequence": seq, "source_database": record_source, "pairing_key": header, "comment": header, } ) header = line[1:].strip() seq_lines = [] else: seq_lines.append(line.strip()) if header is not None: seq = "".join(seq_lines).strip() record_source = "query" if not records else source records.append( { "sequence": seq, "source_database": record_source, "pairing_key": header, "comment": header, } ) if not records: raise ValueError(f"No records found in {file_path}") return pd.DataFrame.from_records(records) dfs = {} for file, source in mapped_files.items(): dfs[file] = parse_a3m(file, source) query_set = {df.iloc[0]["sequence"] for df in dfs.values()} if len(query_set) != 1: raise ValueError("Query sequences differ among files.") merged_df = None for df in dfs.values(): if merged_df is None: merged_df = df.iloc[0:1].copy() merged_df = pd.concat([merged_df, df.iloc[1:]], ignore_index=True) query_seq = merged_df.iloc[0]["sequence"] def hash_sequence(seq: str) -> str: return hashlib.sha256(seq.upper().encode()).hexdigest() output_filename = f"{hash_sequence(query_seq)}.aligned.pqt" dir_path.mkdir(exist_ok=True, parents=True) out_path = dir_path / output_filename merged_df.to_parquet(out_path, index=False) return str(out_path) def process_uploaded_msas(msa_files, headers): """ Uploads the given MSA files and returns a dictionary mapping file names to their uploaded values. """ uploaded = upload_custom_files( headers=headers, paths=msa_files, file_type=CustomFileType.MSA ) msa_paths = {} for f in msa_files: msa_paths[f.name] = uploaded.get(str(f)) or uploaded.get(f.name) return msa_paths