|
"""Utilitaries methods for the commands module.""" |
|
|
|
import hashlib |
|
from contextlib import contextmanager |
|
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn |
|
|
|
from folding_studio.api_call.upload_custom_files import ( |
|
CustomFileType, |
|
upload_custom_files, |
|
) |
|
from folding_studio.console import console |
|
|
|
|
|
@contextmanager |
|
def success_fail_catch_spinner(message: str, spinner_name: str = "dots"): |
|
"""Wrapper around a rich progress spinner that adapt its state icon. |
|
|
|
Args: |
|
message (str): message to show supporting rich format. |
|
spinner_name (str, optional): rich SpinnerColumn spinner_name attribute. Defaults to "dots". |
|
|
|
Examples: |
|
``` |
|
with success_fail_catch_spinner("Running Task"): |
|
... |
|
|
|
>>> Running Task β Ό # spins as long as the context manager is running |
|
# if no error raised then transforms into |
|
>>> Running Task β
|
|
# otherwise transforms into |
|
>>> Running Task β |
|
An error occurred: <ERROR> |
|
... |
|
``` |
|
""" |
|
err = None |
|
with Progress( |
|
TextColumn("{task.description}"), |
|
SpinnerColumn(spinner_name, finished_text=""), |
|
console=console, |
|
) as progress: |
|
task_id = progress.add_task(message, total=1) |
|
|
|
try: |
|
yield |
|
progress.update( |
|
task_id, completed=1, description=f"{message} :white_check_mark:" |
|
) |
|
except Exception as e: |
|
progress.update(task_id, completed=1, description=f"{message} :x:") |
|
err = e |
|
|
|
if err is not None: |
|
console.print(f"An error occurred: {err}") |
|
raise err |
|
|
|
|
|
@contextmanager |
|
def success_fail_catch_print(*args, **kwargs): |
|
"""Wrapper around rich `print` that adapts its state icon. |
|
|
|
Examples: |
|
``` |
|
with success_fail_catch_print("Running Task..."): |
|
... |
|
|
|
>>> Running Task... |
|
# if no error raised then transforms into |
|
>>> Running Task... β
|
|
# otherwise transforms into |
|
>>> Running Task... β |
|
An error occurred: <ERROR> |
|
... |
|
``` |
|
""" |
|
console.print(*args, **kwargs, end=" ") |
|
try: |
|
yield |
|
console.print(":white_check_mark:") |
|
except Exception as e: |
|
console.print(":x:") |
|
console.print(f"An error occurred: {e}") |
|
raise e |
|
|
|
|
|
def a3m_to_aligned_pqt(directory: str) -> str: |
|
""" |
|
Finds .a3m files in a directory and merges them into a single aligned Parquet file. |
|
|
|
Args: |
|
directory (str): Path to the directory containing .a3m files. |
|
|
|
Returns: |
|
str: The path to the saved Parquet file. |
|
|
|
Raises: |
|
ValueError: If the directory is invalid, if no records are found in a file, |
|
or if query sequences differ among files. |
|
""" |
|
dir_path = Path(directory) |
|
if not dir_path.is_dir(): |
|
raise ValueError(f"{directory} is not a valid directory.") |
|
|
|
mapped_files = {} |
|
for file in dir_path.glob("*.a3m"): |
|
dbname = file.stem.replace("_hits", "").replace("hits_", "") |
|
source = dbname.lower() if dbname else "uniref90" |
|
mapped_files[file] = source |
|
|
|
def parse_a3m(file_path: Path, source: str) -> pd.DataFrame: |
|
""" |
|
Parses a simple FASTA file. |
|
The first record is flagged with source "query"; subsequent records use the provided source. |
|
Uses the header both as a comment and (if desired) as a pairing key. |
|
""" |
|
with open(file_path, "r") as f: |
|
lines = f.read().splitlines() |
|
|
|
records = [] |
|
header = None |
|
seq_lines = [] |
|
for line in lines: |
|
if line.startswith(">"): |
|
if header is not None: |
|
seq = "".join(seq_lines).strip() |
|
record_source = "query" if not records else source |
|
records.append( |
|
{ |
|
"sequence": seq, |
|
"source_database": record_source, |
|
"pairing_key": header, |
|
"comment": header, |
|
} |
|
) |
|
header = line[1:].strip() |
|
seq_lines = [] |
|
else: |
|
seq_lines.append(line.strip()) |
|
if header is not None: |
|
seq = "".join(seq_lines).strip() |
|
record_source = "query" if not records else source |
|
records.append( |
|
{ |
|
"sequence": seq, |
|
"source_database": record_source, |
|
"pairing_key": header, |
|
"comment": header, |
|
} |
|
) |
|
if not records: |
|
raise ValueError(f"No records found in {file_path}") |
|
return pd.DataFrame.from_records(records) |
|
|
|
dfs = {} |
|
for file, source in mapped_files.items(): |
|
dfs[file] = parse_a3m(file, source) |
|
|
|
query_set = {df.iloc[0]["sequence"] for df in dfs.values()} |
|
if len(query_set) != 1: |
|
raise ValueError("Query sequences differ among files.") |
|
|
|
merged_df = None |
|
for df in dfs.values(): |
|
if merged_df is None: |
|
merged_df = df.iloc[0:1].copy() |
|
merged_df = pd.concat([merged_df, df.iloc[1:]], ignore_index=True) |
|
|
|
query_seq = merged_df.iloc[0]["sequence"] |
|
|
|
def hash_sequence(seq: str) -> str: |
|
return hashlib.sha256(seq.upper().encode()).hexdigest() |
|
|
|
output_filename = f"{hash_sequence(query_seq)}.aligned.pqt" |
|
|
|
dir_path.mkdir(exist_ok=True, parents=True) |
|
out_path = dir_path / output_filename |
|
|
|
merged_df.to_parquet(out_path, index=False) |
|
return str(out_path) |
|
|
|
|
|
def process_uploaded_msas(msa_files, headers): |
|
""" |
|
Uploads the given MSA files and returns a dictionary mapping file names to their uploaded values. |
|
""" |
|
uploaded = upload_custom_files( |
|
headers=headers, paths=msa_files, file_type=CustomFileType.MSA |
|
) |
|
msa_paths = {} |
|
for f in msa_files: |
|
msa_paths[f.name] = uploaded.get(str(f)) or uploaded.get(f.name) |
|
return msa_paths |
|
|