jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Utilitaries methods for the commands module."""
import hashlib
from contextlib import contextmanager
from rich.progress import Progress, SpinnerColumn, TextColumn
from folding_studio.api_call.upload_custom_files import (
CustomFileType,
upload_custom_files,
)
from folding_studio.console import console
@contextmanager
def success_fail_catch_spinner(message: str, spinner_name: str = "dots"):
"""Wrapper around a rich progress spinner that adapt its state icon.
Args:
message (str): message to show supporting rich format.
spinner_name (str, optional): rich SpinnerColumn spinner_name attribute. Defaults to "dots".
Examples:
```
with success_fail_catch_spinner("Running Task"):
...
>>> Running Task β Ό # spins as long as the context manager is running
# if no error raised then transforms into
>>> Running Task βœ…
# otherwise transforms into
>>> Running Task ❌
An error occurred: <ERROR>
...
```
"""
err = None
with Progress(
TextColumn("{task.description}"),
SpinnerColumn(spinner_name, finished_text=""),
console=console,
) as progress:
task_id = progress.add_task(message, total=1)
# response = client.send_request(query)
try:
yield
progress.update(
task_id, completed=1, description=f"{message} :white_check_mark:"
)
except Exception as e:
progress.update(task_id, completed=1, description=f"{message} :x:")
err = e
# for coherent message order the print has to be made outside the Progress context manager
if err is not None:
console.print(f"An error occurred: {err}")
raise err
@contextmanager
def success_fail_catch_print(*args, **kwargs):
"""Wrapper around rich `print` that adapts its state icon.
Examples:
```
with success_fail_catch_print("Running Task..."):
...
>>> Running Task...
# if no error raised then transforms into
>>> Running Task... βœ…
# otherwise transforms into
>>> Running Task... ❌
An error occurred: <ERROR>
...
```
"""
console.print(*args, **kwargs, end=" ")
try:
yield
console.print(":white_check_mark:")
except Exception as e:
console.print(":x:")
console.print(f"An error occurred: {e}")
raise e
def a3m_to_aligned_pqt(directory: str) -> str:
"""
Finds .a3m files in a directory and merges them into a single aligned Parquet file.
Args:
directory (str): Path to the directory containing .a3m files.
Returns:
str: The path to the saved Parquet file.
Raises:
ValueError: If the directory is invalid, if no records are found in a file,
or if query sequences differ among files.
"""
dir_path = Path(directory)
if not dir_path.is_dir():
raise ValueError(f"{directory} is not a valid directory.")
mapped_files = {}
for file in dir_path.glob("*.a3m"):
dbname = file.stem.replace("_hits", "").replace("hits_", "")
source = dbname.lower() if dbname else "uniref90"
mapped_files[file] = source
def parse_a3m(file_path: Path, source: str) -> pd.DataFrame:
"""
Parses a simple FASTA file.
The first record is flagged with source "query"; subsequent records use the provided source.
Uses the header both as a comment and (if desired) as a pairing key.
"""
with open(file_path, "r") as f:
lines = f.read().splitlines()
records = []
header = None
seq_lines = []
for line in lines:
if line.startswith(">"):
if header is not None:
seq = "".join(seq_lines).strip()
record_source = "query" if not records else source
records.append(
{
"sequence": seq,
"source_database": record_source,
"pairing_key": header,
"comment": header,
}
)
header = line[1:].strip()
seq_lines = []
else:
seq_lines.append(line.strip())
if header is not None:
seq = "".join(seq_lines).strip()
record_source = "query" if not records else source
records.append(
{
"sequence": seq,
"source_database": record_source,
"pairing_key": header,
"comment": header,
}
)
if not records:
raise ValueError(f"No records found in {file_path}")
return pd.DataFrame.from_records(records)
dfs = {}
for file, source in mapped_files.items():
dfs[file] = parse_a3m(file, source)
query_set = {df.iloc[0]["sequence"] for df in dfs.values()}
if len(query_set) != 1:
raise ValueError("Query sequences differ among files.")
merged_df = None
for df in dfs.values():
if merged_df is None:
merged_df = df.iloc[0:1].copy()
merged_df = pd.concat([merged_df, df.iloc[1:]], ignore_index=True)
query_seq = merged_df.iloc[0]["sequence"]
def hash_sequence(seq: str) -> str:
return hashlib.sha256(seq.upper().encode()).hexdigest()
output_filename = f"{hash_sequence(query_seq)}.aligned.pqt"
dir_path.mkdir(exist_ok=True, parents=True)
out_path = dir_path / output_filename
merged_df.to_parquet(out_path, index=False)
return str(out_path)
def process_uploaded_msas(msa_files, headers):
"""
Uploads the given MSA files and returns a dictionary mapping file names to their uploaded values.
"""
uploaded = upload_custom_files(
headers=headers, paths=msa_files, file_type=CustomFileType.MSA
)
msa_paths = {}
for f in msa_files:
msa_paths[f.name] = uploaded.get(str(f)) or uploaded.get(f.name)
return msa_paths