jfaustin commited on
Commit
41f7b15
·
1 Parent(s): f321ade

fix boltz params

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +11 -12
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .envrc
2
 
 
 
1
  .envrc
2
 
3
+ boltz_results/
app.py CHANGED
@@ -5,7 +5,7 @@ from Bio.PDB import MMCIFParser, PDBIO
5
  import logging
6
  import os
7
  from folding_studio.client import Client
8
- from folding_studio.query.boltz import BoltzQuery
9
 
10
 
11
  from molecule import molecule
@@ -37,13 +37,13 @@ def convert_cif_to_pdb(cif_path, pdb_path):
37
  io.set_structure(structure)
38
  io.save(pdb_path)
39
 
40
- def call_boltz(seq_file: Path, api_key: str, output_dir: Path) -> None:
41
  """Call Boltz prediction."""
42
  # Initialize parameters with CLI-provided values
43
  parameters = {
44
  "recycling_steps": 3,
45
- "sampling_steps": 1,
46
- "diffusion_samples": 100,
47
  "step_scale": 1.638,
48
  "msa_pairing_strategy": "greedy",
49
  "write_full_pae": False,
@@ -52,15 +52,17 @@ def call_boltz(seq_file: Path, api_key: str, output_dir: Path) -> None:
52
  "seed": 0,
53
  "custom_msa_paths": None,
54
  }
55
-
56
  # Create a client using API key
57
  logger.info("Authenticating client with API key")
58
  client = Client.from_api_key(api_key=api_key)
59
 
60
  # Define query
61
- query = BoltzQuery.from_file(seq_file, **parameters)
 
62
  query.save_parameters(output_dir)
63
 
 
64
 
65
  # Send a request
66
  logger.info("Sending request to Folding Studio API")
@@ -84,12 +86,10 @@ def predict(sequence: str, api_key: str) -> str:
84
  str: HTML iframe containing 3D molecular visualization
85
  """
86
 
87
- # Create FASTA file with sequence
88
- seq_file = Path("sequence.fasta")
89
- _write_fasta_file(seq_file, sequence)
90
-
91
  # Set up unique output directory based on sequence hash
92
  seq_id = hashlib.sha1(sequence.encode()).hexdigest()
 
 
93
  output_dir = Path(f"sequence_{seq_id}")
94
  output_dir.mkdir(parents=True, exist_ok=True)
95
 
@@ -103,8 +103,7 @@ def predict(sequence: str, api_key: str) -> str:
103
  else:
104
  logger.info("Prediction already exists. Output directory: %s", output_dir)
105
 
106
- # # TODO: remove this
107
- # output_dir = Path("boltz_results")
108
  # Convert output CIF to PDB
109
  pred_cif = list(output_dir.rglob("*_model_0.cif"))[0]
110
  logger.info("Output file: %s", pred_cif)
 
5
  import logging
6
  import os
7
  from folding_studio.client import Client
8
+ from folding_studio.query.boltz import BoltzQuery, BoltzParameters
9
 
10
 
11
  from molecule import molecule
 
37
  io.set_structure(structure)
38
  io.save(pdb_path)
39
 
40
+ def call_boltz(seq_file: Path | str, api_key: str, output_dir: Path) -> None:
41
  """Call Boltz prediction."""
42
  # Initialize parameters with CLI-provided values
43
  parameters = {
44
  "recycling_steps": 3,
45
+ "sampling_steps": 200,
46
+ "diffusion_samples": 1,
47
  "step_scale": 1.638,
48
  "msa_pairing_strategy": "greedy",
49
  "write_full_pae": False,
 
52
  "seed": 0,
53
  "custom_msa_paths": None,
54
  }
55
+
56
  # Create a client using API key
57
  logger.info("Authenticating client with API key")
58
  client = Client.from_api_key(api_key=api_key)
59
 
60
  # Define query
61
+ seq_file = Path(seq_file)
62
+ query = BoltzQuery.from_file(seq_file, query_name="gradio", parameters=BoltzParameters(**parameters))
63
  query.save_parameters(output_dir)
64
 
65
+ logger.info("Payload: %s", query.payload)
66
 
67
  # Send a request
68
  logger.info("Sending request to Folding Studio API")
 
86
  str: HTML iframe containing 3D molecular visualization
87
  """
88
 
 
 
 
 
89
  # Set up unique output directory based on sequence hash
90
  seq_id = hashlib.sha1(sequence.encode()).hexdigest()
91
+ seq_file = Path(f"sequence_{seq_id}.fasta")
92
+ _write_fasta_file(seq_file, sequence)
93
  output_dir = Path(f"sequence_{seq_id}")
94
  output_dir.mkdir(parents=True, exist_ok=True)
95
 
 
103
  else:
104
  logger.info("Prediction already exists. Output directory: %s", output_dir)
105
 
106
+ # output_dir = Path("boltz_results") # debug
 
107
  # Convert output CIF to PDB
108
  pred_cif = list(output_dir.rglob("*_model_0.cif"))[0]
109
  logger.info("Output file: %s", pred_cif)