File size: 4,744 Bytes
89c0b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
### Preparing the datasets
To download the [wwPDB dataset](https://www.wwpdb.org/) and proprecessed training data, you need at least 1T disk space.

Use the following command to download the preprocessed wwpdb training databases:

```bash
wget -P /af3-dev/release_data/ https://af3-dev.tos-cn-beijing.volces.com/release_data.tar.gz
tar -xzvf /af3-dev/release_data/release_data.tar.gz -C /af3-dev/release_data/
rm /af3-dev/release_data/release_data.tar.gz
```


The data should be placed in the `/af3-dev/release_data/` directory. You can also download it to a different directory, but remember to modify the `DATA_ROOT_DIR` in [configs/configs_data.py](../configs/configs_data.py) correspondingly.  Data hierarchy after extraction is as follows:

  ```bash
  β”œβ”€β”€ components.v20240608.cif [408M] # ccd source file
  β”œβ”€β”€ components.v20240608.cif.rdkit_mol.pkl [121M] # rdkit Mol object generated by ccd source file
  β”œβ”€β”€ indices [33M] # chain or interface entries
  β”œβ”€β”€ mmcif [283G]  # raw mmcif data
  β”œβ”€β”€ mmcif_bioassembly [36G] # preprocessed wwPDB structural data
  β”œβ”€β”€ mmcif_msa [450G] # msa files
  β”œβ”€β”€ posebusters_bioassembly [42M] # preprocessed posebusters structural data
  β”œβ”€β”€ posebusters_mmcif [361M] # raw mmcif data
  β”œβ”€β”€ recentPDB_bioassembly [1.5G] # preprocessed recentPDB structural data
  └── seq_to_pdb_index.json [45M] # sequence to pdb id mapping file
  ```

Data processing scripts have also been released. you can refer to [prepare_training_data.md](./prepare_training_data.md) for generating `{dataset}_bioassembly` and `indices`. And you can refer to [msa_pipeline.md](./msa_pipeline.md) for pipelines to get `mmcif_msa` and `seq_to_pdb_index.json`.

### Training demo
After the installation and data preparations, you can run the following command to train the model from scratch:

  ```bash
  bash train_demo.sh 
  ```
Key arguments in this scripts are explained as follows:
* `dtype`: data type used in training. Valid options include `"bf16"` and `"fp32"`. 
  * `--dtype fp32`: the model will be trained in full FP32 precision.
  * `--dtype bf16`: the model will be trained in BF16 Mixed precision, by default, the `SampleDiffusion`,`ConfidenceHead`, `Mini-rollout` and `Loss` part will still be training in FP32 precision. if you want to train and infer the model in full BF16 Mixed precision, pass the following arguments to the [train_demo.sh](../train_demo.sh):
    ```bash
    --skip_amp.sample_diffusion_training false \
    --skip_amp.confidence_head false \
    --skip_amp.sample_diffusion false \
    --skip_amp.loss false \
    ```
* `ema_decay`: the decay rate of the EMA, default is 0.999.
* `sample_diffusion.N_step`: during evalutaion, the number of steps for the diffusion process is reduced to 20 to improve efficiency.

* `data.train_sets/data.test_sets`: the datasets used for training and evaluation. If there are multiple datasets, separate them with commas.
* Some settings follow those in the [AlphaFold 3](https://www.nature.com/articles/s41586-024-07487-w) paper, The table in [model_performance.md](../docs/model_performance.md) shows the training settings and memory usages for different training stages.
* In this version, we do not use the template and RNA MSA feature for training. As the default settings in [configs/configs_base.py](../configs/configs_base.py) and [configs/configs_data.py](../configs/configs_data.py):
  ```bash
  --model.template_embedder.n_blocks 0 \
  --data.msa.enable_rna_msa false \
  ```
  This will be considered in our future work.

* The model also supports distributed training with PyTorch’s [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html). For example, if you’re running distributed training on a single node with 4 GPUs, you can use:
  ```bash
  torchrun --nproc_per_node=4 runner/train.py
  ```
  You can also pass other arguments with `--<ARGS_KEY> <ARGS_VALUE>` as you want.


If you want to speed up training, see [<u> setting up kernels documentation </u>](./kernels.md).

### Finetune demo

If you want to fine-tune the model on a specific subset, such as an antibody dataset, you only need to provide a PDB list file and load the pretrained weights as [finetune_demo.sh](../finetune_demo.sh) shows:
    
```bash
# wget -P /af3-dev/release_model/ https://af3-dev.tos-cn-beijing.volces.com/release_model/model_v0.2.0.pt
checkpoint_path="/af3-dev/release_model/model_v0.2.0.pt"
...

--load_checkpoint_path ${checkpoint_path} \
--load_checkpoint_ema_path ${checkpoint_path} \
--data.weightedPDB_before2109_wopb_nometalc_0925.base_info.pdb_list examples/subset.txt \
```

, where the `subset.txt` is a file containing the PDB IDs like:
```bash
6hvq
5mqc
5zin
3ew0
5akv
```