vshirasuna commited on
Commit
9123ba9
·
verified ·
1 Parent(s): 7d1fe42

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9.smi +1 -0
  3. data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.npy +3 -0
  4. data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.xyz +11 -0
  5. data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c.smi +1 -0
  6. data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.npy +3 -0
  7. data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.xyz +21 -0
  8. data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939.smi +1 -0
  9. data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.npy +3 -0
  10. data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.xyz +21 -0
  11. data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6.smi +1 -0
  12. data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.npy +3 -0
  13. data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.xyz +13 -0
  14. data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d.smi +1 -0
  15. data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.npy +3 -0
  16. data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.xyz +25 -0
  17. data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0.smi +1 -0
  18. data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.npy +3 -0
  19. data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.xyz +13 -0
  20. data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8.smi +1 -0
  21. data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.npy +3 -0
  22. data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.xyz +15 -0
  23. data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb.smi +1 -0
  24. data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.npy +3 -0
  25. data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.xyz +18 -0
  26. data/datasets/moleculenet/qm9/qm9.csv +3 -0
  27. data/datasets/moleculenet/qm9/test.csv +0 -0
  28. data/datasets/moleculenet/qm9/train.csv +3 -0
  29. data/datasets/moleculenet/qm9/valid.csv +0 -0
  30. finetune/args.py +40 -0
  31. finetune/dataset/__init__.py +7 -0
  32. finetune/dataset/default.py +35 -0
  33. finetune/finetune_regression.py +92 -0
  34. finetune/run_finetune_qm9_alpha.sh +24 -0
  35. finetune/run_finetune_qm9_cv.sh +24 -0
  36. finetune/run_finetune_qm9_g298.sh +24 -0
  37. finetune/run_finetune_qm9_gap.sh +24 -0
  38. finetune/run_finetune_qm9_h298.sh +24 -0
  39. finetune/run_finetune_qm9_homo.sh +24 -0
  40. finetune/run_finetune_qm9_lumo.sh +24 -0
  41. finetune/run_finetune_qm9_mu.sh +24 -0
  42. finetune/run_finetune_qm9_r2.sh +24 -0
  43. finetune/run_finetune_qm9_u0.sh +24 -0
  44. finetune/run_finetune_qm9_u298.sh +24 -0
  45. finetune/run_finetune_qm9_zpve.sh +24 -0
  46. finetune/trainers.py +359 -0
  47. finetune/utils.py +126 -0
  48. images/3dgridvqgan_architecture.png +3 -0
  49. inference/run_embeddings_eval_xgboost.sh +2 -0
  50. inference/run_extract_embeddings.sh +8 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/datasets/moleculenet/qm9/qm9.csv filter=lfs diff=lfs merge=lfs -text
37
+ data/datasets/moleculenet/qm9/train.csv filter=lfs diff=lfs merge=lfs -text
38
+ images/3dgridvqgan_architecture.png filter=lfs diff=lfs merge=lfs -text
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ O=CNC1=NOC=N1
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a17ceb4553e1861d34ebaf0d0b2e5f33975f41632eef13f662ed6098e06eb35
3
+ size 338816
data/3d_grids_sample/00b781f5a45a2cd1cdff1a582f5650f9_0.xyz ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ O -0.3433117873 0.0542505386 0.0030523182
2
+ C -0.0862462134 1.2279140618 0.0081946405
3
+ N 1.1846004685 1.7722542978 -0.0009450675
4
+ C 2.386606203 1.0948306964 -0.0156497688
5
+ N 2.5353255646 -0.2087260828 -0.0269119898
6
+ O 3.9407365203 -0.3429161461 -0.0391639287
7
+ C 4.4419993701 0.8855820068 -0.0338427395
8
+ N 3.5551776308 1.8291029217 -0.0192066804
9
+ H -0.8486325877 2.0303875837 0.0201485951
10
+ H 1.2766287385 2.7775349452 0.0053395306
11
+ H 5.5174865125 0.985073027 -0.0416438097
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ C#CCC1CC1(C)CO
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c211d6096850f4fd8bb7e862cb0e3cf09492ed5fb165348829ac37ed9d3c344a
3
+ size 2612864
data/3d_grids_sample/0a18e0f64cbaf1508b32834ece70933c_0.xyz ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C 0.0095240887 1.5961806115 -0.0858857917
2
+ C -0.0675156706 0.086549314 0.0511974473
3
+ C -1.3201762415 -0.5125406846 -0.5508961783
4
+ O -1.3261797579 -0.2420461247 -1.9477734763
5
+ C 1.1982985053 -0.7358014855 -0.0482283446
6
+ C 0.4787861145 -0.6243589949 1.2724834132
7
+ C 1.0934534769 0.1242004421 2.4490763186
8
+ C 1.9615545513 -0.7265702717 3.2588924687
9
+ C 2.6652427105 -1.4441035487 3.9181710713
10
+ H 0.9721361795 1.9954968448 0.2425890448
11
+ H -0.1285829928 1.8750984225 -1.1338899379
12
+ H -0.7769315564 2.0893384777 0.4988274146
13
+ H -1.3405285423 -1.5971735621 -0.3600593089
14
+ H -2.2081101606 -0.0754460014 -0.0643042439
15
+ H -2.18303183 -0.5001196615 -2.2994922363
16
+ H 1.1667370542 -1.6792797399 -0.5822025643
17
+ H 2.1360390085 -0.1991832668 -0.1544507441
18
+ H -0.0724271026 -1.5082395772 1.583953566
19
+ H 1.6685961108 0.9854383459 2.091700259
20
+ H 0.2979810602 0.5352503178 3.0854850798
21
+ H 3.2930319243 -2.0701853274 4.5019859333
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ CCC1C=C2CCC2O1
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5adbb7b36ff91928b2ecec9e548962d45c5241d3949d35e5274511b0af7aa574
3
+ size 2471168
data/3d_grids_sample/0aa6c786f42a0d43d8d1dfc7e9ae4939_0.xyz ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C -0.0055109461 1.5153095531 -0.1082270888
2
+ C -0.0178026585 -0.0131266806 -0.0290282794
3
+ C -1.426157367 -0.5876103682 0.1341883634
4
+ O -2.2512193772 -0.169709702 -0.9848466475
5
+ C -3.0265062489 -1.2762275569 -1.3904613892
6
+ C -2.8104411225 -2.0637374738 -2.7387249862
7
+ C -2.7598726311 -3.3919927241 -1.8878534159
8
+ C -2.4407002514 -2.4767203713 -0.7256518392
9
+ C -1.467181607 -2.1129314787 0.098813579
10
+ H 1.012088138 1.8942209817 -0.2429905156
11
+ H -0.615905307 1.8589202784 -0.9470659107
12
+ H -0.4113398154 1.9616853934 0.8065539523
13
+ H 0.4283203565 -0.4428051608 -0.9340972166
14
+ H 0.5897595853 -0.3532294279 0.8196759242
15
+ H -1.8681962495 -0.175448141 1.0583687871
16
+ H -4.0959419742 -1.0853286499 -1.2201143184
17
+ H -3.5930759641 -1.9863512547 -3.4970500576
18
+ H -1.8354825793 -1.8302996855 -3.1700309799
19
+ H -3.7465941351 -3.8605591083 -1.8233285422
20
+ H -2.0210691777 -4.1500431884 -2.1526779895
21
+ H -0.7049926879 -2.7257531143 0.5639707007
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ CN(C=O)C(N)=O
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37286a82f528eedeca6eb4b02d5a487fb80526fac79b8d4c16d374ee389aa2f4
3
+ size 933248
data/3d_grids_sample/0ae6f92c43122633a151eff5089e8da6_0.xyz ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C -0.0561551802 1.4880114224 0.0115672899
2
+ N -0.012314902 0.0308990979 -0.0074242366
3
+ C 0.0394071908 -0.6483703608 1.2066473859
4
+ O 0.0604983378 -0.0950939226 2.2806600492
5
+ C -0.042694924 -0.7260819108 -1.1975216915
6
+ N -0.0688282818 0.0226384605 -2.3461653613
7
+ O -0.0805938084 -1.9390867198 -1.2010568454
8
+ H 0.8560644678 1.9243363476 -0.4123282977
9
+ H -0.9270762996 1.8594106628 -0.5357636429
10
+ H -0.1279995908 1.7881285544 1.0569523777
11
+ H 0.0576097412 -1.7352809462 1.059547977
12
+ H 0.2506516948 0.9756066207 -2.3628772038
13
+ H 0.0545094644 -0.513342856 -3.1899679505
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ CCCC(CC)C(C)=O
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daf1f661f359abd9637a71410cfbd09e692daa396b32e047c6606e03874c317a
3
+ size 2903168
data/3d_grids_sample/0b4020095a14325f0f174bc8a43f625d_0.xyz ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C 0.3762415949 1.4903515185 0.051015763
2
+ C 0.16294672 -0.0262712548 0.0322514822
3
+ C 0.675840727 -0.6751333193 -1.2593347251
4
+ C 0.5134796126 -2.2163401834 -1.316721532
5
+ C -0.9478193302 -2.6564989579 -1.4944772157
6
+ C -1.1489111104 -4.1684742305 -1.3528914865
7
+ C 1.3845485491 -2.7207559592 -2.4735115772
8
+ C 2.8132544677 -3.1065708435 -2.1332551236
9
+ O 0.960203994 -2.7885031108 -3.6052090651
10
+ H 1.4383754499 1.7429425026 -0.0425294685
11
+ H -0.1517461232 1.9744952727 -0.7778578838
12
+ H 0.0095987923 1.9321754361 0.9827049486
13
+ H 0.6755907641 -0.4799924509 0.890858941
14
+ H -0.9033038021 -0.2426745921 0.1645631834
15
+ H 0.1643795151 -0.2410453696 -2.127887566
16
+ H 1.7371611953 -0.4154116304 -1.3751070827
17
+ H 0.9136495565 -2.6311564942 -0.3805859669
18
+ H -1.5741060384 -2.1375358991 -0.7612854955
19
+ H -1.2777225554 -2.3336553985 -2.4874152623
20
+ H -0.5803069276 -4.7154607671 -2.1113564713
21
+ H -0.8279135541 -4.5236840194 -0.3667438252
22
+ H -2.2026624643 -4.437369643 -1.4740466355
23
+ H 2.8024629719 -4.041189801 -1.5589057987
24
+ H 3.3927775837 -3.2531585575 -3.0456419515
25
+ H 3.2903552816 -2.3529525778 -1.4982282451
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ C1=NC2=CC=NN2C=N1
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08a45fe3a80e5c8cc22b3a0f59d531353b6d401f88661f593abdb69dc44b45c4
3
+ size 380288
data/3d_grids_sample/0b77c16e04a8ac8f84ae9ccaf9b1aaa0_0.xyz ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C 0.0592192692 1.368762351 0.0093929006
2
+ N 1.3122139041 1.8377169369 -0.0001856573
3
+ N 2.0617385494 0.7170865182 -0.013873644
4
+ C 3.4112525731 0.634008202 -0.0275595021
5
+ N 4.0198493068 -0.5162508406 -0.0400141125
6
+ C 3.2187933425 -1.6270159107 -0.0384447767
7
+ N 1.909989902 -1.6610752725 -0.0258242213
8
+ C 1.292800839 -0.4606631607 -0.0130004803
9
+ C -0.0242733248 -0.0346384436 0.0023033991
10
+ H -0.7575875186 2.0764585031 0.0213803521
11
+ H 3.9530882705 1.5750751636 -0.0275449692
12
+ H 3.7430423978 -2.5779490055 -0.0489488368
13
+ H -0.9034361909 -0.6555191213 0.0074057982
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ C1=C(C2CC2)NN=N1
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a8ab9d5c9c61e0547e3a35f52a60354f12a348d681f1d7dc90dc42fed3bb577
3
+ size 1382528
data/3d_grids_sample/0b88d98ac218831353fb8c61aea0cfe8_0.xyz ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C 0.002730801 1.5364883046 -0.0652481333
2
+ C 1.3129515531 0.8070922114 -0.076629951
3
+ C 0.0187818375 0.0226915472 0.0386968119
4
+ C -0.3380018988 -0.6381328815 1.3049549011
5
+ C -0.1608829834 -0.3235247554 2.6381345568
6
+ N -0.6971640624 -1.3104697961 3.4082804995
7
+ N -1.2012930219 -2.2290873625 2.6387797732
8
+ N -0.9871902476 -1.8273143514 1.3659063658
9
+ H -0.2724161486 2.0867325654 0.827333906
10
+ H -0.3566469563 1.9707288421 -0.9912789551
11
+ H 1.8591050209 0.7372804887 -1.0105007132
12
+ H 1.9370129991 0.8562070416 0.8084014071
13
+ H -0.2904338773 -0.518969552 -0.8499640332
14
+ H 0.3118506955 0.5391482945 3.0790387057
15
+ H -1.3069889708 -2.4074874964 0.6066944785
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb.smi ADDED
@@ -0,0 +1 @@
 
 
1
+ CNC1=NC(O)=C(C)N1
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04ca6c042fdc6c92bad68d750b4770bab1cbba2f1c6a016350429da115f7145f
3
+ size 2628416
data/3d_grids_sample/0b94b07e1ec5e58964bfc7e670a359fb_0.xyz ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C 0.2021885028 1.2833147668 0.1828347559
2
+ N -0.2183533194 -0.0600235049 -0.1763289174
3
+ C 0.4792782987 -0.7562350227 -1.1452089495
4
+ N 0.5549235456 -2.0713645176 -1.2135880027
5
+ C 1.2651968247 -2.3183417769 -2.3633449011
6
+ O 1.5374666899 -3.5896693927 -2.7487695103
7
+ C 1.6425751596 -1.1756799637 -3.0064836641
8
+ C 2.4261233754 -0.9183617354 -4.2431958005
9
+ N 1.1239546418 -0.1576445941 -2.1801339353
10
+ H 1.2655154169 1.3576381934 0.4589582859
11
+ H 0.0149717259 1.9764828738 -0.6460648109
12
+ H -0.4014962835 1.6279653806 1.0247916362
13
+ H -0.4239665093 -0.6625739292 0.6073847095
14
+ H 1.1295866096 -4.1597040968 -2.0868842735
15
+ H 1.8581192728 -0.3597192603 -5.0000767893
16
+ H 3.3521095167 -0.3582763401 -4.0525753219
17
+ H 2.7082322022 -1.875782489 -4.6877635002
18
+ H 1.1534957895 0.8267204188 -2.3790310507
data/datasets/moleculenet/qm9/qm9.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b256baa2f324e03d73c6160f9b573d752a8e16512e6939f0f3f04bb3d7dd8c60
3
+ size 39132032
data/datasets/moleculenet/qm9/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/datasets/moleculenet/qm9/train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc96b329d6e0efe1d59359ddb96881209743518ac2910ac3b4a184f70b866f2d
3
+ size 31696921
data/datasets/moleculenet/qm9/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
finetune/args.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_parser(parser=None):
5
+ if parser is None:
6
+ parser = argparse.ArgumentParser()
7
+
8
+ parser.add_argument("--data_root", type=str, required=False, default="")
9
+ parser.add_argument("--grid_path", type=str, required=False, default="")
10
+
11
+ parser.add_argument(
12
+ "--lr_start", type=float, default=3 * 1e-4, help="Initial lr value"
13
+ )
14
+ parser.add_argument(
15
+ "--max_epochs", type=int, required=False, default=1, help="max number of epochs"
16
+ )
17
+
18
+ parser.add_argument("--num_workers", type=int, default=0, required=False)
19
+ parser.add_argument("--dropout", type=float, default=0.1, required=False)
20
+ parser.add_argument("--n_batch", type=int, default=512, help="Batch size")
21
+ parser.add_argument("--dataset_name", type=str, required=False, default="sol")
22
+ parser.add_argument("--measure_name", type=str, required=False, default="measure")
23
+ parser.add_argument("--checkpoints_folder", type=str, required=True)
24
+ parser.add_argument("--model_path", type=str, default="./smi_ted/")
25
+ parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
26
+ parser.add_argument("--restart_filename", type=str, default="")
27
+ parser.add_argument('--n_output', type=int, default=1)
28
+ parser.add_argument("--save_every_epoch", type=int, default=0)
29
+ parser.add_argument("--save_ckpt", type=int, default=1)
30
+ parser.add_argument("--start_seed", type=int, default=0)
31
+ parser.add_argument("--target_metric", type=str, default="rmse")
32
+ parser.add_argument("--loss_fn", type=str, default="mae")
33
+
34
+ return parser
35
+
36
+
37
+ def parse_args():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ return args
finetune/dataset/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # from dataset.breast_uka import BreastUKA
2
+ # from dataset.mrnet import MRNetDataset
3
+ # from dataset.brats import BRATSDataset
4
+ # from dataset.adni import ADNIDataset
5
+ # from dataset.duke import DUKEDataset
6
+ # from dataset.lidc import LIDCDataset
7
+ from dataset.default import GridDataset
finetune/dataset/default.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+ import torch.multiprocessing as mp
7
+
8
+ class GridDataset(Dataset):
9
+ def __init__(self, dataset, target: str, root_dir: str, internal_resolution: int):
10
+ super().__init__()
11
+ self.dataset = dataset
12
+ self.target = target
13
+ self.root_dir = root_dir
14
+ self.internal_resolution = internal_resolution
15
+
16
+ def __len__(self):
17
+ return len(self.dataset)
18
+
19
+ def __getitem__(self, idx: int):
20
+ target = self.dataset.iloc[idx][self.target]
21
+ filename = self.dataset.iloc[idx]['3d_grid']
22
+ try:
23
+ numpy_file = np.load(os.path.join(self.root_dir, filename))
24
+ torch_np = torch.from_numpy(numpy_file)
25
+ torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() # Convert to float and move to appropriate device
26
+ interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear')
27
+
28
+ # Apply tanh and log operations
29
+ # interpolated_data_tanh = torch.tanh(interpolated_data)
30
+ interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) # Adding 1 to avoid log(0)
31
+
32
+ return interpolated_data_log, target
33
+ except Exception as e:
34
+ print(f"Error loading file '{filename}': {e}")
35
+ return None
finetune/finetune_regression.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep learning
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import optim
5
+ from trainers import TrainerRegressor
6
+ from vq_gan_3d.model.vqgan_DDP import load_VQGAN
7
+ from utils import init_weights, RMSELoss
8
+
9
+ # Parallel
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+
12
+ # Data
13
+ import pandas as pd
14
+ import numpy as np
15
+
16
+ # Standard library
17
+ import math
18
+ import args
19
+ import os
20
+
21
+
22
+ def ddp_setup():
23
+ init_process_group(backend="nccl")
24
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
25
+
26
+
27
+ def main(config):
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ ddp_setup()
30
+
31
+ # load dataset
32
+ df_train = pd.read_csv(f"{config.data_root}/train.csv")
33
+ df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
34
+ df_test = pd.read_csv(f"{config.data_root}/test.csv")
35
+
36
+ # load model
37
+ model = load_VQGAN(folder=config.model_path, filename=config.ckpt_filename)
38
+ model.net.apply(init_weights)
39
+ print(model.net)
40
+
41
+ # disable gradients to frozen parts
42
+ for param in model.decoder.parameters(): # decoder
43
+ param.requires_grad = False
44
+ for param in model.post_vq_conv.parameters(): # after codebook
45
+ param.requires_grad = False
46
+ for param in model.codebook.parameters(): # codebook
47
+ param.requires_grad = False
48
+ for param in model.image_discriminator.parameters(): # GAN discriminator
49
+ param.requires_grad = False
50
+
51
+ if config.loss_fn == 'rmse':
52
+ loss_function = RMSELoss()
53
+ elif config.loss_fn == 'mae':
54
+ loss_function = nn.L1Loss()
55
+
56
+ # init trainer
57
+ trainer = TrainerRegressor(
58
+ raw_data=(df_train, df_valid, df_test),
59
+ grids_path=config.grid_path,
60
+ dataset_name=config.dataset_name,
61
+ target=config.measure_name,
62
+ batch_size=config.n_batch,
63
+ hparams=config,
64
+ internal_resolution=model.config['model']['internal_resolution'],
65
+ target_metric=config.target_metric,
66
+ seed=config.start_seed,
67
+ num_workers=config.num_workers,
68
+ checkpoints_folder=config.checkpoints_folder,
69
+ restart_filename=config.restart_filename,
70
+ device=device,
71
+ save_every_epoch=bool(config.save_every_epoch),
72
+ save_ckpt=bool(config.save_ckpt)
73
+ )
74
+ trainer.compile(
75
+ model=model,
76
+ optimizer=optim.AdamW(
77
+ list(model.encoder.parameters())
78
+ +list(model.pre_vq_conv.parameters())
79
+ +list(model.net.parameters()),
80
+ lr=config.lr_start, betas=(0.9, 0.999)
81
+ ),
82
+ loss_fn=loss_function
83
+ )
84
+ trainer.fit(max_epochs=config.max_epochs)
85
+ trainer.evaluate()
86
+ destroy_process_group()
87
+
88
+
89
+ if __name__ == '__main__':
90
+ parser = args.get_parser()
91
+ config = parser.parse_args()
92
+ main(config)
finetune/run_finetune_qm9_alpha.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'alpha' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/alpha' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_cv.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'cv' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/cv' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_g298.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'g298' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/g298' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_gap.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'gap' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/gap' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_h298.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'h298' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/h298' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_homo.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'homo' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/homo' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_lumo.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'lumo' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/lumo' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_mu.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'mu' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/mu' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_r2.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'r2' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/r2' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_u0.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'u0' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/u0' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_u298.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'u298' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/u298' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/run_finetune_qm9_zpve.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ torchrun \
3
+ --standalone \
4
+ --nnodes=1 \
5
+ --nproc_per_node=1 \
6
+ finetune_regression.py \
7
+ --n_batch 8 \
8
+ --dropout 0.1 \
9
+ --lr_start 3e-5 \
10
+ --num_workers 16 \
11
+ --max_epochs 100 \
12
+ --model_path '../data/checkpoints/pretrained' \
13
+ --ckpt_filename 'VQGAN_43.pt' \
14
+ --data_root '../data/datasets/moleculenet/qm9' \
15
+ --grid_path '/data_npy/qm9' \
16
+ --dataset_name qm9 \
17
+ --measure_name 'zpve' \
18
+ --checkpoints_folder '../data/checkpoints/finetuned/qm9/zpve' \
19
+ --loss_fn 'mae' \
20
+ --target_metric 'mae' \
21
+ --save_ckpt 1 \
22
+ --start_seed 0 \
23
+ --save_every_epoch 0 \
24
+ --restart_filename '' \
finetune/trainers.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep learning
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.backends.cudnn as cudnn
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+ from torch.utils.data import DataLoader
9
+ from dataset.default import GridDataset
10
+ from utils import RMSELoss
11
+
12
+ # Data
13
+ import pandas as pd
14
+ import numpy as np
15
+
16
+ # Standard library
17
+ import random
18
+ import args
19
+ import os
20
+ import copy
21
+ import shutil
22
+ from tqdm import tqdm
23
+
24
+ # Machine Learning
25
+ from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve
26
+ from scipy import stats
27
+ from utils import RMSE, sensitivity, specificity
28
+
29
+
30
+ class Trainer:
31
+
32
+ def __init__(self, raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
33
+ target_metric='rmse', seed=0, num_workers=0, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
34
+ # data
35
+ self.df_train = raw_data[0]
36
+ self.df_valid = raw_data[1]
37
+ self.df_test = raw_data[2]
38
+ self.grids_path = grids_path
39
+ self.dataset_name = dataset_name
40
+ self.target = target
41
+ self.batch_size = batch_size
42
+ self.hparams = hparams
43
+ self.internal_resolution = internal_resolution
44
+ self.num_workers = num_workers
45
+ self._prepare_data()
46
+
47
+ # config
48
+ self.target_metric = target_metric
49
+ self.seed = seed
50
+ self.checkpoints_folder = checkpoints_folder
51
+ self.restart_filename = restart_filename
52
+ self.start_epoch = 1
53
+ self.save_every_epoch = save_every_epoch
54
+ self.save_ckpt = save_ckpt
55
+ self.best_vloss = float('inf')
56
+ self.last_filename = None
57
+ self._set_seed(seed)
58
+
59
+ # multi-gpu
60
+ self.local_rank = int(os.environ["LOCAL_RANK"])
61
+ self.global_rank = int(os.environ["RANK"])
62
+
63
+ def _prepare_data(self):
64
+ train_dataset = GridDataset(
65
+ dataset=self.df_train,
66
+ target=self.target,
67
+ root_dir=self.grids_path,
68
+ internal_resolution=self.internal_resolution,
69
+ )
70
+ valid_dataset = GridDataset(
71
+ dataset=self.df_valid,
72
+ target=self.target,
73
+ root_dir=self.grids_path,
74
+ internal_resolution=self.internal_resolution,
75
+ )
76
+ test_dataset = GridDataset(
77
+ dataset=self.df_test,
78
+ target=self.target,
79
+ root_dir=self.grids_path,
80
+ internal_resolution=self.internal_resolution,
81
+ )
82
+
83
+ # create dataloader
84
+ self.train_loader = DataLoader(
85
+ train_dataset,
86
+ batch_size=self.batch_size,
87
+ num_workers=self.num_workers,
88
+ sampler=DistributedSampler(train_dataset),
89
+ shuffle=False,
90
+ pin_memory=True
91
+ )
92
+ self.valid_loader = DataLoader(
93
+ valid_dataset,
94
+ batch_size=self.batch_size,
95
+ num_workers=self.num_workers,
96
+ sampler=DistributedSampler(valid_dataset),
97
+ shuffle=False,
98
+ pin_memory=True
99
+ )
100
+ self.test_loader = DataLoader(
101
+ test_dataset,
102
+ batch_size=self.batch_size,
103
+ num_workers=self.num_workers,
104
+ sampler=DistributedSampler(test_dataset),
105
+ shuffle=False,
106
+ pin_memory=True
107
+ )
108
+
109
+ def compile(self, model, optimizer, loss_fn):
110
+ self.model = model.to(self.local_rank)
111
+ self.optimizer = optimizer
112
+ self.loss_fn = loss_fn
113
+ self._print_configuration()
114
+
115
+ if self.restart_filename:
116
+ self._load_checkpoint(self.restart_filename)
117
+ print('Checkpoint restored!')
118
+
119
+ self.model = DDP(self.model, device_ids=[self.local_rank])
120
+
121
+ def fit(self, max_epochs=500):
122
+ for epoch in range(self.start_epoch, max_epochs+1):
123
+ print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
124
+
125
+ # training
126
+ self.model.train()
127
+ self.train_loader.sampler.set_epoch(epoch)
128
+ train_loss = self._train_one_epoch()
129
+
130
+ # validation
131
+ self.model.eval()
132
+ val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader)
133
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
134
+
135
+ if self.global_rank == 0:
136
+ for m in val_metrics.keys():
137
+ print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
138
+ print('-'*64)
139
+ for m in tst_metrics.keys():
140
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
141
+
142
+ ############################### Save Finetune checkpoint #######################################
143
+ if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt and self.global_rank == 0:
144
+ # remove old checkpoint
145
+ if (self.last_filename != None) and (not self.save_every_epoch):
146
+ os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
147
+
148
+ # filename
149
+ model_name = f'{str(self.model.module)}-Finetune'
150
+ self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
151
+
152
+ # update best loss
153
+ self.best_vloss = val_loss
154
+
155
+ # save checkpoint
156
+ print('Saving checkpoint...')
157
+ self._save_checkpoint(epoch, self.last_filename)
158
+
159
+ def evaluate(self, verbose=True):
160
+ if verbose:
161
+ print("\n=====Test Evaluation=====")
162
+
163
+ # set model evaluation mode
164
+ model_inf = copy.deepcopy(self.model)
165
+ model_inf.eval()
166
+
167
+ # evaluate on test set
168
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
169
+
170
+ if verbose and self.global_rank == 0:
171
+ # show metrics
172
+ for m in tst_metrics.keys():
173
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
174
+
175
+ # save predictions
176
+ pd.DataFrame(tst_preds).to_csv(
177
+ os.path.join(
178
+ self.checkpoints_folder,
179
+ f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'
180
+ ),
181
+ index=False
182
+ )
183
+
184
+ def _train_one_epoch(self):
185
+ raise NotImplementedError
186
+
187
+ def _validate_one_epoch(self, data_loader, model=None):
188
+ raise NotImplementedError
189
+
190
+ def _print_configuration(self):
191
+ print('----Finetune information----')
192
+ print('Dataset:\t', self.dataset_name)
193
+ print('Target:\t\t', self.target)
194
+ print('Batch size:\t', self.batch_size)
195
+ print('LR:\t\t', self._get_lr())
196
+ print('Device:\t\t', self.local_rank)
197
+ print('Optimizer:\t', self.optimizer.__class__.__name__)
198
+ print('Loss function:\t', self.loss_fn.__class__.__name__)
199
+ print('Seed:\t\t', self.seed)
200
+ print('Train size:\t', self.df_train.shape[0])
201
+ print('Valid size:\t', self.df_valid.shape[0])
202
+ print('Test size:\t', self.df_test.shape[0])
203
+
204
+ def _load_checkpoint(self, filename):
205
+ ckpt_path = os.path.join(self.checkpoints_folder, filename)
206
+ ckpt_dict = torch.load(ckpt_path, map_location='cpu')
207
+ self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
208
+ self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
209
+ self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
210
+
211
+ def _save_checkpoint(self, current_epoch, filename):
212
+ if not os.path.exists(self.checkpoints_folder):
213
+ os.makedirs(self.checkpoints_folder)
214
+
215
+ self.model.module.config['finetune'] = vars(self.hparams)
216
+ hparams = self.model.module.config
217
+
218
+ ckpt_dict = {
219
+ 'MODEL_STATE': self.model.module.state_dict(),
220
+ 'EPOCHS_RUN': current_epoch,
221
+ 'hparams': hparams,
222
+ 'finetune_info': {
223
+ 'dataset': self.dataset_name,
224
+ 'target`': self.target,
225
+ 'batch_size': self.batch_size,
226
+ 'lr': self._get_lr(),
227
+ 'device': self.local_rank,
228
+ 'optim': self.optimizer.__class__.__name__,
229
+ 'loss_fn': self.loss_fn.__class__.__name__,
230
+ 'train_size': self.df_train.shape[0],
231
+ 'valid_size': self.df_valid.shape[0],
232
+ 'test_size': self.df_test.shape[0],
233
+ 'best_vloss': self.best_vloss,
234
+ },
235
+ 'seed': self.seed,
236
+ }
237
+
238
+ assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed']
239
+
240
+ torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename))
241
+
242
+ def _set_seed(self, value):
243
+ random.seed(value)
244
+ torch.manual_seed(value)
245
+ np.random.seed(value)
246
+ if torch.cuda.is_available():
247
+ torch.cuda.manual_seed(value)
248
+ torch.cuda.manual_seed_all(value)
249
+ cudnn.deterministic = True
250
+ cudnn.benchmark = False
251
+
252
+ def _get_lr(self):
253
+ for param_group in self.optimizer.param_groups:
254
+ return param_group['lr']
255
+
256
+
257
+ class TrainerRegressor(Trainer):
258
+
259
+ def __init__(self, raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
260
+ target_metric='rmse', seed=0, num_workers=0, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
261
+ super().__init__(raw_data, grids_path, dataset_name, target, batch_size, hparams, internal_resolution,
262
+ target_metric, seed, num_workers, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
263
+
264
+ def _train_one_epoch(self):
265
+ running_loss = 0.0
266
+
267
+ if self.global_rank == 0:
268
+ pbar = tqdm(total=len(self.train_loader))
269
+ for idx, data in enumerate(self.train_loader):
270
+ # Every data instance is an input + label pair
271
+ grids, targets = data
272
+ targets = targets.to(self.local_rank)
273
+ grids = grids.to(self.local_rank)
274
+
275
+ # zero the parameter gradients (otherwise they are accumulated)
276
+ self.optimizer.zero_grad()
277
+
278
+ # Make predictions for this batch
279
+ embeddings = self.model.module.feature_extraction(grids)
280
+ outputs = self.model.module.net(embeddings).squeeze(1)
281
+
282
+ # Compute the loss and its gradients
283
+ loss = self.loss_fn(outputs, targets)
284
+ loss.backward()
285
+
286
+ # Adjust learning weights
287
+ self.optimizer.step()
288
+
289
+ # print statistics
290
+ running_loss += loss.item()
291
+
292
+ # progress bar
293
+ if self.global_rank == 0:
294
+ pbar.update(1)
295
+ pbar.set_description('[TRAINING]')
296
+ pbar.set_postfix(loss=running_loss/(idx+1))
297
+ pbar.refresh()
298
+ if self.global_rank == 0:
299
+ pbar.close()
300
+
301
+ return running_loss / len(self.train_loader)
302
+
303
+ def _validate_one_epoch(self, data_loader, model=None):
304
+ data_targets = []
305
+ data_preds = []
306
+ running_loss = 0.0
307
+
308
+ model = self.model if model is None else model
309
+
310
+ if self.global_rank == 0:
311
+ pbar = tqdm(total=len(data_loader))
312
+ with torch.no_grad():
313
+ for idx, data in enumerate(data_loader):
314
+ # Every data instance is an input + label pair
315
+ grids, targets = data
316
+ targets = targets.to(self.local_rank)
317
+ grids = grids.to(self.local_rank)
318
+
319
+ # Make predictions for this batch
320
+ embeddings = model.module.feature_extraction(grids)
321
+ predictions = model.module.net(embeddings).squeeze(1)
322
+
323
+ # Compute the loss
324
+ loss = self.loss_fn(predictions, targets)
325
+
326
+ data_targets.append(targets.view(-1))
327
+ data_preds.append(predictions.view(-1))
328
+
329
+ # print statistics
330
+ running_loss += loss.item()
331
+
332
+ # progress bar
333
+ if self.global_rank == 0:
334
+ pbar.update(1)
335
+ pbar.set_description('[EVALUATION]')
336
+ pbar.set_postfix(loss=running_loss/(idx+1))
337
+ pbar.refresh()
338
+ if self.global_rank == 0:
339
+ pbar.close()
340
+
341
+ # Put together predictions and labels from batches
342
+ preds = torch.cat(data_preds, dim=0).cpu().numpy()
343
+ tgts = torch.cat(data_targets, dim=0).cpu().numpy()
344
+
345
+ # Compute metrics
346
+ mae = mean_absolute_error(tgts, preds)
347
+ r2 = r2_score(tgts, preds)
348
+ rmse = RMSE(preds, tgts)
349
+ spearman = stats.spearmanr(tgts, preds).correlation # scipy 1.12.0
350
+
351
+ # Rearange metrics
352
+ metrics = {
353
+ 'mae': mae,
354
+ 'r2': r2,
355
+ 'rmse': rmse,
356
+ 'spearman': spearman,
357
+ }
358
+
359
+ return preds, running_loss / len(data_loader), metrics
finetune/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep learning
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import Dataset
5
+ from sklearn.metrics import confusion_matrix
6
+
7
+ # Data
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ # Standard library
12
+ import os
13
+
14
+ # Chemistry
15
+ from rdkit import Chem
16
+ from rdkit.Chem import PandasTools
17
+ from rdkit.Chem import Descriptors
18
+ PandasTools.RenderImagesInAllDataFrames(True)
19
+
20
+
21
+ def normalize_smiles(smi, canonical=True, isomeric=False):
22
+ try:
23
+ normalized = Chem.MolToSmiles(
24
+ Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
25
+ )
26
+ except:
27
+ normalized = None
28
+ return normalized
29
+
30
+
31
+ class RMSELoss:
32
+ def __init__(self):
33
+ pass
34
+
35
+ def __call__(self, yhat, y):
36
+ return torch.sqrt(torch.mean((yhat-y)**2))
37
+
38
+
39
+ def RMSE(predictions, targets):
40
+ return np.sqrt(((predictions - targets) ** 2).mean())
41
+
42
+
43
+ def sensitivity(y_true, y_pred):
44
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
45
+ return (tp/(tp+fn))
46
+
47
+
48
+ def specificity(y_true, y_pred):
49
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
50
+ return (tn/(tn+fp))
51
+
52
+
53
+ def init_weights(module):
54
+ if isinstance(module, (nn.Linear, nn.Embedding)):
55
+ module.weight.data.normal_(mean=0.0, std=0.02)
56
+ if isinstance(module, nn.Linear) and module.bias is not None:
57
+ module.bias.data.zero_()
58
+ elif isinstance(module, nn.LayerNorm):
59
+ module.bias.data.zero_()
60
+ module.weight.data.fill_(1.0)
61
+
62
+
63
+ def get_optim_groups(module, keep_decoder=False):
64
+ # setup optimizer
65
+ # separate out all parameters to those that will and won't experience regularizing weight decay
66
+ decay = set()
67
+ no_decay = set()
68
+ whitelist_weight_modules = (torch.nn.Linear,)
69
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
70
+ for mn, m in module.named_modules():
71
+ for pn, p in m.named_parameters():
72
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
73
+
74
+ if not keep_decoder and 'decoder' in fpn: # exclude decoder components
75
+ continue
76
+
77
+ if pn.endswith('bias'):
78
+ # all biases will not be decayed
79
+ no_decay.add(fpn)
80
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
81
+ # weights of whitelist modules will be weight decayed
82
+ decay.add(fpn)
83
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
84
+ # weights of blacklist modules will NOT be weight decayed
85
+ no_decay.add(fpn)
86
+
87
+ # validate that we considered every parameter
88
+ param_dict = {pn: p for pn, p in module.named_parameters()}
89
+
90
+ # create the pytorch optimizer object
91
+ optim_groups = [
92
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
93
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
94
+ ]
95
+
96
+ return optim_groups
97
+
98
+
99
+ class CustomDataset(Dataset):
100
+ def __init__(self, dataset, target):
101
+ self.dataset = dataset
102
+ self.target = target
103
+
104
+ def __len__(self):
105
+ return len(self.dataset)
106
+
107
+ def __getitem__(self, idx):
108
+ smiles = self.dataset['canon_smiles'].iloc[idx]
109
+ labels = self.dataset[self.target].iloc[idx]
110
+ return smiles, labels
111
+
112
+
113
+ class CustomDatasetMultitask(Dataset):
114
+ def __init__(self, dataset, targets):
115
+ self.dataset = dataset
116
+ self.targets = targets
117
+
118
+ def __len__(self):
119
+ return len(self.dataset)
120
+
121
+ def __getitem__(self, idx):
122
+ smiles = self.dataset['canon_smiles'].iloc[idx]
123
+ labels = self.dataset[self.targets].iloc[idx].to_numpy()
124
+ mask = [0.0 if np.isnan(x) else 1.0 for x in labels]
125
+ labels = [0.0 if np.isnan(x) else x for x in labels]
126
+ return smiles, torch.tensor(labels, dtype=torch.float32), torch.tensor(mask)
images/3dgridvqgan_architecture.png ADDED

Git LFS Details

  • SHA256: 7ff246442abb10516e041e090acee9f0738323252438b908edc8c8424cbe4ba5
  • Pointer size: 131 Bytes
  • Size of remote file: 730 kB
inference/run_embeddings_eval_xgboost.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ python -u ./scripts/evaluate_embeddings_xgboost.py --task zpve
inference/run_extract_embeddings.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ python -u ./scripts/extract_embeddings.py \
3
+ --dataset_path '../data/datasets/moleculenet/qm9.csv' \
4
+ --save_dataset_path '../data/embeddings/qm9_embeddings.csv' \
5
+ --ckpt_filename 'VQGAN_43.pt' \
6
+ --data_dir '../data/sample_data_schema' \
7
+ --batch_size 2 \
8
+ --num_workers 0 \