Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9.smi +1 -0
- data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.npy +3 -0
- data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.xyz +11 -0
- data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c.smi +1 -0
- data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.npy +3 -0
- data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.xyz +21 -0
- data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939.smi +1 -0
- data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.npy +3 -0
- data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.xyz +21 -0
- data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6.smi +1 -0
- data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.npy +3 -0
- data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.xyz +13 -0
- data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d.smi +1 -0
- data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.npy +3 -0
- data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.xyz +25 -0
- data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0.smi +1 -0
- data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.npy +3 -0
- data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.xyz +13 -0
- data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8.smi +1 -0
- data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.npy +3 -0
- data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.xyz +15 -0
- data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb.smi +1 -0
- data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.npy +3 -0
- data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.xyz +18 -0
- data/datasets/moleculenet/qm9/qm9.csv +3 -0
- data/datasets/moleculenet/qm9/test.csv +0 -0
- data/datasets/moleculenet/qm9/train.csv +3 -0
- data/datasets/moleculenet/qm9/valid.csv +0 -0
- finetune/args.py +40 -0
- finetune/dataset/__init__.py +7 -0
- finetune/dataset/default.py +35 -0
- finetune/finetune_regression.py +92 -0
- finetune/run_finetune_qm9_alpha.sh +24 -0
- finetune/run_finetune_qm9_cv.sh +24 -0
- finetune/run_finetune_qm9_g298.sh +24 -0
- finetune/run_finetune_qm9_gap.sh +24 -0
- finetune/run_finetune_qm9_h298.sh +24 -0
- finetune/run_finetune_qm9_homo.sh +24 -0
- finetune/run_finetune_qm9_lumo.sh +24 -0
- finetune/run_finetune_qm9_mu.sh +24 -0
- finetune/run_finetune_qm9_r2.sh +24 -0
- finetune/run_finetune_qm9_u0.sh +24 -0
- finetune/run_finetune_qm9_u298.sh +24 -0
- finetune/run_finetune_qm9_zpve.sh +24 -0
- finetune/trainers.py +359 -0
- finetune/utils.py +126 -0
- images/3dgridvqgan_architecture.png +3 -0
- inference/run_embeddings_eval_xgboost.sh +2 -0
- inference/run_extract_embeddings.sh +8 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/datasets/moleculenet/qm9/qm9.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/datasets/moleculenet/qm9/train.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
images/3dgridvqgan_architecture.png filter=lfs diff=lfs merge=lfs -text
|
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
O=CNC1=NOC=N1
|
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a17ceb4553e1861d34ebaf0d0b2e5f33975f41632eef13f662ed6098e06eb35
|
3 |
+
size 338816
|
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.xyz
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
O -0.3433117873 0.0542505386 0.0030523182
|
2 |
+
C -0.0862462134 1.2279140618 0.0081946405
|
3 |
+
N 1.1846004685 1.7722542978 -0.0009450675
|
4 |
+
C 2.386606203 1.0948306964 -0.0156497688
|
5 |
+
N 2.5353255646 -0.2087260828 -0.0269119898
|
6 |
+
O 3.9407365203 -0.3429161461 -0.0391639287
|
7 |
+
C 4.4419993701 0.8855820068 -0.0338427395
|
8 |
+
N 3.5551776308 1.8291029217 -0.0192066804
|
9 |
+
H -0.8486325877 2.0303875837 0.0201485951
|
10 |
+
H 1.2766287385 2.7775349452 0.0053395306
|
11 |
+
H 5.5174865125 0.985073027 -0.0416438097
|
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
C#CCC1CC1(C)CO
|
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c211d6096850f4fd8bb7e862cb0e3cf09492ed5fb165348829ac37ed9d3c344a
|
3 |
+
size 2612864
|
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.xyz
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C 0.0095240887 1.5961806115 -0.0858857917
|
2 |
+
C -0.0675156706 0.086549314 0.0511974473
|
3 |
+
C -1.3201762415 -0.5125406846 -0.5508961783
|
4 |
+
O -1.3261797579 -0.2420461247 -1.9477734763
|
5 |
+
C 1.1982985053 -0.7358014855 -0.0482283446
|
6 |
+
C 0.4787861145 -0.6243589949 1.2724834132
|
7 |
+
C 1.0934534769 0.1242004421 2.4490763186
|
8 |
+
C 1.9615545513 -0.7265702717 3.2588924687
|
9 |
+
C 2.6652427105 -1.4441035487 3.9181710713
|
10 |
+
H 0.9721361795 1.9954968448 0.2425890448
|
11 |
+
H -0.1285829928 1.8750984225 -1.1338899379
|
12 |
+
H -0.7769315564 2.0893384777 0.4988274146
|
13 |
+
H -1.3405285423 -1.5971735621 -0.3600593089
|
14 |
+
H -2.2081101606 -0.0754460014 -0.0643042439
|
15 |
+
H -2.18303183 -0.5001196615 -2.2994922363
|
16 |
+
H 1.1667370542 -1.6792797399 -0.5822025643
|
17 |
+
H 2.1360390085 -0.1991832668 -0.1544507441
|
18 |
+
H -0.0724271026 -1.5082395772 1.583953566
|
19 |
+
H 1.6685961108 0.9854383459 2.091700259
|
20 |
+
H 0.2979810602 0.5352503178 3.0854850798
|
21 |
+
H 3.2930319243 -2.0701853274 4.5019859333
|
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CCC1C=C2CCC2O1
|
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5adbb7b36ff91928b2ecec9e548962d45c5241d3949d35e5274511b0af7aa574
|
3 |
+
size 2471168
|
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.xyz
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C -0.0055109461 1.5153095531 -0.1082270888
|
2 |
+
C -0.0178026585 -0.0131266806 -0.0290282794
|
3 |
+
C -1.426157367 -0.5876103682 0.1341883634
|
4 |
+
O -2.2512193772 -0.169709702 -0.9848466475
|
5 |
+
C -3.0265062489 -1.2762275569 -1.3904613892
|
6 |
+
C -2.8104411225 -2.0637374738 -2.7387249862
|
7 |
+
C -2.7598726311 -3.3919927241 -1.8878534159
|
8 |
+
C -2.4407002514 -2.4767203713 -0.7256518392
|
9 |
+
C -1.467181607 -2.1129314787 0.098813579
|
10 |
+
H 1.012088138 1.8942209817 -0.2429905156
|
11 |
+
H -0.615905307 1.8589202784 -0.9470659107
|
12 |
+
H -0.4113398154 1.9616853934 0.8065539523
|
13 |
+
H 0.4283203565 -0.4428051608 -0.9340972166
|
14 |
+
H 0.5897595853 -0.3532294279 0.8196759242
|
15 |
+
H -1.8681962495 -0.175448141 1.0583687871
|
16 |
+
H -4.0959419742 -1.0853286499 -1.2201143184
|
17 |
+
H -3.5930759641 -1.9863512547 -3.4970500576
|
18 |
+
H -1.8354825793 -1.8302996855 -3.1700309799
|
19 |
+
H -3.7465941351 -3.8605591083 -1.8233285422
|
20 |
+
H -2.0210691777 -4.1500431884 -2.1526779895
|
21 |
+
H -0.7049926879 -2.7257531143 0.5639707007
|
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CN(C=O)C(N)=O
|
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37286a82f528eedeca6eb4b02d5a487fb80526fac79b8d4c16d374ee389aa2f4
|
3 |
+
size 933248
|
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.xyz
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C -0.0561551802 1.4880114224 0.0115672899
|
2 |
+
N -0.012314902 0.0308990979 -0.0074242366
|
3 |
+
C 0.0394071908 -0.6483703608 1.2066473859
|
4 |
+
O 0.0604983378 -0.0950939226 2.2806600492
|
5 |
+
C -0.042694924 -0.7260819108 -1.1975216915
|
6 |
+
N -0.0688282818 0.0226384605 -2.3461653613
|
7 |
+
O -0.0805938084 -1.9390867198 -1.2010568454
|
8 |
+
H 0.8560644678 1.9243363476 -0.4123282977
|
9 |
+
H -0.9270762996 1.8594106628 -0.5357636429
|
10 |
+
H -0.1279995908 1.7881285544 1.0569523777
|
11 |
+
H 0.0576097412 -1.7352809462 1.059547977
|
12 |
+
H 0.2506516948 0.9756066207 -2.3628772038
|
13 |
+
H 0.0545094644 -0.513342856 -3.1899679505
|
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CCCC(CC)C(C)=O
|
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:daf1f661f359abd9637a71410cfbd09e692daa396b32e047c6606e03874c317a
|
3 |
+
size 2903168
|
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.xyz
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C 0.3762415949 1.4903515185 0.051015763
|
2 |
+
C 0.16294672 -0.0262712548 0.0322514822
|
3 |
+
C 0.675840727 -0.6751333193 -1.2593347251
|
4 |
+
C 0.5134796126 -2.2163401834 -1.316721532
|
5 |
+
C -0.9478193302 -2.6564989579 -1.4944772157
|
6 |
+
C -1.1489111104 -4.1684742305 -1.3528914865
|
7 |
+
C 1.3845485491 -2.7207559592 -2.4735115772
|
8 |
+
C 2.8132544677 -3.1065708435 -2.1332551236
|
9 |
+
O 0.960203994 -2.7885031108 -3.6052090651
|
10 |
+
H 1.4383754499 1.7429425026 -0.0425294685
|
11 |
+
H -0.1517461232 1.9744952727 -0.7778578838
|
12 |
+
H 0.0095987923 1.9321754361 0.9827049486
|
13 |
+
H 0.6755907641 -0.4799924509 0.890858941
|
14 |
+
H -0.9033038021 -0.2426745921 0.1645631834
|
15 |
+
H 0.1643795151 -0.2410453696 -2.127887566
|
16 |
+
H 1.7371611953 -0.4154116304 -1.3751070827
|
17 |
+
H 0.9136495565 -2.6311564942 -0.3805859669
|
18 |
+
H -1.5741060384 -2.1375358991 -0.7612854955
|
19 |
+
H -1.2777225554 -2.3336553985 -2.4874152623
|
20 |
+
H -0.5803069276 -4.7154607671 -2.1113564713
|
21 |
+
H -0.8279135541 -4.5236840194 -0.3667438252
|
22 |
+
H -2.2026624643 -4.437369643 -1.4740466355
|
23 |
+
H 2.8024629719 -4.041189801 -1.5589057987
|
24 |
+
H 3.3927775837 -3.2531585575 -3.0456419515
|
25 |
+
H 3.2903552816 -2.3529525778 -1.4982282451
|
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
C1=NC2=CC=NN2C=N1
|
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08a45fe3a80e5c8cc22b3a0f59d531353b6d401f88661f593abdb69dc44b45c4
|
3 |
+
size 380288
|
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.xyz
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C 0.0592192692 1.368762351 0.0093929006
|
2 |
+
N 1.3122139041 1.8377169369 -0.0001856573
|
3 |
+
N 2.0617385494 0.7170865182 -0.013873644
|
4 |
+
C 3.4112525731 0.634008202 -0.0275595021
|
5 |
+
N 4.0198493068 -0.5162508406 -0.0400141125
|
6 |
+
C 3.2187933425 -1.6270159107 -0.0384447767
|
7 |
+
N 1.909989902 -1.6610752725 -0.0258242213
|
8 |
+
C 1.292800839 -0.4606631607 -0.0130004803
|
9 |
+
C -0.0242733248 -0.0346384436 0.0023033991
|
10 |
+
H -0.7575875186 2.0764585031 0.0213803521
|
11 |
+
H 3.9530882705 1.5750751636 -0.0275449692
|
12 |
+
H 3.7430423978 -2.5779490055 -0.0489488368
|
13 |
+
H -0.9034361909 -0.6555191213 0.0074057982
|
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
C1=C(C2CC2)NN=N1
|
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a8ab9d5c9c61e0547e3a35f52a60354f12a348d681f1d7dc90dc42fed3bb577
|
3 |
+
size 1382528
|
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.xyz
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C 0.002730801 1.5364883046 -0.0652481333
|
2 |
+
C 1.3129515531 0.8070922114 -0.076629951
|
3 |
+
C 0.0187818375 0.0226915472 0.0386968119
|
4 |
+
C -0.3380018988 -0.6381328815 1.3049549011
|
5 |
+
C -0.1608829834 -0.3235247554 2.6381345568
|
6 |
+
N -0.6971640624 -1.3104697961 3.4082804995
|
7 |
+
N -1.2012930219 -2.2290873625 2.6387797732
|
8 |
+
N -0.9871902476 -1.8273143514 1.3659063658
|
9 |
+
H -0.2724161486 2.0867325654 0.827333906
|
10 |
+
H -0.3566469563 1.9707288421 -0.9912789551
|
11 |
+
H 1.8591050209 0.7372804887 -1.0105007132
|
12 |
+
H 1.9370129991 0.8562070416 0.8084014071
|
13 |
+
H -0.2904338773 -0.518969552 -0.8499640332
|
14 |
+
H 0.3118506955 0.5391482945 3.0790387057
|
15 |
+
H -1.3069889708 -2.4074874964 0.6066944785
|
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb.smi
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CNC1=NC(O)=C(C)N1
|
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04ca6c042fdc6c92bad68d750b4770bab1cbba2f1c6a016350429da115f7145f
|
3 |
+
size 2628416
|
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.xyz
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C 0.2021885028 1.2833147668 0.1828347559
|
2 |
+
N -0.2183533194 -0.0600235049 -0.1763289174
|
3 |
+
C 0.4792782987 -0.7562350227 -1.1452089495
|
4 |
+
N 0.5549235456 -2.0713645176 -1.2135880027
|
5 |
+
C 1.2651968247 -2.3183417769 -2.3633449011
|
6 |
+
O 1.5374666899 -3.5896693927 -2.7487695103
|
7 |
+
C 1.6425751596 -1.1756799637 -3.0064836641
|
8 |
+
C 2.4261233754 -0.9183617354 -4.2431958005
|
9 |
+
N 1.1239546418 -0.1576445941 -2.1801339353
|
10 |
+
H 1.2655154169 1.3576381934 0.4589582859
|
11 |
+
H 0.0149717259 1.9764828738 -0.6460648109
|
12 |
+
H -0.4014962835 1.6279653806 1.0247916362
|
13 |
+
H -0.4239665093 -0.6625739292 0.6073847095
|
14 |
+
H 1.1295866096 -4.1597040968 -2.0868842735
|
15 |
+
H 1.8581192728 -0.3597192603 -5.0000767893
|
16 |
+
H 3.3521095167 -0.3582763401 -4.0525753219
|
17 |
+
H 2.7082322022 -1.875782489 -4.6877635002
|
18 |
+
H 1.1534957895 0.8267204188 -2.3790310507
|
data/datasets/moleculenet/qm9/qm9.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b256baa2f324e03d73c6160f9b573d752a8e16512e6939f0f3f04bb3d7dd8c60
|
3 |
+
size 39132032
|
data/datasets/moleculenet/qm9/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/datasets/moleculenet/qm9/train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc96b329d6e0efe1d59359ddb96881209743518ac2910ac3b4a184f70b866f2d
|
3 |
+
size 31696921
|
data/datasets/moleculenet/qm9/valid.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
finetune/args.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def get_parser(parser=None):
|
5 |
+
if parser is None:
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
|
8 |
+
parser.add_argument("--data_root", type=str, required=False, default="")
|
9 |
+
parser.add_argument("--grid_path", type=str, required=False, default="")
|
10 |
+
|
11 |
+
parser.add_argument(
|
12 |
+
"--lr_start", type=float, default=3 * 1e-4, help="Initial lr value"
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--max_epochs", type=int, required=False, default=1, help="max number of epochs"
|
16 |
+
)
|
17 |
+
|
18 |
+
parser.add_argument("--num_workers", type=int, default=0, required=False)
|
19 |
+
parser.add_argument("--dropout", type=float, default=0.1, required=False)
|
20 |
+
parser.add_argument("--n_batch", type=int, default=512, help="Batch size")
|
21 |
+
parser.add_argument("--dataset_name", type=str, required=False, default="sol")
|
22 |
+
parser.add_argument("--measure_name", type=str, required=False, default="measure")
|
23 |
+
parser.add_argument("--checkpoints_folder", type=str, required=True)
|
24 |
+
parser.add_argument("--model_path", type=str, default="./smi_ted/")
|
25 |
+
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
|
26 |
+
parser.add_argument("--restart_filename", type=str, default="")
|
27 |
+
parser.add_argument('--n_output', type=int, default=1)
|
28 |
+
parser.add_argument("--save_every_epoch", type=int, default=0)
|
29 |
+
parser.add_argument("--save_ckpt", type=int, default=1)
|
30 |
+
parser.add_argument("--start_seed", type=int, default=0)
|
31 |
+
parser.add_argument("--target_metric", type=str, default="rmse")
|
32 |
+
parser.add_argument("--loss_fn", type=str, default="mae")
|
33 |
+
|
34 |
+
return parser
|
35 |
+
|
36 |
+
|
37 |
+
def parse_args():
|
38 |
+
parser = get_parser()
|
39 |
+
args = parser.parse_args()
|
40 |
+
return args
|
finetune/dataset/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from dataset.breast_uka import BreastUKA
|
2 |
+
# from dataset.mrnet import MRNetDataset
|
3 |
+
# from dataset.brats import BRATSDataset
|
4 |
+
# from dataset.adni import ADNIDataset
|
5 |
+
# from dataset.duke import DUKEDataset
|
6 |
+
# from dataset.lidc import LIDCDataset
|
7 |
+
from dataset.default import GridDataset
|
finetune/dataset/default.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.multiprocessing as mp
|
7 |
+
|
8 |
+
class GridDataset(Dataset):
|
9 |
+
def __init__(self, dataset, target: str, root_dir: str, internal_resolution: int):
|
10 |
+
super().__init__()
|
11 |
+
self.dataset = dataset
|
12 |
+
self.target = target
|
13 |
+
self.root_dir = root_dir
|
14 |
+
self.internal_resolution = internal_resolution
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dataset)
|
18 |
+
|
19 |
+
def __getitem__(self, idx: int):
|
20 |
+
target = self.dataset.iloc[idx][self.target]
|
21 |
+
filename = self.dataset.iloc[idx]['3d_grid']
|
22 |
+
try:
|
23 |
+
numpy_file = np.load(os.path.join(self.root_dir, filename))
|
24 |
+
torch_np = torch.from_numpy(numpy_file)
|
25 |
+
torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() # Convert to float and move to appropriate device
|
26 |
+
interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear')
|
27 |
+
|
28 |
+
# Apply tanh and log operations
|
29 |
+
# interpolated_data_tanh = torch.tanh(interpolated_data)
|
30 |
+
interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) # Adding 1 to avoid log(0)
|
31 |
+
|
32 |
+
return interpolated_data_log, target
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Error loading file '{filename}': {e}")
|
35 |
+
return None
|
finetune/finetune_regression.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep learning
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import optim
|
5 |
+
from trainers import TrainerRegressor
|
6 |
+
from vq_gan_3d.model.vqgan_DDP import load_VQGAN
|
7 |
+
from utils import init_weights, RMSELoss
|
8 |
+
|
9 |
+
# Parallel
|
10 |
+
from torch.distributed import init_process_group, destroy_process_group
|
11 |
+
|
12 |
+
# Data
|
13 |
+
import pandas as pd
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
# Standard library
|
17 |
+
import math
|
18 |
+
import args
|
19 |
+
import os
|
20 |
+
|
21 |
+
|
22 |
+
def ddp_setup():
|
23 |
+
init_process_group(backend="nccl")
|
24 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
25 |
+
|
26 |
+
|
27 |
+
def main(config):
|
28 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
+
ddp_setup()
|
30 |
+
|
31 |
+
# load dataset
|
32 |
+
df_train = pd.read_csv(f"{config.data_root}/train.csv")
|
33 |
+
df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
|
34 |
+
df_test = pd.read_csv(f"{config.data_root}/test.csv")
|
35 |
+
|
36 |
+
# load model
|
37 |
+
model = load_VQGAN(folder=config.model_path, filename=config.ckpt_filename)
|
38 |
+
model.net.apply(init_weights)
|
39 |
+
print(model.net)
|
40 |
+
|
41 |
+
# disable gradients to frozen parts
|
42 |
+
for param in model.decoder.parameters(): # decoder
|
43 |
+
param.requires_grad = False
|
44 |
+
for param in model.post_vq_conv.parameters(): # after codebook
|
45 |
+
param.requires_grad = False
|
46 |
+
for param in model.codebook.parameters(): # codebook
|
47 |
+
param.requires_grad = False
|
48 |
+
for param in model.image_discriminator.parameters(): # GAN discriminator
|
49 |
+
param.requires_grad = False
|
50 |
+
|
51 |
+
if config.loss_fn == 'rmse':
|
52 |
+
loss_function = RMSELoss()
|
53 |
+
elif config.loss_fn == 'mae':
|
54 |
+
loss_function = nn.L1Loss()
|
55 |
+
|
56 |
+
# init trainer
|
57 |
+
trainer = TrainerRegressor(
|
58 |
+
raw_data=(df_train, df_valid, df_test),
|
59 |
+
grids_path=config.grid_path,
|
60 |
+
dataset_name=config.dataset_name,
|
61 |
+
target=config.measure_name,
|
62 |
+
batch_size=config.n_batch,
|
63 |
+
hparams=config,
|
64 |
+
internal_resolution=model.config['model']['internal_resolution'],
|
65 |
+
target_metric=config.target_metric,
|
66 |
+
seed=config.start_seed,
|
67 |
+
num_workers=config.num_workers,
|
68 |
+
checkpoints_folder=config.checkpoints_folder,
|
69 |
+
restart_filename=config.restart_filename,
|
70 |
+
device=device,
|
71 |
+
save_every_epoch=bool(config.save_every_epoch),
|
72 |
+
save_ckpt=bool(config.save_ckpt)
|
73 |
+
)
|
74 |
+
trainer.compile(
|
75 |
+
model=model,
|
76 |
+
optimizer=optim.AdamW(
|
77 |
+
list(model.encoder.parameters())
|
78 |
+
+list(model.pre_vq_conv.parameters())
|
79 |
+
+list(model.net.parameters()),
|
80 |
+
lr=config.lr_start, betas=(0.9, 0.999)
|
81 |
+
),
|
82 |
+
loss_fn=loss_function
|
83 |
+
)
|
84 |
+
trainer.fit(max_epochs=config.max_epochs)
|
85 |
+
trainer.evaluate()
|
86 |
+
destroy_process_group()
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == '__main__':
|
90 |
+
parser = args.get_parser()
|
91 |
+
config = parser.parse_args()
|
92 |
+
main(config)
|
finetune/run_finetune_qm9_alpha.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'alpha' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/alpha' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_cv.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'cv' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/cv' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_g298.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'g298' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/g298' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_gap.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'gap' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/gap' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_h298.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'h298' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/h298' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_homo.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'homo' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/homo' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_lumo.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'lumo' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/lumo' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_mu.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'mu' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/mu' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_r2.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'r2' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/r2' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_u0.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'u0' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/u0' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_u298.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'u298' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/u298' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/run_finetune_qm9_zpve.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
torchrun \
|
3 |
+
--standalone \
|
4 |
+
--nnodes=1 \
|
5 |
+
--nproc_per_node=1 \
|
6 |
+
finetune_regression.py \
|
7 |
+
--n_batch 8 \
|
8 |
+
--dropout 0.1 \
|
9 |
+
--lr_start 3e-5 \
|
10 |
+
--num_workers 16 \
|
11 |
+
--max_epochs 100 \
|
12 |
+
--model_path '../data/checkpoints/pretrained' \
|
13 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
14 |
+
--data_root '../data/datasets/moleculenet/qm9' \
|
15 |
+
--grid_path '/data_npy/qm9' \
|
16 |
+
--dataset_name qm9 \
|
17 |
+
--measure_name 'zpve' \
|
18 |
+
--checkpoints_folder '../data/checkpoints/finetuned/qm9/zpve' \
|
19 |
+
--loss_fn 'mae' \
|
20 |
+
--target_metric 'mae' \
|
21 |
+
--save_ckpt 1 \
|
22 |
+
--start_seed 0 \
|
23 |
+
--save_every_epoch 0 \
|
24 |
+
--restart_filename '' \
|
finetune/trainers.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep learning
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
from torch.utils.data.distributed import DistributedSampler
|
7 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from dataset.default import GridDataset
|
10 |
+
from utils import RMSELoss
|
11 |
+
|
12 |
+
# Data
|
13 |
+
import pandas as pd
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
# Standard library
|
17 |
+
import random
|
18 |
+
import args
|
19 |
+
import os
|
20 |
+
import copy
|
21 |
+
import shutil
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
# Machine Learning
|
25 |
+
from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve
|
26 |
+
from scipy import stats
|
27 |
+
from utils import RMSE, sensitivity, specificity
|
28 |
+
|
29 |
+
|
30 |
+
class Trainer:
|
31 |
+
|
32 |
+
def __init__(self, raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
|
33 |
+
target_metric='rmse', seed=0, num_workers=0, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
34 |
+
# data
|
35 |
+
self.df_train = raw_data[0]
|
36 |
+
self.df_valid = raw_data[1]
|
37 |
+
self.df_test = raw_data[2]
|
38 |
+
self.grids_path = grids_path
|
39 |
+
self.dataset_name = dataset_name
|
40 |
+
self.target = target
|
41 |
+
self.batch_size = batch_size
|
42 |
+
self.hparams = hparams
|
43 |
+
self.internal_resolution = internal_resolution
|
44 |
+
self.num_workers = num_workers
|
45 |
+
self._prepare_data()
|
46 |
+
|
47 |
+
# config
|
48 |
+
self.target_metric = target_metric
|
49 |
+
self.seed = seed
|
50 |
+
self.checkpoints_folder = checkpoints_folder
|
51 |
+
self.restart_filename = restart_filename
|
52 |
+
self.start_epoch = 1
|
53 |
+
self.save_every_epoch = save_every_epoch
|
54 |
+
self.save_ckpt = save_ckpt
|
55 |
+
self.best_vloss = float('inf')
|
56 |
+
self.last_filename = None
|
57 |
+
self._set_seed(seed)
|
58 |
+
|
59 |
+
# multi-gpu
|
60 |
+
self.local_rank = int(os.environ["LOCAL_RANK"])
|
61 |
+
self.global_rank = int(os.environ["RANK"])
|
62 |
+
|
63 |
+
def _prepare_data(self):
|
64 |
+
train_dataset = GridDataset(
|
65 |
+
dataset=self.df_train,
|
66 |
+
target=self.target,
|
67 |
+
root_dir=self.grids_path,
|
68 |
+
internal_resolution=self.internal_resolution,
|
69 |
+
)
|
70 |
+
valid_dataset = GridDataset(
|
71 |
+
dataset=self.df_valid,
|
72 |
+
target=self.target,
|
73 |
+
root_dir=self.grids_path,
|
74 |
+
internal_resolution=self.internal_resolution,
|
75 |
+
)
|
76 |
+
test_dataset = GridDataset(
|
77 |
+
dataset=self.df_test,
|
78 |
+
target=self.target,
|
79 |
+
root_dir=self.grids_path,
|
80 |
+
internal_resolution=self.internal_resolution,
|
81 |
+
)
|
82 |
+
|
83 |
+
# create dataloader
|
84 |
+
self.train_loader = DataLoader(
|
85 |
+
train_dataset,
|
86 |
+
batch_size=self.batch_size,
|
87 |
+
num_workers=self.num_workers,
|
88 |
+
sampler=DistributedSampler(train_dataset),
|
89 |
+
shuffle=False,
|
90 |
+
pin_memory=True
|
91 |
+
)
|
92 |
+
self.valid_loader = DataLoader(
|
93 |
+
valid_dataset,
|
94 |
+
batch_size=self.batch_size,
|
95 |
+
num_workers=self.num_workers,
|
96 |
+
sampler=DistributedSampler(valid_dataset),
|
97 |
+
shuffle=False,
|
98 |
+
pin_memory=True
|
99 |
+
)
|
100 |
+
self.test_loader = DataLoader(
|
101 |
+
test_dataset,
|
102 |
+
batch_size=self.batch_size,
|
103 |
+
num_workers=self.num_workers,
|
104 |
+
sampler=DistributedSampler(test_dataset),
|
105 |
+
shuffle=False,
|
106 |
+
pin_memory=True
|
107 |
+
)
|
108 |
+
|
109 |
+
def compile(self, model, optimizer, loss_fn):
|
110 |
+
self.model = model.to(self.local_rank)
|
111 |
+
self.optimizer = optimizer
|
112 |
+
self.loss_fn = loss_fn
|
113 |
+
self._print_configuration()
|
114 |
+
|
115 |
+
if self.restart_filename:
|
116 |
+
self._load_checkpoint(self.restart_filename)
|
117 |
+
print('Checkpoint restored!')
|
118 |
+
|
119 |
+
self.model = DDP(self.model, device_ids=[self.local_rank])
|
120 |
+
|
121 |
+
def fit(self, max_epochs=500):
|
122 |
+
for epoch in range(self.start_epoch, max_epochs+1):
|
123 |
+
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
124 |
+
|
125 |
+
# training
|
126 |
+
self.model.train()
|
127 |
+
self.train_loader.sampler.set_epoch(epoch)
|
128 |
+
train_loss = self._train_one_epoch()
|
129 |
+
|
130 |
+
# validation
|
131 |
+
self.model.eval()
|
132 |
+
val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
|
133 |
+
tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
|
134 |
+
|
135 |
+
if self.global_rank == 0:
|
136 |
+
for m in val_metrics.keys():
|
137 |
+
print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
|
138 |
+
print('-'*64)
|
139 |
+
for m in tst_metrics.keys():
|
140 |
+
print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
|
141 |
+
|
142 |
+
############################### Save Finetune checkpoint #######################################
|
143 |
+
if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt and self.global_rank == 0:
|
144 |
+
# remove old checkpoint
|
145 |
+
if (self.last_filename != None) and (not self.save_every_epoch):
|
146 |
+
os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
|
147 |
+
|
148 |
+
# filename
|
149 |
+
model_name = f'{str(self.model.module)}-Finetune'
|
150 |
+
self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
|
151 |
+
|
152 |
+
# update best loss
|
153 |
+
self.best_vloss = val_loss
|
154 |
+
|
155 |
+
# save checkpoint
|
156 |
+
print('Saving checkpoint...')
|
157 |
+
self._save_checkpoint(epoch, self.last_filename)
|
158 |
+
|
159 |
+
def evaluate(self, verbose=True):
|
160 |
+
if verbose:
|
161 |
+
print("\n=====Test Evaluation=====")
|
162 |
+
|
163 |
+
# set model evaluation mode
|
164 |
+
model_inf = copy.deepcopy(self.model)
|
165 |
+
model_inf.eval()
|
166 |
+
|
167 |
+
# evaluate on test set
|
168 |
+
tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
|
169 |
+
|
170 |
+
if verbose and self.global_rank == 0:
|
171 |
+
# show metrics
|
172 |
+
for m in tst_metrics.keys():
|
173 |
+
print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
|
174 |
+
|
175 |
+
# save predictions
|
176 |
+
pd.DataFrame(tst_preds).to_csv(
|
177 |
+
os.path.join(
|
178 |
+
self.checkpoints_folder,
|
179 |
+
f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'
|
180 |
+
),
|
181 |
+
index=False
|
182 |
+
)
|
183 |
+
|
184 |
+
def _train_one_epoch(self):
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
def _validate_one_epoch(self, data_loader, model=None):
|
188 |
+
raise NotImplementedError
|
189 |
+
|
190 |
+
def _print_configuration(self):
|
191 |
+
print('----Finetune information----')
|
192 |
+
print('Dataset:\t', self.dataset_name)
|
193 |
+
print('Target:\t\t', self.target)
|
194 |
+
print('Batch size:\t', self.batch_size)
|
195 |
+
print('LR:\t\t', self._get_lr())
|
196 |
+
print('Device:\t\t', self.local_rank)
|
197 |
+
print('Optimizer:\t', self.optimizer.__class__.__name__)
|
198 |
+
print('Loss function:\t', self.loss_fn.__class__.__name__)
|
199 |
+
print('Seed:\t\t', self.seed)
|
200 |
+
print('Train size:\t', self.df_train.shape[0])
|
201 |
+
print('Valid size:\t', self.df_valid.shape[0])
|
202 |
+
print('Test size:\t', self.df_test.shape[0])
|
203 |
+
|
204 |
+
def _load_checkpoint(self, filename):
|
205 |
+
ckpt_path = os.path.join(self.checkpoints_folder, filename)
|
206 |
+
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
207 |
+
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
208 |
+
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
|
209 |
+
self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
|
210 |
+
|
211 |
+
def _save_checkpoint(self, current_epoch, filename):
|
212 |
+
if not os.path.exists(self.checkpoints_folder):
|
213 |
+
os.makedirs(self.checkpoints_folder)
|
214 |
+
|
215 |
+
self.model.module.config['finetune'] = vars(self.hparams)
|
216 |
+
hparams = self.model.module.config
|
217 |
+
|
218 |
+
ckpt_dict = {
|
219 |
+
'MODEL_STATE': self.model.module.state_dict(),
|
220 |
+
'EPOCHS_RUN': current_epoch,
|
221 |
+
'hparams': hparams,
|
222 |
+
'finetune_info': {
|
223 |
+
'dataset': self.dataset_name,
|
224 |
+
'target`': self.target,
|
225 |
+
'batch_size': self.batch_size,
|
226 |
+
'lr': self._get_lr(),
|
227 |
+
'device': self.local_rank,
|
228 |
+
'optim': self.optimizer.__class__.__name__,
|
229 |
+
'loss_fn': self.loss_fn.__class__.__name__,
|
230 |
+
'train_size': self.df_train.shape[0],
|
231 |
+
'valid_size': self.df_valid.shape[0],
|
232 |
+
'test_size': self.df_test.shape[0],
|
233 |
+
'best_vloss': self.best_vloss,
|
234 |
+
},
|
235 |
+
'seed': self.seed,
|
236 |
+
}
|
237 |
+
|
238 |
+
assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed']
|
239 |
+
|
240 |
+
torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename))
|
241 |
+
|
242 |
+
def _set_seed(self, value):
|
243 |
+
random.seed(value)
|
244 |
+
torch.manual_seed(value)
|
245 |
+
np.random.seed(value)
|
246 |
+
if torch.cuda.is_available():
|
247 |
+
torch.cuda.manual_seed(value)
|
248 |
+
torch.cuda.manual_seed_all(value)
|
249 |
+
cudnn.deterministic = True
|
250 |
+
cudnn.benchmark = False
|
251 |
+
|
252 |
+
def _get_lr(self):
|
253 |
+
for param_group in self.optimizer.param_groups:
|
254 |
+
return param_group['lr']
|
255 |
+
|
256 |
+
|
257 |
+
class TrainerRegressor(Trainer):
|
258 |
+
|
259 |
+
def __init__(self, raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
|
260 |
+
target_metric='rmse', seed=0, num_workers=0, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
261 |
+
super().__init__(raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
|
262 |
+
target_metric, seed, num_workers, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
263 |
+
|
264 |
+
def _train_one_epoch(self):
|
265 |
+
running_loss = 0.0
|
266 |
+
|
267 |
+
if self.global_rank == 0:
|
268 |
+
pbar = tqdm(total=len(self.train_loader))
|
269 |
+
for idx, data in enumerate(self.train_loader):
|
270 |
+
# Every data instance is an input + label pair
|
271 |
+
grids, targets = data
|
272 |
+
targets = targets.to(self.local_rank)
|
273 |
+
grids = grids.to(self.local_rank)
|
274 |
+
|
275 |
+
# zero the parameter gradients (otherwise they are accumulated)
|
276 |
+
self.optimizer.zero_grad()
|
277 |
+
|
278 |
+
# Make predictions for this batch
|
279 |
+
embeddings = self.model.module.feature_extraction(grids)
|
280 |
+
outputs = self.model.module.net(embeddings).squeeze(1)
|
281 |
+
|
282 |
+
# Compute the loss and its gradients
|
283 |
+
loss = self.loss_fn(outputs, targets)
|
284 |
+
loss.backward()
|
285 |
+
|
286 |
+
# Adjust learning weights
|
287 |
+
self.optimizer.step()
|
288 |
+
|
289 |
+
# print statistics
|
290 |
+
running_loss += loss.item()
|
291 |
+
|
292 |
+
# progress bar
|
293 |
+
if self.global_rank == 0:
|
294 |
+
pbar.update(1)
|
295 |
+
pbar.set_description('[TRAINING]')
|
296 |
+
pbar.set_postfix(loss=running_loss/(idx+1))
|
297 |
+
pbar.refresh()
|
298 |
+
if self.global_rank == 0:
|
299 |
+
pbar.close()
|
300 |
+
|
301 |
+
return running_loss / len(self.train_loader)
|
302 |
+
|
303 |
+
def _validate_one_epoch(self, data_loader, model=None):
|
304 |
+
data_targets = []
|
305 |
+
data_preds = []
|
306 |
+
running_loss = 0.0
|
307 |
+
|
308 |
+
model = self.model if model is None else model
|
309 |
+
|
310 |
+
if self.global_rank == 0:
|
311 |
+
pbar = tqdm(total=len(data_loader))
|
312 |
+
with torch.no_grad():
|
313 |
+
for idx, data in enumerate(data_loader):
|
314 |
+
# Every data instance is an input + label pair
|
315 |
+
grids, targets = data
|
316 |
+
targets = targets.to(self.local_rank)
|
317 |
+
grids = grids.to(self.local_rank)
|
318 |
+
|
319 |
+
# Make predictions for this batch
|
320 |
+
embeddings = model.module.feature_extraction(grids)
|
321 |
+
predictions = model.module.net(embeddings).squeeze(1)
|
322 |
+
|
323 |
+
# Compute the loss
|
324 |
+
loss = self.loss_fn(predictions, targets)
|
325 |
+
|
326 |
+
data_targets.append(targets.view(-1))
|
327 |
+
data_preds.append(predictions.view(-1))
|
328 |
+
|
329 |
+
# print statistics
|
330 |
+
running_loss += loss.item()
|
331 |
+
|
332 |
+
# progress bar
|
333 |
+
if self.global_rank == 0:
|
334 |
+
pbar.update(1)
|
335 |
+
pbar.set_description('[EVALUATION]')
|
336 |
+
pbar.set_postfix(loss=running_loss/(idx+1))
|
337 |
+
pbar.refresh()
|
338 |
+
if self.global_rank == 0:
|
339 |
+
pbar.close()
|
340 |
+
|
341 |
+
# Put together predictions and labels from batches
|
342 |
+
preds = torch.cat(data_preds, dim=0).cpu().numpy()
|
343 |
+
tgts = torch.cat(data_targets, dim=0).cpu().numpy()
|
344 |
+
|
345 |
+
# Compute metrics
|
346 |
+
mae = mean_absolute_error(tgts, preds)
|
347 |
+
r2 = r2_score(tgts, preds)
|
348 |
+
rmse = RMSE(preds, tgts)
|
349 |
+
spearman = stats.spearmanr(tgts, preds).correlation # scipy 1.12.0
|
350 |
+
|
351 |
+
# Rearange metrics
|
352 |
+
metrics = {
|
353 |
+
'mae': mae,
|
354 |
+
'r2': r2,
|
355 |
+
'rmse': rmse,
|
356 |
+
'spearman': spearman,
|
357 |
+
}
|
358 |
+
|
359 |
+
return preds, running_loss / len(data_loader), metrics
|
finetune/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep learning
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from sklearn.metrics import confusion_matrix
|
6 |
+
|
7 |
+
# Data
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
# Standard library
|
12 |
+
import os
|
13 |
+
|
14 |
+
# Chemistry
|
15 |
+
from rdkit import Chem
|
16 |
+
from rdkit.Chem import PandasTools
|
17 |
+
from rdkit.Chem import Descriptors
|
18 |
+
PandasTools.RenderImagesInAllDataFrames(True)
|
19 |
+
|
20 |
+
|
21 |
+
def normalize_smiles(smi, canonical=True, isomeric=False):
|
22 |
+
try:
|
23 |
+
normalized = Chem.MolToSmiles(
|
24 |
+
Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
|
25 |
+
)
|
26 |
+
except:
|
27 |
+
normalized = None
|
28 |
+
return normalized
|
29 |
+
|
30 |
+
|
31 |
+
class RMSELoss:
|
32 |
+
def __init__(self):
|
33 |
+
pass
|
34 |
+
|
35 |
+
def __call__(self, yhat, y):
|
36 |
+
return torch.sqrt(torch.mean((yhat-y)**2))
|
37 |
+
|
38 |
+
|
39 |
+
def RMSE(predictions, targets):
|
40 |
+
return np.sqrt(((predictions - targets) ** 2).mean())
|
41 |
+
|
42 |
+
|
43 |
+
def sensitivity(y_true, y_pred):
|
44 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
45 |
+
return (tp/(tp+fn))
|
46 |
+
|
47 |
+
|
48 |
+
def specificity(y_true, y_pred):
|
49 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
50 |
+
return (tn/(tn+fp))
|
51 |
+
|
52 |
+
|
53 |
+
def init_weights(module):
|
54 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
55 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
56 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
57 |
+
module.bias.data.zero_()
|
58 |
+
elif isinstance(module, nn.LayerNorm):
|
59 |
+
module.bias.data.zero_()
|
60 |
+
module.weight.data.fill_(1.0)
|
61 |
+
|
62 |
+
|
63 |
+
def get_optim_groups(module, keep_decoder=False):
|
64 |
+
# setup optimizer
|
65 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
66 |
+
decay = set()
|
67 |
+
no_decay = set()
|
68 |
+
whitelist_weight_modules = (torch.nn.Linear,)
|
69 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
70 |
+
for mn, m in module.named_modules():
|
71 |
+
for pn, p in m.named_parameters():
|
72 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
73 |
+
|
74 |
+
if not keep_decoder and 'decoder' in fpn: # exclude decoder components
|
75 |
+
continue
|
76 |
+
|
77 |
+
if pn.endswith('bias'):
|
78 |
+
# all biases will not be decayed
|
79 |
+
no_decay.add(fpn)
|
80 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
81 |
+
# weights of whitelist modules will be weight decayed
|
82 |
+
decay.add(fpn)
|
83 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
84 |
+
# weights of blacklist modules will NOT be weight decayed
|
85 |
+
no_decay.add(fpn)
|
86 |
+
|
87 |
+
# validate that we considered every parameter
|
88 |
+
param_dict = {pn: p for pn, p in module.named_parameters()}
|
89 |
+
|
90 |
+
# create the pytorch optimizer object
|
91 |
+
optim_groups = [
|
92 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
|
93 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
94 |
+
]
|
95 |
+
|
96 |
+
return optim_groups
|
97 |
+
|
98 |
+
|
99 |
+
class CustomDataset(Dataset):
|
100 |
+
def __init__(self, dataset, target):
|
101 |
+
self.dataset = dataset
|
102 |
+
self.target = target
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.dataset)
|
106 |
+
|
107 |
+
def __getitem__(self, idx):
|
108 |
+
smiles = self.dataset['canon_smiles'].iloc[idx]
|
109 |
+
labels = self.dataset[self.target].iloc[idx]
|
110 |
+
return smiles, labels
|
111 |
+
|
112 |
+
|
113 |
+
class CustomDatasetMultitask(Dataset):
|
114 |
+
def __init__(self, dataset, targets):
|
115 |
+
self.dataset = dataset
|
116 |
+
self.targets = targets
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.dataset)
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
smiles = self.dataset['canon_smiles'].iloc[idx]
|
123 |
+
labels = self.dataset[self.targets].iloc[idx].to_numpy()
|
124 |
+
mask = [0.0 if np.isnan(x) else 1.0 for x in labels]
|
125 |
+
labels = [0.0 if np.isnan(x) else x for x in labels]
|
126 |
+
return smiles, torch.tensor(labels, dtype=torch.float32), torch.tensor(mask)
|
images/3dgridvqgan_architecture.png
ADDED
![]() |
Git LFS Details
|
inference/run_embeddings_eval_xgboost.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python -u ./scripts/evaluate_embeddings_xgboost.py --task zpve
|
inference/run_extract_embeddings.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python -u ./scripts/extract_embeddings.py \
|
3 |
+
--dataset_path '../data/datasets/moleculenet/qm9.csv' \
|
4 |
+
--save_dataset_path '../data/embeddings/qm9_embeddings.csv' \
|
5 |
+
--ckpt_filename 'VQGAN_43.pt' \
|
6 |
+
--data_dir '../data/sample_data_schema' \
|
7 |
+
--batch_size 2 \
|
8 |
+
--num_workers 0 \
|