wenruifan commited on
Commit
a256709
·
verified ·
1 Parent(s): 834db5b

Upload 115 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml +34 -0
  2. PreTrain_MeDSLIP/data_file/observation explanation.json +77 -0
  3. PreTrain_MeDSLIP/data_file/preprocessing/adj_matrix.py +891 -0
  4. PreTrain_MeDSLIP/data_file/preprocessing/radgraph_itemized.py +139 -0
  5. PreTrain_MeDSLIP/data_file/preprocessing/radgraph_parsed.py +322 -0
  6. PreTrain_MeDSLIP/dataset/dataset.py +310 -0
  7. PreTrain_MeDSLIP/dataset/randaugment.py +346 -0
  8. PreTrain_MeDSLIP/models/__init__.py +0 -0
  9. PreTrain_MeDSLIP/models/model_MeDSLIP.py +530 -0
  10. PreTrain_MeDSLIP/models/tokenization_bert.py +578 -0
  11. PreTrain_MeDSLIP/models/transformer.py +210 -0
  12. PreTrain_MeDSLIP/optim/__init__.py +13 -0
  13. PreTrain_MeDSLIP/optim/adafactor.py +206 -0
  14. PreTrain_MeDSLIP/optim/adahessian.py +207 -0
  15. PreTrain_MeDSLIP/optim/adamp.py +133 -0
  16. PreTrain_MeDSLIP/optim/adamw.py +131 -0
  17. PreTrain_MeDSLIP/optim/lookahead.py +96 -0
  18. PreTrain_MeDSLIP/optim/nadam.py +108 -0
  19. PreTrain_MeDSLIP/optim/novograd.py +90 -0
  20. PreTrain_MeDSLIP/optim/nvnovograd.py +132 -0
  21. PreTrain_MeDSLIP/optim/optim_factory.py +138 -0
  22. PreTrain_MeDSLIP/optim/radam.py +170 -0
  23. PreTrain_MeDSLIP/optim/rmsprop_tf.py +160 -0
  24. PreTrain_MeDSLIP/optim/sgdp.py +123 -0
  25. PreTrain_MeDSLIP/scheduler/__init__.py +5 -0
  26. PreTrain_MeDSLIP/scheduler/cosine_lr.py +136 -0
  27. PreTrain_MeDSLIP/scheduler/plateau_lr.py +116 -0
  28. PreTrain_MeDSLIP/scheduler/scheduler.py +120 -0
  29. PreTrain_MeDSLIP/scheduler/scheduler_factory.py +87 -0
  30. PreTrain_MeDSLIP/scheduler/step_lr.py +73 -0
  31. PreTrain_MeDSLIP/scheduler/tanh_lr.py +141 -0
  32. PreTrain_MeDSLIP/train_MeDSLIP.py +446 -0
  33. PreTrain_MeDSLIP/utils.py +277 -0
  34. README.md +49 -3
  35. Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml +17 -0
  36. Sample_Finetuning_SIIMACR/I1_classification/dataset/dataset_siim_acr.py +124 -0
  37. Sample_Finetuning_SIIMACR/I1_classification/dataset/randaugment.py +346 -0
  38. Sample_Finetuning_SIIMACR/I1_classification/models/resnet.py +88 -0
  39. Sample_Finetuning_SIIMACR/I1_classification/optim/__init__.py +13 -0
  40. Sample_Finetuning_SIIMACR/I1_classification/optim/adafactor.py +206 -0
  41. Sample_Finetuning_SIIMACR/I1_classification/optim/adahessian.py +207 -0
  42. Sample_Finetuning_SIIMACR/I1_classification/optim/adamp.py +133 -0
  43. Sample_Finetuning_SIIMACR/I1_classification/optim/adamw.py +131 -0
  44. Sample_Finetuning_SIIMACR/I1_classification/optim/lookahead.py +96 -0
  45. Sample_Finetuning_SIIMACR/I1_classification/optim/nadam.py +108 -0
  46. Sample_Finetuning_SIIMACR/I1_classification/optim/novograd.py +90 -0
  47. Sample_Finetuning_SIIMACR/I1_classification/optim/nvnovograd.py +132 -0
  48. Sample_Finetuning_SIIMACR/I1_classification/optim/optim_factory.py +138 -0
  49. Sample_Finetuning_SIIMACR/I1_classification/optim/radam.py +170 -0
  50. Sample_Finetuning_SIIMACR/I1_classification/optim/rmsprop_tf.py +160 -0
PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: "setting/rad_graph_metric_train_local.json"
2
+ valid_file: "setting/rad_graph_metric_validate_local.json"
3
+ test_file: "setting/rad_graph_metric_test_local.json"
4
+ label_file: "setting/landmark_observation_adj_mtx.npy"
5
+ pathology_book: "PreTrain_MeDSLIP/data_file/observation explanation.json"
6
+
7
+ image_res: 224
8
+ patch_size: 16
9
+ num_sentences: 12
10
+ num_tokens: 32
11
+ vision_width: 768
12
+ fea_width: 197
13
+ embed_dim: 256
14
+ batch_size: 64
15
+ test_batch_size: 32
16
+ temp: 0.07
17
+ mlm_probability: 0.15
18
+ queue_size: 8192
19
+ momentum: 0.995
20
+ alpha: 0.4
21
+ d_model: 256
22
+ res_base_model: "resnet50"
23
+ num_queries: 75
24
+ dropout: 0.1
25
+ attribute_set_size: 2
26
+ N: 4
27
+ H: 4
28
+ no_cl: False
29
+
30
+ exclude_class: False
31
+ text_encoder: "emilyalsentzer/Bio_ClinicalBERT"
32
+ shuffle_ratio: 0.5
33
+ optimizer: {opt: adamW, lr: 1e-4, weight_decay: 0.02}
34
+ schedular: {sched: cosine, lr: 1e-4, epochs: 100, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 5, cooldown_epochs: 0}
PreTrain_MeDSLIP/data_file/observation explanation.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "normal": "It means the absence of diseases and infirmity, indicating the structure is normal.",
3
+ "clear": "The lungs are clear and normal. No evidence for other diseases on lung.",
4
+ "sharp": "This means that an anatomical structure's boundary or edge is clear and normal, meaning it is free of diseases.",
5
+ "sharply": "\u2018Sharply seen\u2019 means that an anatomical structure is clearly visible.",
6
+ "unremarkable": "This represents some anatomical structures are normal, usually modifying cardiac and mediastinal silhouettes.",
7
+ "intact": "The bonny structure is complete and normal, meaning no fractures.",
8
+ "stable": "The modified anatomical structures are normal and stable. No evidence for diseases.",
9
+ "free": "It usually refers to free air and is associate with pneumothorax,atelectasis,pneumoperitoneum and emphysema.",
10
+ "effusion": "A pleural effusion is accumulation of excessive fluid in the pleural space, the potential space that surrounds each lung. A pleural effusion infiltrates the space between the visceral pleura and the parietal pleura",
11
+ "opacity": "It is defined as an area of hazy opacification due to air displacement by fluid, airway collapse, fibrosis, or a neoplastic process. It is causes include infections, interstitial lung disease, and pulmonary edema.",
12
+ "pneumothorax": "A pneumothorax is an abnormal collection of air in the pleural space between the lung and the chest wall. It may be caused by pneumonia or fibrosis and other diseases. ",
13
+ "edema": "Pulmonary edema, also known as pulmonary congestion, is excessive liquid accumulation in the tissue and air spaces of the lungs. It will show fluid in the alveolar walls",
14
+ "atelectasis": "It is the collapse or closure of a lung resulting in reduced or absent gas exchange. Findings can include lung opacification and loss of lung volume.",
15
+ "tube": "it is a surgical drain that is inserted through the chest wall and into the pleural space or the mediastinum to remove undesired substances such as air (pneumothorax), excess fluid (pleural effusion or hydrothorax), blood (hemothorax), chyle (chylothorax) or pus (empyema) from the intrathoracic space.",
16
+ "consolidation": "it is a region of normally compressible lung tissue that has filled with liquid instead of air. Consolidation must be present to diagnose pneumonia: the signs of lobar pneumonia are characteristic and clinically referred to as consolidation.",
17
+ "process": "Acute process' means there is abnormality in the anotomy structure. ",
18
+ "abnormality": "It means the exist of diseases and infirmity, indicating the structure is abnormal.",
19
+ "enlarge": "It usually modifies cardiac silhouette and heart. Cardiomegaly is a medical condition in which the heart is enlarged.",
20
+ "tip": "It refers to the top head of the tube.",
21
+ "low": "The presence of low lung volumes may be a sign of a restrictive lung condition such as pulmonary fibrosis or sarcoidosis.",
22
+ "eumonia": "Pneumonia is an inflammatory condition of the lung primarily affecting the small air sacs known as alveoli. Pneumonia may present with opacities. Complications such as pleural effusion may also be found increasing the diagnostic accuracy of lung consolidation and pleural effusion",
23
+ "line": "It refers to venous access line ot PICC lines.",
24
+ "congestion": "Pulmonary congestion is defined as accumulation of fluid in the lungs, resulting in impaired gas exchange and arterial hypoxemia. ",
25
+ "catheter": "catheter is a tube placed in the body to drain and collect urine from the bladder",
26
+ "cardiomegaly": "Cardiomegaly (sometimes megacardia or megalocardia) is a medical condition in which the heart is enlarged. ",
27
+ "fracture": "fracture is a break in a rib bone.",
28
+ "air": "It refers to the free air or gas in pleural space, indicating pneumothorax. Air displacement by fluid may lead to opacity.",
29
+ "tortuous": "the Aorta is slightly tortuous. Sometimes it may refer to varicose veins",
30
+ "lead": "It refers to the leading head of the tube.",
31
+ "disease": "It means the exist of diseases and abnormalty, indicating the structure is abnormal. ",
32
+ "calcification": "Pulmonary calcification is a common asymptomatic finding. Pulmonary calcifications are caused mainly by two mechanisms: the dystrophic form and the metastatic form",
33
+ "prominence": "It means the exist of some observation.",
34
+ "device": "It refer to some equipments like picc tub, valve catheter, pacemaker hardware, arthroplastmarker icd defib, device support equipment and mediport'",
35
+ "engorgement": "pulmonary vascular engorgement means obstruction of the normal flux of blood within the blood vessel network of the lung resulting in engorgement of pulmonary vessels",
36
+ "picc": "A peripherally inserted central catheter (PICC), also called a PICC line, is a long, thin tube that's inserted through a vein in your arm and passed through to the larger veins near your heart. ",
37
+ "clip": "Surgical clips or vascular clips usually represent the one kind of medical equipments.",
38
+ "elevation": "If tissues or anatomical structures are elevated, they are raised up higher than the normal location.",
39
+ "expand": "It means the lungs are normally expanded and clear, indicating the absence of pneumothorax.",
40
+ "nodule": "A lung nodule or pulmonary nodule is a relatively small focal density in the lung. it may be confused with the projection of a structure of the chest wall or skin, such as a nipple, a healing rib fracture or lung cancer.",
41
+ "wire": "sternotomis wires means the center line of the chest.",
42
+ "fluid": "It refers to the water of liquid in the lung and it may indicate edema and other diseases.",
43
+ "degenerative": "Degenerative disease is the result of a continuous process based on degenerative cell changes",
44
+ "pacemaker": "pacemaker device usually represents the one kind of medical equipments.",
45
+ "thicken": "Pleural thickening is an increase in the bulkiness of one or both of the pulmonary pleurae. It may cause by pulmonary Infection, empyema, tuberculosis or lung cancer.",
46
+ "marking": "It represents interstitial markings or bronchovascular markings",
47
+ "scar": "A scar (or scar tissue) is an area of fibrous tissue that replaces normal tissues after an injury.",
48
+ "hyperinflate": "Hyperinflated lungs are larger-than-normal lungs as a result of trapped air.",
49
+ "blunt": "Blunting of the costophrenic angles is usually caused by a pleural effusion, as already discussed. Other causes of costophrenic angle blunting include lung disease in the region of the costophrenic angle, and lung hyperexpansion.",
50
+ "loss": "The etiology of lung volume loss can be listed as follows: airway obstruction or compression, obesity, scoliosis, restrictive diseases such as pulmonary fibrosis and interstitial lung disease, tuberculosis, sarcoidosis, pleural effusions, rib injury (fractures or diaphragm paralysis), and heart failure",
51
+ "widen": "The mediastinum is not widened or enlarged",
52
+ "collapse": "collapse lung refers to pneumothorax or atelectasis.",
53
+ "density": "The density (more precisely, the volumetric mass density; also known as specific mass), of a substance is its mass per unit volume. ",
54
+ "emphysema": "Emphysema, or pulmonary emphysema, is a lower respiratory tract disease, characterized by air-filled spaces (pneumatosis) in the lungs, that can vary in size and may be very large",
55
+ "aerate": "Aeration (also called aerification or aeriation) is the process by which air is circulated through, mixed with or dissolved in a liquid or other substances that act as a fluid (such as soil).",
56
+ "mass": "A lung mass is an abnormal growth or area in the lungs and it can also view as lung cancer.",
57
+ "crowd": "Crowding of the bronchovascular structures is an important direct sign of volume loss. The atelectatic lung enhances densely after contrast administration because of closeness of the pulmonary arteries and arterioles within the collapsed lobe.",
58
+ "infiltrate": "A pulmonary infiltrate is a substance denser than air, such as pus, blood, or protein, which lingers within the parenchyma of the lungs. Pulmonary infiltrates are associated with pneumonia, tuberculosis and sarcoidosis.",
59
+ "obscure": "Some anatomy structures are not clear and is difficult to understand or see",
60
+ "deformity": "It means some body parts are abnormal or unjuried.",
61
+ "hernia": "Lung hernia (Sibson hernia) is a protrusion of lung outside of thoracic wall. the hernia is noted after chest trauma, thoracic surgery or certain pulmonary diseases",
62
+ "drainage": "Tube drainage represents the one kind of medical equipment",
63
+ "distention": "Distension generally refers to an enlargement, dilation, or ballooning effect. It may refer to: Abdominal distension,",
64
+ "shift": "The mediastinal shift is the deviation of the mediastinal structures towards one side of the chest cavity, usually seen on chest radiograph. It indicates a severe asymmetry of intrathoracic pressures.",
65
+ "stent": "tracheal stent represents the one kind of medical equipments",
66
+ "pressure": "Pulmonary venous pressure is intermediate between mean PAP and LAP over all physiologic pressures",
67
+ "lesion": "Lung nodules, pulmonary nodules, white spots, lesions\u2014these terms all describe the same phenomenon: an abnormality in the lungs.",
68
+ "finding": "Some observation on body parts, usually indicating abnormalty.",
69
+ "borderline": "borderline size of the cardiac silhouette means the cardiac silhouette is not enlarged and normal.",
70
+ "hardware": "It represents the one kind of medical equipments.",
71
+ "dilation": "the state of being larger or more open than normal",
72
+ "chf": "Heart failure \u2014 sometimes known as congestive heart failure \u2014 occurs when the heart muscle doesn't pump blood as well as it should. When this happens, blood often backs up and fluid can build up in the lungs, causing shortness of breath.",
73
+ "redistribution": "If the pulmonary edema is due to heart failure or fluid overload, you may also see cardiomegaly and distension of the pulmonary veins, particularly in the upper lung fields.",
74
+ "aspiration": "Aspiration pneumonia occurs when food or liquid is breathed into the airways or lungs, instead of being swallowed. ",
75
+ "tail_abnorm_obs": "Some very rare diseases.",
76
+ "excluded_obs": "Some observations that seldom appear in the reports."
77
+ }
PreTrain_MeDSLIP/data_file/preprocessing/adj_matrix.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copied from AGXNet:
3
+ https://github.com/batmanlab/AGXNet
4
+ """
5
+
6
+ """Create adjacency matrix for representing the relations between anatomical landmarks and observations."""
7
+
8
+ import argparse
9
+ import pandas as pd
10
+ import numpy as np
11
+ import pickle
12
+
13
+ from tqdm import tqdm, trange
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+
17
+ parser = argparse.ArgumentParser(description="Create Adjacency matrix Matrix.")
18
+
19
+ parser.add_argument(
20
+ "--input-path",
21
+ default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-sentence-parsed.csv",
22
+ help="Itemized input data path.",
23
+ )
24
+
25
+ # List of most common normal observations
26
+ NORM_OBS = [
27
+ "normal",
28
+ "clear",
29
+ "sharp",
30
+ "sharply",
31
+ "unremarkable",
32
+ "intact",
33
+ "stable",
34
+ "free",
35
+ ]
36
+
37
+ # exclude
38
+ EXCLUDED_OBS = [
39
+ "none",
40
+ "unchanged",
41
+ "change",
42
+ "great",
43
+ "similar",
44
+ "large",
45
+ "small",
46
+ "moderate",
47
+ "mild",
48
+ "median",
49
+ "decrease",
50
+ "bad",
51
+ "more",
52
+ "constant",
53
+ "worsen",
54
+ "new",
55
+ "improve",
56
+ "status",
57
+ "position",
58
+ "sternotomy",
59
+ "cabg",
60
+ "replacement",
61
+ "postoperative",
62
+ "assessment",
63
+ "patient",
64
+ ]
65
+
66
+ # top 90% abnormal observations
67
+ ABNORM_OBS = [
68
+ "effusion",
69
+ "opacity",
70
+ "pneumothorax",
71
+ "edema",
72
+ "atelectasis",
73
+ "tube",
74
+ "consolidation",
75
+ "process",
76
+ "abnormality",
77
+ "enlarge",
78
+ "tip",
79
+ "low",
80
+ "pneumonia",
81
+ "line",
82
+ "congestion",
83
+ "catheter",
84
+ "cardiomegaly",
85
+ "fracture",
86
+ "air",
87
+ "tortuous",
88
+ "lead",
89
+ "disease",
90
+ "calcification",
91
+ "prominence",
92
+ "device",
93
+ "engorgement",
94
+ "picc",
95
+ "clip",
96
+ "elevation",
97
+ "expand",
98
+ "nodule",
99
+ "wire",
100
+ "fluid",
101
+ "degenerative",
102
+ "pacemaker",
103
+ "thicken",
104
+ "marking",
105
+ "scar",
106
+ "hyperinflate",
107
+ "blunt",
108
+ "loss",
109
+ "widen",
110
+ "collapse",
111
+ "density",
112
+ "emphysema",
113
+ "aerate",
114
+ "mass",
115
+ "crowd",
116
+ "infiltrate",
117
+ "obscure",
118
+ "deformity",
119
+ "hernia",
120
+ "drainage",
121
+ "distention",
122
+ "shift",
123
+ "stent",
124
+ "pressure",
125
+ "lesion",
126
+ "finding",
127
+ "borderline",
128
+ "hardware",
129
+ "dilation",
130
+ "chf",
131
+ "redistribution",
132
+ "aspiration",
133
+ ]
134
+
135
+ # final row and column names in adjacent matrix
136
+ LANDMARK_NAME = [
137
+ "trachea",
138
+ "left_hilar",
139
+ "right_hilar",
140
+ "hilar_unspec",
141
+ "left_pleural",
142
+ "right_pleural",
143
+ "pleural_unspec",
144
+ "heart_size",
145
+ "heart_border",
146
+ "left_diaphragm",
147
+ "right_diaphragm",
148
+ "diaphragm_unspec",
149
+ "retrocardiac",
150
+ "lower_left_lobe",
151
+ "upper_left_lobe",
152
+ "lower_right_lobe",
153
+ "middle_right_lobe",
154
+ "upper_right_lobe",
155
+ "left_lower_lung",
156
+ "left_mid_lung",
157
+ "left_upper_lung",
158
+ "left_apical_lung",
159
+ "left_lung_unspec",
160
+ "right_lower_lung",
161
+ "right_mid_lung",
162
+ "right_upper_lung",
163
+ "right_apical_lung",
164
+ "right_lung_unspec",
165
+ "lung_apices",
166
+ "lung_bases",
167
+ "left_costophrenic",
168
+ "right_costophrenic",
169
+ "costophrenic_unspec",
170
+ "cardiophrenic_sulcus",
171
+ "mediastinal",
172
+ "spine",
173
+ "clavicle",
174
+ "rib",
175
+ "stomach",
176
+ "right_atrium",
177
+ "right_ventricle",
178
+ "aorta",
179
+ "svc",
180
+ "interstitium",
181
+ "parenchymal",
182
+ "cavoatrial_junction",
183
+ "cardiopulmonary",
184
+ "pulmonary",
185
+ "lung_volumes",
186
+ "unspecified",
187
+ "other",
188
+ ]
189
+
190
+ OBSERVATION_CLASS = [
191
+ "normal",
192
+ "clear",
193
+ "sharp",
194
+ "sharply",
195
+ "unremarkable",
196
+ "intact",
197
+ "stable",
198
+ "free",
199
+ "effusion",
200
+ "opacity",
201
+ "pneumothorax",
202
+ "edema",
203
+ "atelectasis",
204
+ "tube",
205
+ "consolidation",
206
+ "process",
207
+ "abnormality",
208
+ "enlarge",
209
+ "tip",
210
+ "low",
211
+ "pneumonia",
212
+ "line",
213
+ "congestion",
214
+ "catheter",
215
+ "cardiomegaly",
216
+ "fracture",
217
+ "air",
218
+ "tortuous",
219
+ "lead",
220
+ "disease",
221
+ "calcification",
222
+ "prominence",
223
+ "device",
224
+ "engorgement",
225
+ "picc",
226
+ "clip",
227
+ "elevation",
228
+ "expand",
229
+ "nodule",
230
+ "wire",
231
+ "fluid",
232
+ "degenerative",
233
+ "pacemaker",
234
+ "thicken",
235
+ "marking",
236
+ "scar",
237
+ "hyperinflate",
238
+ "blunt",
239
+ "loss",
240
+ "widen",
241
+ "collapse",
242
+ "density",
243
+ "emphysema",
244
+ "aerate",
245
+ "mass",
246
+ "crowd",
247
+ "infiltrate",
248
+ "obscure",
249
+ "deformity",
250
+ "hernia",
251
+ "drainage",
252
+ "distention",
253
+ "shift",
254
+ "stent",
255
+ "pressure",
256
+ "lesion",
257
+ "finding",
258
+ "borderline",
259
+ "hardware",
260
+ "dilation",
261
+ "chf",
262
+ "redistribution",
263
+ "aspiration",
264
+ "tail_abnorm_obs",
265
+ "excluded_obs",
266
+ ]
267
+
268
+ DICT_ANATOMICAL_LANDMARKS = {
269
+ "trachea": {"a": ["trachea", "tracheal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
270
+ "left_hilar": {
271
+ "a": ["hilar", "hilum", "perihilar", "infrahilar"],
272
+ "m1": ["left"],
273
+ "m2": ["right"],
274
+ "sc": [],
275
+ "t": "m1+m2-",
276
+ },
277
+ "right_hilar": {
278
+ "a": ["hilar", "hilum", "perihilar", "infrahilar"],
279
+ "m1": ["right"],
280
+ "m2": ["left"],
281
+ "sc": [],
282
+ "t": "m1+m2-",
283
+ },
284
+ "hilar_unspec": {
285
+ "a": ["hilar", "hilum", "perihilar", "infrahilar"],
286
+ "m1": ["left", "right"],
287
+ "m2": [],
288
+ "sc": ["hila", "perihilar|right|left", "perihilar|left|right"],
289
+ "t": "m1-",
290
+ },
291
+ "left_pleural": {
292
+ "a": ["pleural"],
293
+ "m1": ["left"],
294
+ "m2": ["right"],
295
+ "sc": [],
296
+ "t": "m1+m2-",
297
+ },
298
+ "right_pleural": {
299
+ "a": ["pleural"],
300
+ "m1": ["right"],
301
+ "m2": ["left"],
302
+ "sc": [],
303
+ "t": "m1+m2-",
304
+ },
305
+ "pleural_unspec": {
306
+ "a": ["pleural"],
307
+ "m1": ["left", "right"],
308
+ "m2": [],
309
+ "sc": [
310
+ "pleural|left|right",
311
+ "pleural|right|left",
312
+ "pleural|bilateral|right|left",
313
+ "pleural|bilateral|left|right",
314
+ ],
315
+ "t": "m1-",
316
+ },
317
+ "heart_size": {
318
+ "a": ["heart", "cardiac"],
319
+ "m1": ["border", "borders"],
320
+ "m2": [],
321
+ "sc": [],
322
+ "t": "m1-",
323
+ },
324
+ "heart_border": {
325
+ "a": ["heart", "cardiac"],
326
+ "m1": ["border", "borders"],
327
+ "m2": [],
328
+ "sc": [],
329
+ "t": "m1+",
330
+ },
331
+ "left_diaphragm": {
332
+ "a": ["diaphragm", "hemidiaphragm"],
333
+ "m1": ["left"],
334
+ "m2": ["right"],
335
+ "sc": [],
336
+ "t": "m1+m2-",
337
+ },
338
+ "right_diaphragm": {
339
+ "a": ["diaphragm", "hemidiaphragm"],
340
+ "m1": ["right"],
341
+ "m2": ["left"],
342
+ "sc": [],
343
+ "t": "m1+m2-",
344
+ },
345
+ "diaphragm_unspec": {
346
+ "a": ["diaphragm", "diaphragms", "hemidiaphragms", "hemidiaphragm"],
347
+ "m1": ["left", "right"],
348
+ "m2": [],
349
+ "sc": ["hemidiaphragm|left|right", "hemidiaphragm|right|left"],
350
+ "t": "m1-",
351
+ },
352
+ "retrocardiac": {"a": ["retrocardiac"], "m1": [], "m2": [], "sc": [], "t": "m0"},
353
+ "lower_left_lobe": {
354
+ "a": ["lobe"],
355
+ "m1": ["left"],
356
+ "m2": ["lower"],
357
+ "sc": [],
358
+ "t": "m1+m2+",
359
+ },
360
+ "upper_left_lobe": {
361
+ "a": ["lobe"],
362
+ "m1": ["left"],
363
+ "m2": ["upper"],
364
+ "sc": ["lingula", "lingular"],
365
+ "t": "m1+m2+",
366
+ },
367
+ "lower_right_lobe": {
368
+ "a": ["lobe"],
369
+ "m1": ["right"],
370
+ "m2": ["lower"],
371
+ "sc": [],
372
+ "t": "m1+m2+",
373
+ },
374
+ "middle_right_lobe": {
375
+ "a": ["lobe"],
376
+ "m1": ["right"],
377
+ "m2": ["middle"],
378
+ "sc": [],
379
+ "t": "m1+m2+",
380
+ },
381
+ "upper_right_lobe": {
382
+ "a": ["lobe"],
383
+ "m1": ["right"],
384
+ "m2": ["upper"],
385
+ "sc": [],
386
+ "t": "m1+m2+",
387
+ },
388
+ "left_lower_lung": {
389
+ "a": ["lung"],
390
+ "m1": ["left"],
391
+ "m2": ["lower", "base", "basilar", "basal", "basis"],
392
+ "sc": ["base|left", "basilar|left", "basal|left", "lung|left|bases"],
393
+ "t": "m1+m2+",
394
+ },
395
+ "left_mid_lung": {
396
+ "a": ["lung"],
397
+ "m1": ["left"],
398
+ "m2": ["middle", "mid"],
399
+ "sc": ["midlung|left"],
400
+ "t": "m1+m2+",
401
+ },
402
+ "left_upper_lung": {
403
+ "a": ["lung"],
404
+ "m1": ["left"],
405
+ "m2": ["upper"],
406
+ "sc": [],
407
+ "t": "m1+m2+",
408
+ },
409
+ "left_apical_lung": {
410
+ "a": ["apex", "apical", "apical", "apicolateral"],
411
+ "m1": ["left"],
412
+ "m2": ["right"],
413
+ "sc": [],
414
+ "t": "m1+m2-",
415
+ },
416
+ "left_lung_unspec": {
417
+ "a": ["lung", "hemithorax"],
418
+ "m1": ["left", "left-sided"],
419
+ "m2": [
420
+ "volume",
421
+ "volumes",
422
+ "right",
423
+ "lower",
424
+ "base",
425
+ "bases",
426
+ "basilar",
427
+ "basilar",
428
+ "basal",
429
+ "basis",
430
+ "middle",
431
+ "mid",
432
+ "upper",
433
+ "apex",
434
+ "apical",
435
+ "perihilar",
436
+ ],
437
+ "sc": ["left", "left side", "thorax|left|hemi"],
438
+ "t": "m1+m2-",
439
+ },
440
+ "right_lower_lung": {
441
+ "a": ["lung"],
442
+ "m1": ["right"],
443
+ "m2": ["lower", "base", "basilar", "basal", "basis"],
444
+ "sc": ["base|right", "basilar|right", "basal|right", "lung|right|bases"],
445
+ "t": "m1+m2+",
446
+ },
447
+ "right_mid_lung": {
448
+ "a": ["lung"],
449
+ "m1": ["right"],
450
+ "m2": ["middle", "mid"],
451
+ "sc": [],
452
+ "t": "m1+m2+",
453
+ },
454
+ "right_upper_lung": {
455
+ "a": ["lung"],
456
+ "m1": ["right"],
457
+ "m2": ["upper"],
458
+ "sc": [],
459
+ "t": "m1+m2+",
460
+ },
461
+ "right_apical_lung": {
462
+ "a": ["apex", "apical", "apical", "apicolateral"],
463
+ "m1": ["right"],
464
+ "m2": ["left"],
465
+ "sc": [],
466
+ "t": "m1+m2-",
467
+ },
468
+ "right_lung_unspec": {
469
+ "a": ["lung", "hemithorax"],
470
+ "m1": ["right", "right-sided"],
471
+ "m2": [
472
+ "volume",
473
+ "volumes",
474
+ "left",
475
+ "lower",
476
+ "base",
477
+ "bases",
478
+ "basilar",
479
+ "basilar",
480
+ "basal",
481
+ "basis",
482
+ "middle",
483
+ "mid",
484
+ "upper",
485
+ "apex",
486
+ "apical",
487
+ "perihilar",
488
+ ],
489
+ "sc": ["right", "right side", "thorax|right|hemi"],
490
+ "t": "m1+m2-",
491
+ },
492
+ "lung_apices": {
493
+ "a": ["apices", "apical"],
494
+ "m1": ["left", "right"],
495
+ "m2": [],
496
+ "sc": ["biapical", "lungs|upper"],
497
+ "t": "m1-",
498
+ },
499
+ "lung_bases": {
500
+ "a": ["lung", "lungs"],
501
+ "m1": ["left", "right"],
502
+ "m2": ["bibasilar", "basilar", "base", "bases", "bibasal", "basal"],
503
+ "sc": [
504
+ "lung|lower",
505
+ "lungs|lower",
506
+ "bibasilar",
507
+ "basilar",
508
+ "bases",
509
+ "bibasal",
510
+ "basal",
511
+ "basal|bilateral",
512
+ "lobe|lower",
513
+ "lobes|lower",
514
+ "lobe|bilateral|lower",
515
+ "bases|both",
516
+ "bibasilar|left|right",
517
+ "bibasilar|right|left",
518
+ ],
519
+ "t": "m1-m2+",
520
+ },
521
+ "left_costophrenic": {
522
+ "a": ["costophrenic"],
523
+ "m1": ["left"],
524
+ "m2": ["right"],
525
+ "sc": [],
526
+ "t": "m1+m2-",
527
+ },
528
+ "right_costophrenic": {
529
+ "a": ["costophrenic"],
530
+ "m1": ["right"],
531
+ "m2": ["left"],
532
+ "sc": [],
533
+ "t": "m1+m2-",
534
+ },
535
+ "costophrenic_unspec": {
536
+ "a": ["costophrenic"],
537
+ "m1": ["left", "right"],
538
+ "m2": [],
539
+ "sc": [],
540
+ "t": "m1-",
541
+ },
542
+ "cardiophrenic_sulcus": {
543
+ "a": ["cardiophrenic"],
544
+ "m1": [],
545
+ "m2": [],
546
+ "sc": [],
547
+ "t": "m0",
548
+ },
549
+ "mediastinal": {
550
+ "a": ["mediastinal", "cardiomediastinal", "mediastinum", "cardiomediastinum"],
551
+ "m1": [],
552
+ "m2": [],
553
+ "sc": [],
554
+ "t": "m0",
555
+ },
556
+ "spine": {"a": ["spine", "spinal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
557
+ "clavicle": {
558
+ "a": ["clavicle", "clavicles"],
559
+ "m1": [],
560
+ "m2": [],
561
+ "sc": [],
562
+ "t": "m0",
563
+ },
564
+ "rib": {"a": ["rib", "ribs"], "m1": [], "m2": [], "sc": [], "t": "m0"},
565
+ "stomach": {
566
+ "a": ["stomach", "abdomen", "abdominal"],
567
+ "m1": [],
568
+ "m2": [],
569
+ "sc": [],
570
+ "t": "m0",
571
+ },
572
+ "right_atrium": {
573
+ "a": ["atrium", "atrial"],
574
+ "m1": ["right"],
575
+ "m2": ["left"],
576
+ "sc": [],
577
+ "t": "m1+m2-",
578
+ },
579
+ "right_ventricle": {
580
+ "a": ["ventricle", "ventricular"],
581
+ "m1": ["right"],
582
+ "m2": ["left"],
583
+ "sc": [],
584
+ "t": "m1+m2-",
585
+ },
586
+ "aorta": {"a": ["aorta", "aortic"], "m1": [], "m2": [], "sc": [], "t": "m0"},
587
+ "svc": {"a": ["svc"], "m1": [], "m2": [], "sc": [], "t": "m0"},
588
+ "interstitium": {
589
+ "a": ["interstitium", "interstitial"],
590
+ "m1": [],
591
+ "m2": [],
592
+ "sc": [],
593
+ "t": "m0",
594
+ },
595
+ "parenchymal": {"a": ["parenchymal"], "m1": [], "m2": [], "sc": [], "t": "m0"},
596
+ "cavoatrial_junction": {
597
+ "a": ["cavoatrial junction"],
598
+ "m1": [],
599
+ "m2": [],
600
+ "sc": [],
601
+ "t": "m0",
602
+ },
603
+ "cardiopulmonary": {
604
+ "a": ["cardiopulmonary"],
605
+ "m1": [],
606
+ "m2": [],
607
+ "sc": [],
608
+ "t": "m0",
609
+ },
610
+ "pulmonary": {"a": ["pulmonary"], "m1": [], "m2": [], "sc": [], "t": "m0"},
611
+ "lung_volumes": {
612
+ "a": ["lungs", "lung", "volume", "volumes"],
613
+ "m1": [
614
+ "left",
615
+ "right",
616
+ "lower",
617
+ "base",
618
+ "bases",
619
+ "basilar",
620
+ "basal",
621
+ "basis",
622
+ "middle",
623
+ "mid",
624
+ "upper",
625
+ "apex",
626
+ "apical",
627
+ "apical",
628
+ ],
629
+ "m2": [],
630
+ "sc": [],
631
+ "t": "m1-",
632
+ },
633
+ }
634
+
635
+
636
+ class LandmarkObservationAdjacentMatrix(Dataset):
637
+ def __init__(self, LANDMARK_NAME, OBSERVATION_CLASS, df_anatomy_label):
638
+ self.LANDMARK_NAME = LANDMARK_NAME
639
+ self.OBSERVATION_CLASS = OBSERVATION_CLASS
640
+ self.df_anatomy_label = df_anatomy_label
641
+
642
+ # get all study ids
643
+ self.sids = list(self.df_anatomy_label["study_id"].unique())
644
+
645
+ def __getitem__(self, idx):
646
+ sid = self.sids[idx]
647
+ df_sid = self.df_anatomy_label[self.df_anatomy_label["study_id"] == sid]
648
+ landmark_observation_adj_mtx = (
649
+ np.zeros((len(LANDMARK_NAME), len(OBSERVATION_CLASS))) - 1.0
650
+ )
651
+ for index, row in df_sid.iterrows():
652
+ try:
653
+ observation_idx = self.OBSERVATION_CLASS.index(
654
+ row.obs_lemma_grp
655
+ ) # if a rare observation, skip this instance
656
+ landmark_idx = self.LANDMARK_NAME.index(row.landmark_name)
657
+
658
+ curr_val = landmark_observation_adj_mtx[landmark_idx, observation_idx]
659
+
660
+ # for obs_lemma_grp, such as tail_abnorm_obs
661
+ # if one observation is DP, then 1.0
662
+ if row.label == "OBS-DP":
663
+ landmark_observation_adj_mtx[landmark_idx, observation_idx] = 1.0
664
+ elif row.label == "OBS-DA":
665
+ landmark_observation_adj_mtx[
666
+ landmark_idx, observation_idx
667
+ ] = np.maximum(curr_val, 0.0)
668
+ except:
669
+ pass
670
+ return sid, landmark_observation_adj_mtx
671
+
672
+ def __len__(self):
673
+ return len(self.sids)
674
+
675
+
676
+ def anatomy_to_landmark(x, a, m1=[], m2=[], sc=[], t="m0"):
677
+ """
678
+ Args:
679
+ x: input anatomy, e.g., "lobe|left|lower"
680
+ a: base anatomy set, e.g., ["hilar", "hilum", "perihilar"]
681
+ m1: level 1 modifier, e.g., ["left", "right"]
682
+ m2: level 2 modifier, e.g., ["upper", "middle", "lower"]
683
+ s: special cases, e.g., ["chest"]
684
+ t: type, ["m2+", "m1+m2-"]
685
+ Return:
686
+ flag: boolean, matched or not matched
687
+ """
688
+ s = set(x.split("|"))
689
+ if t == "m1+m2+":
690
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0) & (len(s & set(m2)) > 0)
691
+ elif t == "m1+m2-":
692
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0) & (len(s & set(m2)) == 0)
693
+ elif t == "m1-m2+":
694
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0) & (len(s & set(m2)) > 0)
695
+ elif t == "m1-m2-":
696
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0) & (len(s & set(m2)) == 0)
697
+ elif t == "m1+":
698
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) > 0)
699
+ elif t == "m2+":
700
+ flag = (len(s & set(a)) > 0) & (len(s & set(m2)) > 0)
701
+ elif t == "m1-":
702
+ flag = (len(s & set(a)) > 0) & (len(s & set(m1)) == 0)
703
+ elif t == "m2-":
704
+ flag = (len(s & set(a)) > 0) & (len(s & set(m2)) == 0)
705
+ elif t == "m0":
706
+ flag = len(s & set(a)) > 0
707
+
708
+ if sc:
709
+ flag = flag | (x in sc)
710
+ return flag
711
+
712
+
713
+ def create_adj_matrix(args):
714
+ # load anatomy label table, text table and master table
715
+ print("Loading parsed RadGraph data...")
716
+ df_anatomy_label = pd.read_csv(args.input_path, dtype=str)
717
+
718
+ # manual lemmatization correction
719
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["enlargement", "increase"])
720
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "enlarge"
721
+
722
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["engorge"])
723
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "engorgement"
724
+
725
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["opacification", "opacity-"])
726
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "opacity"
727
+
728
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["calcify"])
729
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "calcification"
730
+
731
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["effusion ;"])
732
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "effusion"
733
+
734
+ idx_replace = df_anatomy_label["obs_lemma"].isin(
735
+ ["atelectatic", "atelectasis ;", "atelectase"]
736
+ )
737
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "atelectasis"
738
+
739
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["aeration"])
740
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "aerate"
741
+
742
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["distend", "distension"])
743
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "distention"
744
+
745
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["wide"])
746
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "widen"
747
+
748
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["prominent"])
749
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "prominence"
750
+
751
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["haze"])
752
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "haziness"
753
+
754
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["masse"])
755
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "mass"
756
+
757
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["kyphotic"])
758
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "kyphosis"
759
+
760
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["degenerate"])
761
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "degenerative"
762
+
763
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["obscuration"])
764
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "obscure"
765
+
766
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["fibrotic"])
767
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "fibrosis"
768
+
769
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["nodular", "nodularity"])
770
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "nodule"
771
+
772
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["ventilate"])
773
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "ventilation"
774
+
775
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["tortuosity"])
776
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "tortuous"
777
+
778
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["elongate"])
779
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "elongation"
780
+
781
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["elevate"])
782
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "elevation"
783
+
784
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["drain"])
785
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "drainage"
786
+
787
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["deviate"])
788
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "deviation"
789
+
790
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["consolidative", "consolidate"])
791
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "consolidation"
792
+
793
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["dilate", "dilatation"])
794
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "dilation"
795
+
796
+ idx_replace = df_anatomy_label["obs_lemma"].isin(
797
+ ["hydropneumothorax", "pneumothoraces", "pneumothorace"]
798
+ )
799
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pneumothorax"
800
+
801
+ idx_replace = df_anatomy_label["obs_lemma"].isin(["improvement", "improved"])
802
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "improve"
803
+
804
+ idx_replace = df_anatomy_label["obs_lemma"].isin(
805
+ [
806
+ "can not be assess",
807
+ "can not be evaluate",
808
+ "not well see",
809
+ "not well assess",
810
+ "can not be accurately assess",
811
+ "not well evaluate",
812
+ "not well visualize",
813
+ "difficult to evaluate",
814
+ "poorly see",
815
+ ]
816
+ )
817
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "difficult to assess"
818
+
819
+ idx_replace = df_anatomy_label["obs_lemma"] == "pacer"
820
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pacemaker"
821
+
822
+ idx_replace = df_anatomy_label["obs_lemma"].isin(
823
+ ["infection", "infectious", "infectious process"]
824
+ )
825
+ df_anatomy_label.loc[idx_replace, "obs_lemma"] = "pneumonia"
826
+
827
+ df_anatomy_label.loc[df_anatomy_label["label"].isna(), "label"] = "OBS-NA"
828
+
829
+ # step 1: map anatomy name to landmark name
830
+ landmark_name = []
831
+ for index, row in tqdm(
832
+ df_anatomy_label.iterrows(), total=df_anatomy_label.shape[0]
833
+ ):
834
+ x = row.anatomy
835
+ flag = False
836
+ for k, v in DICT_ANATOMICAL_LANDMARKS.items():
837
+ flag = anatomy_to_landmark(x, v["a"], v["m1"], v["m2"], v["sc"], v["t"])
838
+ if flag:
839
+ landmark_name.append(k)
840
+ break
841
+ if (not flag) & (row.anatomy == "unspecified"):
842
+ landmark_name.append("unspecified")
843
+ elif (not flag) & (row.anatomy != "unspecified"):
844
+ landmark_name.append("other")
845
+
846
+ df_anatomy_label["landmark_name"] = landmark_name
847
+
848
+ # create a new obs_lemma column to grouop other abnormal observation class
849
+ df_anatomy_label["obs_lemma_grp"] = df_anatomy_label["obs_lemma"]
850
+
851
+ idx1 = df_anatomy_label["obs_lemma"].isin(NORM_OBS)
852
+ idx2 = df_anatomy_label["obs_lemma"].isin(ABNORM_OBS)
853
+ idx3 = df_anatomy_label["obs_lemma"].isin(EXCLUDED_OBS)
854
+
855
+ df_anatomy_label.loc[idx3, "obs_lemma_grp"] = "excluded_obs"
856
+
857
+ idx = (~idx1) & (~idx2) & (~idx3) # abnormal observations that are in the tail
858
+ df_anatomy_label.loc[idx, "obs_lemma_grp"] = "tail_abnorm_obs"
859
+
860
+ # step 2: get landmark - observation adjacent matrix
861
+ dataset = LandmarkObservationAdjacentMatrix(
862
+ LANDMARK_NAME, OBSERVATION_CLASS, df_anatomy_label
863
+ )
864
+ loader = DataLoader(
865
+ dataset, batch_size=32, shuffle=False, num_workers=8, drop_last=False
866
+ )
867
+
868
+ sid_lst = []
869
+ adj_mtx_lst = []
870
+ for index, data in tqdm(enumerate(loader), total=len(loader)):
871
+ sid, landmark_observation_adj_mtx = data
872
+ sid_lst.append(sid)
873
+ adj_mtx_lst.append(landmark_observation_adj_mtx)
874
+
875
+ # step 3: convert outputs to a dictionary and then save to a pickel file
876
+ full_sids = np.concatenate(sid_lst, axis=0)
877
+ full_adj_mtx = np.concatenate(adj_mtx_lst, axis=0)
878
+ dict_adj_mtx = {}
879
+ for i in trange(len(full_sids)):
880
+ sid = full_sids[i]
881
+ dict_adj_mtx[sid] = full_adj_mtx[i]
882
+
883
+ np.save("landmark_observation_sids.npy", full_sids)
884
+ print("landmark_observation_sids.npy has been saved!")
885
+ np.save("landmark_observation_adj_mtx.npy", full_adj_mtx)
886
+ print("landmark_observation_sids.npy has been saved!")
887
+
888
+
889
+ if __name__ == "__main__":
890
+ args = parser.parse_args()
891
+ create_adj_matrix(args)
PreTrain_MeDSLIP/data_file/preprocessing/radgraph_itemized.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copied from AGXNet:
3
+ https://github.com/batmanlab/AGXNet
4
+ """
5
+
6
+ import argparse
7
+ import pandas as pd
8
+ import json
9
+ from tqdm import tqdm
10
+ import nltk
11
+
12
+
13
+ parser = argparse.ArgumentParser(description="Itemize RadGraph Dataset.")
14
+
15
+ parser.add_argument(
16
+ "--data-path",
17
+ default="/PATH TO RADGRAPH DATA/RadGraph/physionet.org/files/radgraph/1.0.0/MIMIC-CXR_graphs.json",
18
+ help="RadGraph data path.",
19
+ )
20
+ parser.add_argument(
21
+ "--output-path",
22
+ default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv",
23
+ help="Output path for itemized RadGraph data.",
24
+ )
25
+
26
+
27
+ def get_ids(key):
28
+ """Convert keys in the RadGraph file into IDs"""
29
+ lst = key.split("/")
30
+ partition = lst[0] # dataset partition
31
+ pid = lst[1][1:] # patient id
32
+ sid = lst[2].split(".")[0][1:] # study id, remove .txt
33
+ return partition, pid, sid
34
+
35
+
36
+ def get_sen_from_token_ix(text, ix):
37
+ """get the sentence to which the input token index belongs."""
38
+ sen_lst = nltk.sent_tokenize(text)
39
+ dict_ws = {}
40
+ ix_w = 0
41
+ ix_s = 0
42
+ for s in sen_lst:
43
+ words = nltk.word_tokenize(s)
44
+ for w in words:
45
+ dict_ws[ix_w] = ix_s
46
+ ix_w += 1
47
+ ix_s += 1
48
+ return dict_ws[ix], sen_lst[dict_ws[ix]]
49
+
50
+
51
+ def get_entity_relation(value):
52
+ """itemize each relation"""
53
+ source_lst = []
54
+ target_lst = []
55
+ token_lst = []
56
+ token_ix_lst = []
57
+ label_lst = []
58
+ relation_lst = []
59
+ sen_lst = []
60
+ sen_ix_lst = []
61
+
62
+ text = value["text"]
63
+
64
+ entities = value["entities"]
65
+ for k, v in entities.items():
66
+ six, sen = get_sen_from_token_ix(text, v["start_ix"])
67
+ relations = v["relations"]
68
+
69
+ # source node has no out going edge
70
+ if (len(relations) == 0) or (relations[0] is None):
71
+ source_lst.append(k)
72
+ token_ix_lst.append(v["start_ix"])
73
+ token_lst.append(v["tokens"])
74
+ label_lst.append(v["label"])
75
+ relation_lst.append(None)
76
+ target_lst.append(None)
77
+ sen_ix_lst.append(six)
78
+ sen_lst.append(sen)
79
+ else:
80
+ for r in relations:
81
+ source_lst.append(k)
82
+ token_ix_lst.append(v["start_ix"])
83
+ token_lst.append(v["tokens"])
84
+ label_lst.append(v["label"])
85
+ relation_lst.append(r[0])
86
+ target_lst.append(r[1])
87
+ sen_ix_lst.append(six)
88
+ sen_lst.append(sen)
89
+
90
+ # save outputs in a dataframe
91
+ return pd.DataFrame(
92
+ {
93
+ "source": source_lst,
94
+ "token": token_lst,
95
+ "token_ix": token_ix_lst,
96
+ "label": label_lst,
97
+ "relation": relation_lst,
98
+ "target": target_lst,
99
+ "sentence_ix": sen_ix_lst,
100
+ "sentence": sen_lst,
101
+ }
102
+ )
103
+
104
+
105
+ def radgraph_itemize(args):
106
+ """Convert nested RadGraph data to itemized examples."""
107
+
108
+ print("Loading RadGraph data...")
109
+ f = open(args.data_path)
110
+ data = json.load(f)
111
+ print("RadGraph data is loaded.")
112
+
113
+ # create itemized RadGraph data
114
+ df_lst = []
115
+ pid_lst = []
116
+ sid_lst = []
117
+ text_lst = []
118
+ print("Itemizing RadGraph data...")
119
+ for key, value in tqdm(data.items()):
120
+ _, pid, sid = get_ids(key)
121
+ pid_lst.append(pid)
122
+ sid_lst.append(sid)
123
+ text_lst.append(data[key]["text"])
124
+ df = get_entity_relation(value)
125
+ df["subject_id"] = pid
126
+ df["study_id"] = sid
127
+ df_lst.append(df)
128
+
129
+ # entity level dataframe
130
+ df_itemized = pd.concat(df_lst)
131
+
132
+ # save dataframes to a .csv file
133
+ df_itemized.to_csv(args.output_path, index=False)
134
+ print("Outputs have been saved!")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ args = parser.parse_args()
139
+ radgraph_itemize(args)
PreTrain_MeDSLIP/data_file/preprocessing/radgraph_parsed.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copied from AGXNet:
3
+ https://github.com/batmanlab/AGXNet
4
+ """
5
+
6
+ import argparse
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ import spacy
10
+
11
+ sp = spacy.load("en_core_web_sm")
12
+
13
+ parser = argparse.ArgumentParser(description="Pharse RadGraph Relations.")
14
+
15
+ parser.add_argument(
16
+ "--input-path",
17
+ default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-itemized.csv",
18
+ help="Itemized input data path.",
19
+ )
20
+ parser.add_argument(
21
+ "--output-path",
22
+ default="/PROJECT DIR/preprocessing/mimic-cxr-radgraph-sentence-parsed.csv",
23
+ help="Output path for parsed relations.",
24
+ )
25
+
26
+
27
+ def obs_lemmatization(x):
28
+ """
29
+ Lemmatize observation
30
+ Args:
31
+ x: a observation token
32
+ Return:
33
+ normalized observation
34
+ """
35
+ w_lst = []
36
+ for word in sp(str(x)):
37
+ w_lst.append(word.lemma_)
38
+ return " ".join(w_lst)
39
+
40
+
41
+ def radgraph_parse(args):
42
+ """Pharse RadGraph relations."""
43
+
44
+ print("Loading itemized RadGraph data...")
45
+ df_itemized = pd.read_csv(args.input_path)
46
+
47
+ # get all study_id
48
+ sid_lst = list(df_itemized["study_id"].unique())
49
+
50
+ tuple_lst = []
51
+ print("Preprocessing sentences...")
52
+ for sid in tqdm(sid_lst):
53
+ idx_s = df_itemized["study_id"] == sid
54
+ df_sid = df_itemized[idx_s]
55
+
56
+ # unique sentence index
57
+ sen_ids = list(df_sid["sentence_ix"].unique())
58
+
59
+ for si in sen_ids:
60
+ idx_sen = df_sid["sentence_ix"] == si
61
+ df_sen = df_sid[idx_sen]
62
+ sen = df_sen["sentence"].iloc[0]
63
+
64
+ # step 1, select all target anatomy entities (e.g., lobe) with label = ANAT-DP and target = NaN
65
+ idx_a = (df_sen["label"] == "ANAT-DP") & (df_sen["target"].isnull())
66
+ df_a = df_sen[idx_a]
67
+
68
+ if sum(idx_a) > 0:
69
+ for _, row_a in df_a.iterrows():
70
+ anatomy_source_keys = []
71
+ sen = row_a.sentence
72
+ source_key = row_a.source
73
+
74
+ # step 2, get detailed target anatomy (e.g., lower left lobe)
75
+ token_a = [row_a["token"].lower()]
76
+ anatomy_source_keys.append(source_key)
77
+ idx_t = (df_sen["label"] == "ANAT-DP") & (
78
+ df_sen["target"] == source_key
79
+ )
80
+ if sum(idx_t) > 0:
81
+ df_t = df_sen[idx_t]
82
+ for _, row in df_t.iterrows():
83
+ token_a += [row["token"].lower()]
84
+ anatomy_source_keys.append(
85
+ row["source"]
86
+ ) # save keys of all anatomy token, i.e., lower, left, lobe
87
+ anatomy = "|".join(token_a)
88
+
89
+ else:
90
+ anatomy = row_a["token"].lower()
91
+
92
+ # step 3: get observations associated with the target anatomy (e.g., normal, effusion)
93
+ idx_o = (
94
+ (df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"]))
95
+ & (df_sen["target"].isin(anatomy_source_keys))
96
+ & (df_sen["relation"] == "located_at")
97
+ )
98
+ if sum(idx_o) > 0:
99
+ df_o = df_sen[idx_o]
100
+
101
+ anatomy_lst = []
102
+ obs_lst = []
103
+ label_lst = []
104
+ obs_modify_lst = []
105
+ obs_suggestive_lst = []
106
+
107
+ for _, row_o in df_o.iterrows():
108
+ anatomy_lst.append(anatomy)
109
+ obs_lst.append(row_o["token"].lower())
110
+ label_lst.append(row_o["label"])
111
+
112
+ # step 4: get obs modification
113
+ idx_o_m = (df_sen["target"] == row_o.source) & (
114
+ df_sen["relation"] == "modify"
115
+ )
116
+ obs_modify = None
117
+ if sum(idx_o_m) > 0:
118
+ df_o_m = df_sen[idx_o_m]
119
+ temp_lst = []
120
+ for _, row_om in df_o_m.iterrows():
121
+ # if the modification is present
122
+ if row_om.label == "OBS-DP":
123
+ temp_lst.append(row_om["token"].lower())
124
+ if len(temp_lst) > 0:
125
+ obs_modify = "|".join(temp_lst)
126
+ obs_modify_lst.append(obs_modify)
127
+
128
+ # step 5: get suggestive of obs
129
+ idx_o_s = (df_sen["target"] == row_o.source) & (
130
+ df_sen["relation"] == "suggestive_of"
131
+ )
132
+ obs_suggestive = None
133
+ if sum(idx_o_s) > 0:
134
+ df_o_s = df_sen[idx_o_s]
135
+ temp_lst = []
136
+ for _, row_os in df_o_s.iterrows():
137
+ # if the modification is present
138
+ if row_os.label == "OBS-DP":
139
+ temp_lst.append(row_os["token"].lower())
140
+ if len(temp_lst) > 0:
141
+ obs_suggestive = "|".join(temp_lst)
142
+ obs_suggestive_lst.append(obs_suggestive)
143
+
144
+ else:
145
+ anatomy_lst = [anatomy]
146
+ obs_lst = [None]
147
+ label_lst = [None]
148
+ obs_modify_lst = [None]
149
+ obs_suggestive_lst = [None]
150
+
151
+ # step 4: get observations that are not associated with the target anatomy
152
+ idx_oo = (
153
+ (df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"]))
154
+ & (df_sen["target"].isna())
155
+ & (df_sen["relation"].isna())
156
+ )
157
+ if sum(idx_oo) > 0:
158
+ df_oo = df_sen[idx_oo]
159
+ for _, row_oo in df_oo.iterrows():
160
+ anatomy_lst.append("unspecified")
161
+ obs_lst.append(row_oo["token"].lower())
162
+ label_lst.append(row_oo["label"])
163
+ # obs_modify_lst.append(None)
164
+ # obs_suggestive_lst.append(None)
165
+
166
+ # step 5: get obs modification
167
+ idx_o_m = (df_sen["target"] == row_oo.source) & (
168
+ df_sen["relation"] == "modify"
169
+ )
170
+ obs_modify = None
171
+ if sum(idx_o_m) > 0:
172
+ df_o_m = df_sen[idx_o_m]
173
+ temp_lst = []
174
+ for _, row_om in df_o_m.iterrows():
175
+ # if the modification is present
176
+ if row_om.label == "OBS-DP":
177
+ temp_lst.append(row_om["token"].lower())
178
+ if len(temp_lst) > 0:
179
+ obs_modify = "|".join(temp_lst)
180
+ obs_modify_lst.append(obs_modify)
181
+
182
+ # step 5: get suggestive of obs
183
+ idx_o_s = (df_sen["target"] == row_oo.source) & (
184
+ df_sen["relation"] == "suggestive_of"
185
+ )
186
+ obs_suggestive = None
187
+ if sum(idx_o_s) > 0:
188
+ df_o_s = df_sen[idx_o_s]
189
+ temp_lst = []
190
+ for _, row_os in df_o_s.iterrows():
191
+ # if the modification is present
192
+ if row_os.label == "OBS-DP":
193
+ temp_lst.append(row_os["token"].lower())
194
+ if len(temp_lst) > 0:
195
+ obs_suggestive = "|".join(temp_lst)
196
+ obs_suggestive_lst.append(obs_suggestive)
197
+
198
+ # step 6: create tuple of 7 values (sid, sentence_id, sentence, anatomy, obs, label)
199
+ t_lst = []
200
+ for i in range(len(obs_lst)):
201
+ t_lst.append(
202
+ (
203
+ sid,
204
+ si,
205
+ sen,
206
+ anatomy_lst[i],
207
+ obs_lst[i],
208
+ label_lst[i],
209
+ obs_modify_lst[i],
210
+ obs_suggestive_lst[i],
211
+ )
212
+ )
213
+
214
+ # remove duplicates caused by 1 obs "located_at" multiple anatomies
215
+ tuple_lst.append(list(set(t_lst)))
216
+
217
+ # if the sentence does not have any ANATOMY token
218
+ else:
219
+ idx_o = (df_sen["label"].isin(["OBS-DA", "OBS-DP", "OBS-U"])) & (
220
+ df_sen["target"].isnull()
221
+ )
222
+ if sum(idx_o) > 0:
223
+ df_o = df_sen[idx_o]
224
+
225
+ obs_lst = []
226
+ label_lst = []
227
+ obs_modify_lst = []
228
+ obs_suggestive_lst = []
229
+
230
+ for _, row_o in df_o.iterrows():
231
+ obs_lst.append(row_o["token"].lower())
232
+ label_lst.append(row_o["label"])
233
+
234
+ # step 4: get obs modification
235
+ idx_o_m = (df_sen["target"] == row_o.source) & (
236
+ df_sen["relation"] == "modify"
237
+ )
238
+ obs_modify = None
239
+ if sum(idx_o_m) > 0:
240
+ df_o_m = df_sen[idx_o_m]
241
+ temp_lst = []
242
+ for _, row_om in df_o_m.iterrows():
243
+ # if the modification is present
244
+ if row_om.label == "OBS-DP":
245
+ temp_lst.append(row_om["token"].lower())
246
+ if len(temp_lst) > 0:
247
+ obs_modify = "|".join(temp_lst)
248
+ obs_modify_lst.append(obs_modify)
249
+
250
+ # step 5: get suggestive of obs
251
+ idx_o_s = (df_sen["target"] == row_o.source) & (
252
+ df_sen["relation"] == "suggestive_of"
253
+ )
254
+ obs_suggestive = None
255
+ if sum(idx_o_s) > 0:
256
+ df_o_s = df_sen[idx_o_s]
257
+ temp_lst = []
258
+ for _, row_os in df_o_s.iterrows():
259
+ # if the modification is present
260
+ if row_os.label == "OBS-DP":
261
+ temp_lst.append(row_os["token"].lower())
262
+ if len(temp_lst) > 0:
263
+ obs_suggestive = "|".join(temp_lst)
264
+ obs_suggestive_lst.append(obs_suggestive)
265
+ else:
266
+ obs_lst = [None]
267
+ label_lst = [None]
268
+ obs_modify_lst = [None]
269
+ obs_suggestive_lst = [None]
270
+
271
+ # step 6: create tuple of 7 values (sid, sentence_id, sentence, anatomy, obs, label)
272
+ t_lst = []
273
+ for i in range(len(obs_lst)):
274
+ t_lst.append(
275
+ (
276
+ sid,
277
+ si,
278
+ sen,
279
+ "unspecified",
280
+ obs_lst[i],
281
+ label_lst[i],
282
+ obs_modify_lst[i],
283
+ obs_suggestive_lst[i],
284
+ )
285
+ )
286
+
287
+ # remove duplicates if existing
288
+ tuple_lst.append(list(set(t_lst)))
289
+
290
+ # flatten nested list
291
+ df_lst = [item for sublist in tuple_lst for item in sublist]
292
+ df_anatomy_label = pd.DataFrame(
293
+ df_lst,
294
+ columns=[
295
+ "study_id",
296
+ "sen_id",
297
+ "sentence",
298
+ "anatomy",
299
+ "observation",
300
+ "label",
301
+ "obs_modify",
302
+ "obs_suggestive",
303
+ ],
304
+ )
305
+
306
+ # lemmatize observation tokens (e.g., normalize opacities to opacity)
307
+ obs_lemma_lst = []
308
+ print("Lemmatizing observation tokens...")
309
+ for t in tqdm(df_lst):
310
+ obs = t[4]
311
+ obs_lemma = obs_lemmatization(obs)
312
+ obs_lemma_lst.append(obs_lemma)
313
+
314
+ # save preprocessed sentence level data
315
+ df_anatomy_label["obs_lemma"] = obs_lemma_lst
316
+ df_anatomy_label.to_csv(args.output_path, index=False)
317
+ print("Output file has been saved!")
318
+
319
+
320
+ if __name__ == "__main__":
321
+ args = parser.parse_args()
322
+ radgraph_parse(args)
PreTrain_MeDSLIP/dataset/dataset.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch.utils.data import DataLoader
3
+ import PIL
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+ import pandas as pd
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import random
10
+ from dataset.randaugment import RandomAugment
11
+
12
+
13
+ class MeDSLIP_Dataset(Dataset):
14
+ def __init__(self, csv_path, np_path, mode="train", num_neg_samples=7):
15
+ self.num_neg_samples = num_neg_samples
16
+ self.ann = json.load(open(csv_path, "r"))
17
+ self.img_path_list = list(self.ann)
18
+ self.anaomy_list = [
19
+ "trachea",
20
+ "left_hilar",
21
+ "right_hilar",
22
+ "hilar_unspec",
23
+ "left_pleural",
24
+ "right_pleural",
25
+ "pleural_unspec",
26
+ "heart_size",
27
+ "heart_border",
28
+ "left_diaphragm",
29
+ "right_diaphragm",
30
+ "diaphragm_unspec",
31
+ "retrocardiac",
32
+ "lower_left_lobe",
33
+ "upper_left_lobe",
34
+ "lower_right_lobe",
35
+ "middle_right_lobe",
36
+ "upper_right_lobe",
37
+ "left_lower_lung",
38
+ "left_mid_lung",
39
+ "left_upper_lung",
40
+ "left_apical_lung",
41
+ "left_lung_unspec",
42
+ "right_lower_lung",
43
+ "right_mid_lung",
44
+ "right_upper_lung",
45
+ "right_apical_lung",
46
+ "right_lung_unspec",
47
+ "lung_apices",
48
+ "lung_bases",
49
+ "left_costophrenic",
50
+ "right_costophrenic",
51
+ "costophrenic_unspec",
52
+ "cardiophrenic_sulcus",
53
+ "mediastinal",
54
+ "spine",
55
+ "clavicle",
56
+ "rib",
57
+ "stomach",
58
+ "right_atrium",
59
+ "right_ventricle",
60
+ "aorta",
61
+ "svc",
62
+ "interstitium",
63
+ "parenchymal",
64
+ "cavoatrial_junction",
65
+ "cardiopulmonary",
66
+ "pulmonary",
67
+ "lung_volumes",
68
+ "unspecified",
69
+ "other",
70
+ ]
71
+ self.obs_list = [
72
+ "normal",
73
+ "clear",
74
+ "sharp",
75
+ "sharply",
76
+ "unremarkable",
77
+ "intact",
78
+ "stable",
79
+ "free",
80
+ "effusion",
81
+ "opacity",
82
+ "pneumothorax",
83
+ "edema",
84
+ "atelectasis",
85
+ "tube",
86
+ "consolidation",
87
+ "process",
88
+ "abnormality",
89
+ "enlarge",
90
+ "tip",
91
+ "low",
92
+ "pneumonia",
93
+ "line",
94
+ "congestion",
95
+ "catheter",
96
+ "cardiomegaly",
97
+ "fracture",
98
+ "air",
99
+ "tortuous",
100
+ "lead",
101
+ "disease",
102
+ "calcification",
103
+ "prominence",
104
+ "device",
105
+ "engorgement",
106
+ "picc",
107
+ "clip",
108
+ "elevation",
109
+ "expand",
110
+ "nodule",
111
+ "wire",
112
+ "fluid",
113
+ "degenerative",
114
+ "pacemaker",
115
+ "thicken",
116
+ "marking",
117
+ "scar",
118
+ "hyperinflate",
119
+ "blunt",
120
+ "loss",
121
+ "widen",
122
+ "collapse",
123
+ "density",
124
+ "emphysema",
125
+ "aerate",
126
+ "mass",
127
+ "crowd",
128
+ "infiltrate",
129
+ "obscure",
130
+ "deformity",
131
+ "hernia",
132
+ "drainage",
133
+ "distention",
134
+ "shift",
135
+ "stent",
136
+ "pressure",
137
+ "lesion",
138
+ "finding",
139
+ "borderline",
140
+ "hardware",
141
+ "dilation",
142
+ "chf",
143
+ "redistribution",
144
+ "aspiration",
145
+ "tail_abnorm_obs",
146
+ "excluded_obs",
147
+ ]
148
+ self.rad_graph_results = np.load(np_path)
149
+ normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
150
+ if mode == "train":
151
+ self.transform = transforms.Compose(
152
+ [
153
+ transforms.RandomResizedCrop(
154
+ 224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
155
+ ),
156
+ transforms.RandomHorizontalFlip(),
157
+ RandomAugment(
158
+ 2,
159
+ 7,
160
+ isPIL=True,
161
+ augs=[
162
+ "Identity",
163
+ "AutoContrast",
164
+ "Equalize",
165
+ "Brightness",
166
+ "Sharpness",
167
+ "ShearX",
168
+ "ShearY",
169
+ "TranslateX",
170
+ "TranslateY",
171
+ "Rotate",
172
+ ],
173
+ ),
174
+ transforms.ToTensor(),
175
+ normalize,
176
+ ]
177
+ )
178
+ if mode == "test":
179
+ self.transform = transforms.Compose(
180
+ [
181
+ transforms.Resize([224, 224]),
182
+ transforms.ToTensor(),
183
+ normalize,
184
+ ]
185
+ )
186
+
187
+ def __getitem__(self, index):
188
+ img_path = self.img_path_list[index]
189
+ class_label = self.rad_graph_results[
190
+ self.ann[img_path]["labels_id"], :, :
191
+ ]
192
+ labels_pathology = np.zeros(class_label.shape[-1]) - 1
193
+ labels_anatomy = np.zeros(class_label.shape[0]) - 1
194
+ labels_pathology, index_list_pathology = self.triplet_extraction_pathology(
195
+ class_label
196
+ )
197
+ labels_anatomy, index_list_anatomy = self.triplet_extraction_anatomy(
198
+ class_label
199
+ )
200
+ index_list_pathology = np.array(index_list_pathology)
201
+ index_list_anatomy = np.array(index_list_anatomy)
202
+
203
+ img = PIL.Image.open(img_path).convert("RGB")
204
+ image = self.transform(img)
205
+
206
+ return {
207
+ "image": image,
208
+ "label_pathology": labels_pathology,
209
+ "index_pathology": index_list_pathology,
210
+ "label_anatomy": labels_anatomy,
211
+ "index_anatomy": index_list_anatomy,
212
+ "matrix": class_label,
213
+ }
214
+
215
+ def triplet_extraction_pathology(self, class_label):
216
+ """
217
+ This is for ProtoCL. Therefore, we need to extract anatomies to use in pathology stream.
218
+ """
219
+
220
+ exist_labels = np.zeros(class_label.shape[-1]) - 1
221
+ anatomy_list = []
222
+ for i in range(class_label.shape[1]):
223
+ temp_list = []
224
+ ### extract the exist label for each pathology and maintain -1 if not mentioned. ###
225
+ if 0 in class_label[:, i]:
226
+ exist_labels[i] = 0
227
+
228
+ if 1 in class_label[:, i]:
229
+ exist_labels[i] = 1
230
+ ### if the pathology exists try to get its anatomy.###
231
+ ### Note that, the contrastive loss will only be caculated on exist pathology as it is meaningless to predict their anatomy for the non-exist entities###
232
+ temp_list.append(-1)
233
+
234
+ try:
235
+ temp_list = temp_list + random.sample(
236
+ np.where(class_label[:, i] != 1)[0].tolist(),
237
+ self.num_neg_samples,
238
+ )
239
+ except:
240
+ print("fatal error")
241
+ if temp_list == []:
242
+ temp_list = temp_list + random.sample(
243
+ np.where(class_label[:, i] != 1)[0].tolist(),
244
+ self.num_neg_samples + 1,
245
+ )
246
+ anatomy_list.append(temp_list)
247
+
248
+ return exist_labels, anatomy_list
249
+
250
+ def triplet_extraction_anatomy(self, class_label):
251
+ """
252
+ This is for ProtoCL. Therefore, we need to extract pathological labels to use in anatomy stream.
253
+ """
254
+ exist_labels = np.zeros(class_label.shape[0]) - 1
255
+ pathology_list = []
256
+ for i in range(class_label.shape[0]):
257
+ temp_list = []
258
+ ### extract the exist label for each pathology and maintain -1 if not mentioned. ###
259
+ if 0 in class_label[i, :]:
260
+ exist_labels[i] = 0
261
+
262
+ if 1 in class_label[i, :]:
263
+ exist_labels[i] = 1
264
+ ### if the pathology exists try to get its anatomy.###
265
+ ### Note that, the contrastive loss will only be caculated on exist pathology as it is meaningless to predict their anatomy for the non-exist entities###
266
+ temp_list.append(-1)
267
+
268
+ try:
269
+ temp_list = temp_list + random.sample(
270
+ np.where(class_label[i, :] != 1)[0].tolist(),
271
+ self.num_neg_samples,
272
+ )
273
+ except:
274
+ print("fatal error")
275
+ if temp_list == []:
276
+ temp_list = temp_list + random.sample(
277
+ np.where(class_label[i, :] != 1)[0].tolist(),
278
+ self.num_neg_samples + 1,
279
+ )
280
+ pathology_list.append(temp_list)
281
+
282
+ return exist_labels, pathology_list
283
+
284
+ def __len__(self):
285
+ return len(self.ann)
286
+
287
+
288
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
289
+ loaders = []
290
+ for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
291
+ datasets, samplers, batch_size, num_workers, is_trains, collate_fns
292
+ ):
293
+ if is_train:
294
+ shuffle = sampler is None
295
+ drop_last = True
296
+ else:
297
+ shuffle = False
298
+ drop_last = False
299
+ loader = DataLoader(
300
+ dataset,
301
+ batch_size=bs,
302
+ num_workers=n_worker,
303
+ pin_memory=True,
304
+ sampler=sampler,
305
+ shuffle=shuffle,
306
+ collate_fn=collate_fn,
307
+ drop_last=drop_last,
308
+ )
309
+ loaders.append(loader)
310
+ return loaders
PreTrain_MeDSLIP/dataset/randaugment.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ """
12
+ same output as PIL.ImageOps.autocontrast
13
+ """
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ """
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ """
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0:
55
+ return ch
56
+ n = np.empty_like(hist)
57
+ n[0] = step // 2
58
+ n[1:] = hist[:-1]
59
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
60
+ return table[ch]
61
+
62
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
63
+ out = cv2.merge(channels)
64
+ return out
65
+
66
+
67
+ def rotate_func(img, degree, fill=(0, 0, 0)):
68
+ """
69
+ like PIL, rotate by degree, not radians
70
+ """
71
+ H, W = img.shape[0], img.shape[1]
72
+ center = W / 2, H / 2
73
+ M = cv2.getRotationMatrix2D(center, degree, 1)
74
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
75
+ return out
76
+
77
+
78
+ def solarize_func(img, thresh=128):
79
+ """
80
+ same output as PIL.ImageOps.posterize
81
+ """
82
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
83
+ table = table.clip(0, 255).astype(np.uint8)
84
+ out = table[img]
85
+ return out
86
+
87
+
88
+ def color_func(img, factor):
89
+ """
90
+ same output as PIL.ImageEnhance.Color
91
+ """
92
+ ## implementation according to PIL definition, quite slow
93
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
94
+ # out = blend(degenerate, img, factor)
95
+ # M = (
96
+ # np.eye(3) * factor
97
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
98
+ # )[np.newaxis, np.newaxis, :]
99
+ M = np.float32(
100
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
101
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
102
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
103
+ return out
104
+
105
+
106
+ def contrast_func(img, factor):
107
+ """
108
+ same output as PIL.ImageEnhance.Contrast
109
+ """
110
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
111
+ table = (
112
+ np.array([(el - mean) * factor + mean for el in range(256)])
113
+ .clip(0, 255)
114
+ .astype(np.uint8)
115
+ )
116
+ out = table[img]
117
+ return out
118
+
119
+
120
+ def brightness_func(img, factor):
121
+ """
122
+ same output as PIL.ImageEnhance.Contrast
123
+ """
124
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def sharpness_func(img, factor):
130
+ """
131
+ The differences the this result and PIL are all on the 4 boundaries, the center
132
+ areas are same
133
+ """
134
+ kernel = np.ones((3, 3), dtype=np.float32)
135
+ kernel[1][1] = 5
136
+ kernel /= 13
137
+ degenerate = cv2.filter2D(img, -1, kernel)
138
+ if factor == 0.0:
139
+ out = degenerate
140
+ elif factor == 1.0:
141
+ out = img
142
+ else:
143
+ out = img.astype(np.float32)
144
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
145
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
146
+ out = out.astype(np.uint8)
147
+ return out
148
+
149
+
150
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
151
+ H, W = img.shape[0], img.shape[1]
152
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
153
+ out = cv2.warpAffine(
154
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
155
+ ).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ """
161
+ same output as PIL.Image.transform
162
+ """
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(
166
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
167
+ ).astype(np.uint8)
168
+ return out
169
+
170
+
171
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
172
+ """
173
+ same output as PIL.Image.transform
174
+ """
175
+ H, W = img.shape[0], img.shape[1]
176
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
177
+ out = cv2.warpAffine(
178
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
179
+ ).astype(np.uint8)
180
+ return out
181
+
182
+
183
+ def posterize_func(img, bits):
184
+ """
185
+ same output as PIL.ImageOps.posterize
186
+ """
187
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
188
+ return out
189
+
190
+
191
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
192
+ H, W = img.shape[0], img.shape[1]
193
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
194
+ out = cv2.warpAffine(
195
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
196
+ ).astype(np.uint8)
197
+ return out
198
+
199
+
200
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
201
+ replace = np.array(replace, dtype=np.uint8)
202
+ H, W = img.shape[0], img.shape[1]
203
+ rh, rw = np.random.random(2)
204
+ pad_size = pad_size // 2
205
+ ch, cw = int(rh * H), int(rw * W)
206
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
207
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
208
+ out = img.copy()
209
+ out[x1:x2, y1:y2, :] = replace
210
+ return out
211
+
212
+
213
+ ### level to args
214
+ def enhance_level_to_args(MAX_LEVEL):
215
+ def level_to_args(level):
216
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
217
+
218
+ return level_to_args
219
+
220
+
221
+ def shear_level_to_args(MAX_LEVEL, replace_value):
222
+ def level_to_args(level):
223
+ level = (level / MAX_LEVEL) * 0.3
224
+ if np.random.random() > 0.5:
225
+ level = -level
226
+ return (level, replace_value)
227
+
228
+ return level_to_args
229
+
230
+
231
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
232
+ def level_to_args(level):
233
+ level = (level / MAX_LEVEL) * float(translate_const)
234
+ if np.random.random() > 0.5:
235
+ level = -level
236
+ return (level, replace_value)
237
+
238
+ return level_to_args
239
+
240
+
241
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
242
+ def level_to_args(level):
243
+ level = int((level / MAX_LEVEL) * cutout_const)
244
+ return (level, replace_value)
245
+
246
+ return level_to_args
247
+
248
+
249
+ def solarize_level_to_args(MAX_LEVEL):
250
+ def level_to_args(level):
251
+ level = int((level / MAX_LEVEL) * 256)
252
+ return (level,)
253
+
254
+ return level_to_args
255
+
256
+
257
+ def none_level_to_args(level):
258
+ return ()
259
+
260
+
261
+ def posterize_level_to_args(MAX_LEVEL):
262
+ def level_to_args(level):
263
+ level = int((level / MAX_LEVEL) * 4)
264
+ return (level,)
265
+
266
+ return level_to_args
267
+
268
+
269
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
270
+ def level_to_args(level):
271
+ level = (level / MAX_LEVEL) * 30
272
+ if np.random.random() < 0.5:
273
+ level = -level
274
+ return (level, replace_value)
275
+
276
+ return level_to_args
277
+
278
+
279
+ func_dict = {
280
+ "Identity": identity_func,
281
+ "AutoContrast": autocontrast_func,
282
+ "Equalize": equalize_func,
283
+ "Rotate": rotate_func,
284
+ "Solarize": solarize_func,
285
+ "Color": color_func,
286
+ "Contrast": contrast_func,
287
+ "Brightness": brightness_func,
288
+ "Sharpness": sharpness_func,
289
+ "ShearX": shear_x_func,
290
+ "TranslateX": translate_x_func,
291
+ "TranslateY": translate_y_func,
292
+ "Posterize": posterize_func,
293
+ "ShearY": shear_y_func,
294
+ }
295
+
296
+ translate_const = 10
297
+ MAX_LEVEL = 10
298
+ replace_value = (128, 128, 128)
299
+ arg_dict = {
300
+ "Identity": none_level_to_args,
301
+ "AutoContrast": none_level_to_args,
302
+ "Equalize": none_level_to_args,
303
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
304
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
305
+ "Color": enhance_level_to_args(MAX_LEVEL),
306
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
307
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
308
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
309
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
310
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
311
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
312
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
313
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
314
+ }
315
+
316
+
317
+ class RandomAugment(object):
318
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
319
+ self.N = N
320
+ self.M = M
321
+ self.isPIL = isPIL
322
+ if augs:
323
+ self.augs = augs
324
+ else:
325
+ self.augs = list(arg_dict.keys())
326
+
327
+ def get_random_ops(self):
328
+ sampled_ops = np.random.choice(self.augs, self.N)
329
+ return [(op, 0.5, self.M) for op in sampled_ops]
330
+
331
+ def __call__(self, img):
332
+ if self.isPIL:
333
+ img = np.array(img)
334
+ ops = self.get_random_ops()
335
+ for name, prob, level in ops:
336
+ if np.random.random() > prob:
337
+ continue
338
+ args = arg_dict[name](level)
339
+ img = func_dict[name](img, *args)
340
+ return img
341
+
342
+
343
+ if __name__ == "__main__":
344
+ a = RandomAugment()
345
+ img = np.random.randn(32, 32, 3)
346
+ a(img)
PreTrain_MeDSLIP/models/__init__.py ADDED
File without changes
PreTrain_MeDSLIP/models/model_MeDSLIP.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
2
+ from sklearn.metrics import log_loss
3
+ import torch.nn as nn
4
+ import torch
5
+ import math
6
+ import numpy as np
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ import torch.nn.functional as F
9
+ from .transformer import *
10
+ import torchvision.models as models
11
+ from einops import rearrange
12
+ from transformers import AutoModel
13
+
14
+ """
15
+ args.N
16
+ args.d_model
17
+ args.res_base_model
18
+ args.H
19
+ args.num_queries
20
+ args.dropout
21
+ args.attribute_set_size
22
+ """
23
+
24
+
25
+ class MeDSLIP(nn.Module):
26
+ def __init__(
27
+ self, config, anatomy_book, pathology_book, mode="train",
28
+ ):
29
+ super(MeDSLIP, self).__init__()
30
+ self.mode = mode
31
+ self.d_model = config["d_model"]
32
+ # """ book embedding"""
33
+ with torch.no_grad():
34
+ bert_model = self._get_bert_basemodel(
35
+ config["text_encoder"], freeze_layers=None
36
+ ).to(anatomy_book["input_ids"].device)
37
+ self.anatomy_book = bert_model(
38
+ input_ids=anatomy_book["input_ids"],
39
+ attention_mask=anatomy_book["attention_mask"],
40
+ ) # (**encoded_inputs)
41
+ self.anatomy_book = self.anatomy_book.last_hidden_state[:, 0, :]
42
+ self.pathology_book = bert_model(
43
+ input_ids=pathology_book["input_ids"],
44
+ attention_mask=pathology_book["attention_mask"],
45
+ ) # (**encoded_inputs)
46
+ self.pathology_book = self.pathology_book.last_hidden_state[:, 0, :]
47
+ self.pathology_embedding_layer = nn.Linear(768, 256)
48
+ self.cl_fc_pathology = nn.Linear(256, 768)
49
+
50
+ self.pathology_name = [
51
+ "normal",
52
+ "clear",
53
+ "sharp",
54
+ "sharply",
55
+ "unremarkable",
56
+ "intact",
57
+ "stable",
58
+ "free",
59
+ "effusion",
60
+ "opacity",
61
+ "pneumothorax",
62
+ "edema",
63
+ "atelectasis",
64
+ "tube",
65
+ "consolidation",
66
+ "process",
67
+ "abnormality",
68
+ "enlarge",
69
+ "tip",
70
+ "low",
71
+ "pneumonia",
72
+ "line",
73
+ "congestion",
74
+ "catheter",
75
+ "cardiomegaly",
76
+ "fracture",
77
+ "air",
78
+ "tortuous",
79
+ "lead",
80
+ "pathology",
81
+ "calcification",
82
+ "prominence",
83
+ "device",
84
+ "engorgement",
85
+ "picc",
86
+ "clip",
87
+ "elevation",
88
+ "expand",
89
+ "nodule",
90
+ "wire",
91
+ "fluid",
92
+ "degenerative",
93
+ "pacemaker",
94
+ "thicken",
95
+ "marking",
96
+ "scar",
97
+ "hyperinflate",
98
+ "blunt",
99
+ "loss",
100
+ "widen",
101
+ "coll_eapse",
102
+ "density",
103
+ "emphysema",
104
+ "aerate",
105
+ "mass",
106
+ "crowd",
107
+ "infiltrate",
108
+ "obscure",
109
+ "deformity",
110
+ "hernia",
111
+ "drainage",
112
+ "distention",
113
+ "shift",
114
+ "stent",
115
+ "pressure",
116
+ "lesion",
117
+ "finding",
118
+ "borderline",
119
+ "hardware",
120
+ "dilation",
121
+ "chf",
122
+ "redistribution",
123
+ "aspiration",
124
+ "tail_abnorm_obs",
125
+ "excluded_obs",
126
+ ]
127
+
128
+ self.excluded_pathology = [
129
+ "pneumonia",
130
+ "infiltrate",
131
+ "mass",
132
+ "nodule",
133
+ "emphysema",
134
+ "fibrosis",
135
+ "thicken",
136
+ "hernia",
137
+ ]
138
+
139
+ self.keep_class_dim_pathology = [
140
+ self.pathology_name.index(i)
141
+ for i in self.pathology_name
142
+ if i not in self.excluded_pathology
143
+ ]
144
+ """ visual backbone"""
145
+ self.resnet_dict = {
146
+ "resnet18": models.resnet18(pretrained=False),
147
+ "resnet50": models.resnet50(pretrained=False),
148
+ }
149
+ resnet = self._get_res_basemodel(config["res_base_model"])
150
+ num_ftrs = int(resnet.fc.in_features / 2)
151
+ self.res_features = nn.Sequential(*list(resnet.children())[:-3])
152
+
153
+ self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs)
154
+ self.res_l2_pathology = nn.Linear(num_ftrs, self.d_model)
155
+
156
+ self.cl_fc_anatomy = nn.Linear(256, 768)
157
+ self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs)
158
+ self.res_l2_anatomy = nn.Linear(num_ftrs, self.d_model)
159
+
160
+ self.mask_generator = nn.Linear(num_ftrs, num_ftrs)
161
+
162
+ ###################################
163
+ """ Query Decoder"""
164
+ ###################################
165
+
166
+ self.H = config["H"]
167
+ decoder_layer = TransformerDecoderLayer(
168
+ self.d_model, config["H"], 1024, 0.1, "relu", normalize_before=True
169
+ )
170
+ decoder_norm = nn.LayerNorm(self.d_model)
171
+ self.decoder_anatomy = TransformerDecoder(
172
+ decoder_layer, config["N"], decoder_norm, return_intermediate=False
173
+ )
174
+ self.decoder_pathology = TransformerDecoder(
175
+ decoder_layer, config["N"], decoder_norm, return_intermediate=False
176
+ )
177
+
178
+ # Learnable Queries
179
+ self.dropout_feas_anatomy = nn.Dropout(config["dropout"])
180
+ self.dropout_feas_pathology = nn.Dropout(config["dropout"])
181
+
182
+ # Attribute classifier
183
+ self.classifier_anatomy = nn.Linear(self.d_model, config["attribute_set_size"])
184
+ self.classifier_pathology = nn.Linear(
185
+ self.d_model, config["attribute_set_size"]
186
+ )
187
+
188
+ self.apply(self._init_weights)
189
+
190
+ def _get_res_basemodel(self, res_model_name):
191
+ try:
192
+ res_model = self.resnet_dict[res_model_name]
193
+ print("Image feature extractor:", res_model_name)
194
+ return res_model
195
+ except:
196
+ raise (
197
+ "Invalid model name. Check the config file and pass one of: resnet18 or resnet50"
198
+ )
199
+
200
+ def _get_bert_basemodel(self, bert_model_name, freeze_layers):
201
+ try:
202
+ model = AutoModel.from_pretrained(bert_model_name)
203
+ print("text feature extractor:", bert_model_name)
204
+ except:
205
+ raise (
206
+ "Invalid model name. Check the config file and pass a BERT model from transformers lybrary"
207
+ )
208
+
209
+ if freeze_layers is not None:
210
+ for layer_idx in freeze_layers:
211
+ for param in list(model.encoder.layer[layer_idx].parameters()):
212
+ param.requires_grad = False
213
+ return model
214
+
215
+ def image_encoder(self, xis):
216
+ # patch features
217
+ """
218
+ 16 torch.Size([16, 1024, 14, 14])
219
+ torch.Size([16, 196, 1024])
220
+ torch.Size([3136, 1024])
221
+ torch.Size([16, 196, 256])
222
+ """
223
+ batch_size = xis.shape[0]
224
+ res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num
225
+ res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d")
226
+ x = rearrange(res_fea, "b n d -> (b n) d")
227
+
228
+ mask = self.mask_generator(x)
229
+ x_pathology = mask * x
230
+ x_anatomy = (1 - mask) * x
231
+
232
+ x_pathology = self.res_l1_pathology(x_pathology)
233
+ x_anatomy = self.res_l1_anatomy(x_anatomy)
234
+ x_pathology = F.relu(x_pathology)
235
+ x_anatomy = F.relu(x_anatomy)
236
+
237
+ x_pathology = self.res_l2_pathology(x_pathology)
238
+ x_anatomy = self.res_l2_anatomy(x_anatomy)
239
+
240
+ out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size)
241
+ out_emb_anatomy = rearrange(x_anatomy, "(b n) d -> b n d", b=batch_size)
242
+ return out_emb_pathology, out_emb_anatomy
243
+
244
+ def forward(
245
+ self,
246
+ images,
247
+ labels_pathology=None,
248
+ labels_anatomy=None,
249
+ matrix=None,
250
+ sample_index_pathology=None,
251
+ sample_index_anatomy=None,
252
+ is_train=True,
253
+ text_gen=False,
254
+ no_cl=False,
255
+ exclude_class=False,
256
+ ):
257
+
258
+ B = images.shape[0]
259
+ device = images.device
260
+ """ Visual Backbone """
261
+ x_pathology, x_anatomy = self.image_encoder(images) # batch_size,patch_num,dim
262
+
263
+ features_pathology = x_pathology.transpose(0, 1) # patch_num b dim
264
+ features_anatomy = x_anatomy.transpose(0, 1) # patch_num b dim
265
+
266
+ query_embed_pathology = self.pathology_embedding_layer(self.pathology_book)
267
+ query_embed_anatomy = self.pathology_embedding_layer(self.anatomy_book)
268
+ query_embed_pathology = query_embed_pathology.unsqueeze(1).repeat(1, B, 1)
269
+ query_embed_anatomy = query_embed_anatomy.unsqueeze(1).repeat(1, B, 1)
270
+
271
+ features_pathology, ws_pathology = self.decoder_pathology(
272
+ query_embed_pathology,
273
+ features_pathology,
274
+ memory_key_padding_mask=None,
275
+ pos=None,
276
+ query_pos=None,
277
+ )
278
+ features_anatomy, ws_anatomy = self.decoder_anatomy(
279
+ query_embed_anatomy,
280
+ features_anatomy,
281
+ memory_key_padding_mask=None,
282
+ pos=None,
283
+ query_pos=None,
284
+ )
285
+
286
+ ap_pathology = features_pathology
287
+ ap_anatomy = features_anatomy
288
+
289
+ ap_logits = torch.bmm(
290
+ ap_pathology.transpose(0, 1), ap_anatomy.transpose(0, 1).transpose(1, 2)
291
+ ).transpose(
292
+ 1, 2
293
+ )
294
+ if text_gen:
295
+ output_logits = ap_logits
296
+ matrix_zero = matrix
297
+
298
+ masks = matrix_zero >= 0
299
+ ap_logits = ap_logits[masks]
300
+ matrix_zero = matrix_zero[masks]
301
+
302
+ loss_ap = F.binary_cross_entropy_with_logits(
303
+ ap_logits.float(), matrix_zero.float()
304
+ )
305
+
306
+ out_pathology = self.dropout_feas_pathology(features_pathology)
307
+ out_anatomy = self.dropout_feas_anatomy(features_anatomy)
308
+
309
+ if is_train == True and no_cl == False:
310
+
311
+ # get anatomytomy query
312
+ anatomytomy_query = torch.zeros(
313
+ [
314
+ sample_index_pathology.shape[0],
315
+ sample_index_pathology.shape[1],
316
+ sample_index_pathology.shape[2],
317
+ self.anatomy_book.shape[-1],
318
+ ]
319
+ ).to(
320
+ device
321
+ )
322
+ entity_query = torch.zeros(
323
+ [
324
+ sample_index_anatomy.shape[0],
325
+ sample_index_anatomy.shape[1],
326
+ sample_index_anatomy.shape[2],
327
+ self.pathology_book.shape[-1],
328
+ ]
329
+ ).to(device)
330
+
331
+ anatomytomy_query = self.anatomy_book[sample_index_pathology, :] * (
332
+ sample_index_pathology != -1
333
+ ).int().unsqueeze(-1).repeat(
334
+ 1, 1, 1, 768
335
+ ) # batch, Q , position_num ,dim
336
+ entity_query = self.pathology_book[sample_index_anatomy, :] * (
337
+ sample_index_anatomy != -1
338
+ ).int().unsqueeze(-1).repeat(1, 1, 1, 768)
339
+
340
+ matrix_zero_pathology = matrix
341
+ matrix_zero_anatomy = matrix.transpose(1, 2)
342
+ matrix_zero_pathology[matrix_zero_pathology < 1] = 0
343
+ matrix_zero_anatomy[matrix_zero_anatomy < 1] = 0
344
+ matrix_zero_pathology = matrix_zero_pathology.unsqueeze(3).repeat(
345
+ 1, 1, 1, anatomytomy_query.shape[-1]
346
+ )
347
+ matrix_zero_anatomy = matrix_zero_anatomy.unsqueeze(3).repeat(
348
+ 1, 1, 1, entity_query.shape[-1]
349
+ )
350
+
351
+ anatomy_temp = self.anatomy_book
352
+ pathology_temp = self.pathology_book
353
+ anatomy_temp = anatomy_temp.unsqueeze(0).repeat(
354
+ anatomytomy_query.shape[0], 1, 1
355
+ )
356
+ pathology_temp = pathology_temp.unsqueeze(0).repeat(
357
+ entity_query.shape[0], 1, 1
358
+ )
359
+ anatomy_temp = anatomy_temp.unsqueeze(2).repeat(
360
+ 1, 1, anatomytomy_query.shape[1], 1
361
+ )
362
+ pathology_temp = pathology_temp.unsqueeze(2).repeat(
363
+ 1, 1, entity_query.shape[1], 1
364
+ )
365
+
366
+ posi_matrix_pathology = (matrix_zero_pathology * anatomy_temp).transpose(
367
+ 1, 2
368
+ )
369
+ posi_matrix_anatomy = (matrix_zero_anatomy * pathology_temp).transpose(1, 2)
370
+
371
+ for i in range(anatomytomy_query.shape[0]):
372
+ for j in range(anatomytomy_query.shape[1]):
373
+ if (posi_matrix_pathology[i, j] != 0).sum() > 0:
374
+ num_posi = (
375
+ torch.nonzero(posi_matrix_pathology[i, j], as_tuple=True)[0]
376
+ .unique()
377
+ .shape[0]
378
+ )
379
+ assert anatomytomy_query[i, j, 0, :].sum() == 0
380
+ anatomytomy_query[i, j, 0, :] = (
381
+ posi_matrix_pathology[i, j, :, :].sum(dim=0) / num_posi
382
+ )
383
+
384
+ for i in range(entity_query.shape[0]):
385
+ for j in range(entity_query.shape[1]):
386
+ if (posi_matrix_anatomy[i, j] != 0).sum() > 0:
387
+ num_posi = (
388
+ torch.nonzero(posi_matrix_anatomy[i, j], as_tuple=True)[0]
389
+ .unique()
390
+ .shape[0]
391
+ )
392
+ assert entity_query[i, j, 0, :].sum() == 0
393
+ entity_query[i, j, 0, :] = (
394
+ posi_matrix_anatomy[i, j, :, :].sum(dim=0) / num_posi
395
+ )
396
+ # Got anatomytomy query
397
+
398
+ # [Q,B,A]
399
+ ll_pathology = out_pathology.transpose(0, 1) # B Q A
400
+ ll_anatomy = out_anatomy.transpose(0, 1) # B Q A
401
+
402
+ Q_pathology = ll_pathology.shape[1]
403
+ Q_anatomy = ll_anatomy.shape[1]
404
+
405
+ ll_pathology = ll_pathology.reshape(
406
+ ll_pathology.shape[0] * ll_pathology.shape[1], -1
407
+ )
408
+ ll_anatomy = ll_anatomy.reshape(
409
+ ll_anatomy.shape[0] * ll_anatomy.shape[1], -1
410
+ )
411
+
412
+ ll_pathology = self.cl_fc_pathology(ll_pathology)
413
+ ll_anatomy = self.cl_fc_anatomy(ll_anatomy)
414
+
415
+ ll_pathology = ll_pathology.unsqueeze(dim=-1)
416
+ ll_anatomy = ll_anatomy.unsqueeze(dim=-1)
417
+
418
+ anatomytomy_query = anatomytomy_query.reshape(B * Q_pathology, 8, 768)
419
+ entity_query = entity_query.reshape(B * Q_anatomy, 8, 768)
420
+
421
+ ll_pathology = torch.bmm(
422
+ anatomytomy_query, ll_pathology
423
+ ).squeeze() # B Q position_num
424
+ ll_anatomy = torch.bmm(
425
+ entity_query, ll_anatomy
426
+ ).squeeze() # B Q position_num
427
+
428
+ cl_labels_pathology = torch.zeros((ll_pathology.shape[0])).to(device)
429
+ cl_labels_anatomy = torch.zeros((ll_anatomy.shape[0])).to(device)
430
+
431
+ if exclude_class == True:
432
+ cl_labels_pathology = cl_labels_pathology.reshape(B, Q_pathology)
433
+ cl_labels_anatomy = cl_labels_anatomy.reshape(B, Q_anatomy)
434
+
435
+ cl_labels_pathology = cl_labels_pathology[
436
+ :, self.keep_class_dim_pathology
437
+ ]
438
+ cl_labels_anatomy = cl_labels_anatomy[:, self.keep_class_dim_pathology]
439
+
440
+ cl_labels_pathology = cl_labels_pathology.reshape(-1)
441
+ cl_labels_anatomy = cl_labels_anatomy.reshape(-1)
442
+
443
+ ll_pathology = ll_pathology.reshape(B, Q_pathology, -1)
444
+ ll_anatomy = ll_anatomy.reshape(B, Q_anatomy, -1)
445
+
446
+ ll_pathology = ll_pathology[:, self.keep_class_dim_pathology, :]
447
+ ll_pathology = ll_pathology.reshape(
448
+ B * (len(self.keep_class_dim_pathology)), -1
449
+ )
450
+ ll_anatomy = ll_anatomy.reshape(B * Q_anatomy, -1)
451
+
452
+ x_pathology = self.classifier_pathology(out_pathology).transpose(0, 1)
453
+ x_anatomy = self.classifier_anatomy(out_anatomy).transpose(
454
+ 0, 1
455
+ ) # B query Atributes
456
+
457
+ if exclude_class == True:
458
+ labels_pathology = labels_pathology[:, self.keep_class_dim_pathology]
459
+ x_pathology = x_pathology[:, self.keep_class_dim_pathology, :]
460
+
461
+ labels_pathology = labels_pathology.reshape(-1, 1)
462
+ labels_anatomy = labels_anatomy.reshape(-1, 1)
463
+ logits_pathology = x_pathology.reshape(-1, x_pathology.shape[-1])
464
+ logits_anatomy = x_anatomy.reshape(-1, x_anatomy.shape[-1])
465
+ Mask_pathology = ((labels_pathology != -1) & (labels_pathology != 2)).squeeze()
466
+ Mask_anatomy = ((labels_anatomy != -1) & (labels_anatomy != 2)).squeeze()
467
+
468
+ cl_mask_pathology = (labels_pathology == 1).squeeze()
469
+ cl_mask_anatomy = (labels_anatomy == 1).squeeze()
470
+ if is_train == True:
471
+ labels_pathology = labels_pathology[Mask_pathology].long()
472
+ labels_anatomy = labels_anatomy[Mask_anatomy].long()
473
+ logits_pathology = logits_pathology[Mask_pathology]
474
+ logits_anatomy = logits_anatomy[Mask_anatomy]
475
+ loss_ce_pathology = F.cross_entropy(
476
+ logits_pathology, labels_pathology[:, 0]
477
+ )
478
+ loss_ce_anatomy = F.cross_entropy(logits_anatomy, labels_anatomy[:, 0])
479
+ if no_cl == False:
480
+ cl_labels_pathology = cl_labels_pathology[cl_mask_pathology].long()
481
+ cl_labels_anatomy = cl_labels_anatomy[cl_mask_anatomy].long()
482
+ ll_pathology = ll_pathology[cl_mask_pathology]
483
+ ll_anatomy = ll_anatomy[cl_mask_anatomy]
484
+ loss_cl_pathology = F.cross_entropy(ll_pathology, cl_labels_pathology)
485
+ loss_cl_anatomy = F.cross_entropy(ll_anatomy, cl_labels_anatomy)
486
+ loss_ce = loss_ce_pathology + loss_ce_anatomy
487
+ loss_cl = loss_cl_pathology + loss_cl_anatomy
488
+ loss = loss_ce + loss_cl + loss_ap
489
+ else:
490
+ loss_cl = torch.tensor(0)
491
+ loss = loss_ce_pathology + loss_ce_anatomy + loss_ap
492
+ else:
493
+ loss = 0
494
+ if is_train == True:
495
+ if text_gen:
496
+ return (
497
+ loss,
498
+ x_pathology,
499
+ ws_pathology,
500
+ x_anatomy,
501
+ ws_anatomy,
502
+ output_logits,
503
+ )
504
+ else:
505
+ return (
506
+ loss,
507
+ loss_ce_pathology,
508
+ loss_cl_pathology,
509
+ loss_ce_anatomy,
510
+ loss_cl_anatomy,
511
+ loss_ap,
512
+ )
513
+ else:
514
+ return loss, x_pathology, ws_pathology, x_anatomy, ws_anatomy
515
+
516
+ @staticmethod
517
+ def _init_weights(module):
518
+ r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0."""
519
+
520
+ if isinstance(module, nn.Linear):
521
+ module.weight.data.normal_(mean=0.0, std=0.02)
522
+
523
+ elif isinstance(module, nn.MultiheadAttention):
524
+ module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
525
+ module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
526
+
527
+ elif isinstance(module, nn.Embedding):
528
+ module.weight.data.normal_(mean=0.0, std=0.02)
529
+ if module.padding_idx is not None:
530
+ module.weight.data[module.padding_idx].zero_()
PreTrain_MeDSLIP/models/tokenization_bert.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bert."""
16
+
17
+
18
+ import collections
19
+ import os
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from transformers.tokenization_utils import (
24
+ PreTrainedTokenizer,
25
+ _is_control,
26
+ _is_punctuation,
27
+ _is_whitespace,
28
+ )
29
+ from transformers.utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {
37
+ "vocab_file": {
38
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
39
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
40
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
41
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
42
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
43
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
44
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
45
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
46
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
47
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
48
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
49
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
50
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
51
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
52
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
53
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
54
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
55
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
56
+ }
57
+ }
58
+
59
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
60
+ "bert-base-uncased": 512,
61
+ "bert-large-uncased": 512,
62
+ "bert-base-cased": 512,
63
+ "bert-large-cased": 512,
64
+ "bert-base-multilingual-uncased": 512,
65
+ "bert-base-multilingual-cased": 512,
66
+ "bert-base-chinese": 512,
67
+ "bert-base-german-cased": 512,
68
+ "bert-large-uncased-whole-word-masking": 512,
69
+ "bert-large-cased-whole-word-masking": 512,
70
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
71
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
72
+ "bert-base-cased-finetuned-mrpc": 512,
73
+ "bert-base-german-dbmdz-cased": 512,
74
+ "bert-base-german-dbmdz-uncased": 512,
75
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
76
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
77
+ "wietsedv/bert-base-dutch-cased": 512,
78
+ }
79
+
80
+ PRETRAINED_INIT_CONFIGURATION = {
81
+ "bert-base-uncased": {"do_lower_case": True},
82
+ "bert-large-uncased": {"do_lower_case": True},
83
+ "bert-base-cased": {"do_lower_case": False},
84
+ "bert-large-cased": {"do_lower_case": False},
85
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
86
+ "bert-base-multilingual-cased": {"do_lower_case": False},
87
+ "bert-base-chinese": {"do_lower_case": False},
88
+ "bert-base-german-cased": {"do_lower_case": False},
89
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
90
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
91
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
92
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
93
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
94
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
95
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
96
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
97
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
98
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
99
+ }
100
+
101
+
102
+ def load_vocab(vocab_file):
103
+ """Loads a vocabulary file into a dictionary."""
104
+ vocab = collections.OrderedDict()
105
+ with open(vocab_file, "r", encoding="utf-8") as reader:
106
+ tokens = reader.readlines()
107
+ for index, token in enumerate(tokens):
108
+ token = token.rstrip("\n")
109
+ vocab[token] = index
110
+ return vocab
111
+
112
+
113
+ def whitespace_tokenize(text):
114
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
115
+ text = text.strip()
116
+ if not text:
117
+ return []
118
+ tokens = text.split()
119
+ return tokens
120
+
121
+
122
+ class BertTokenizer(PreTrainedTokenizer):
123
+ r"""
124
+ Construct a BERT tokenizer. Based on WordPiece.
125
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
126
+ Users should refer to this superclass for more information regarding those methods.
127
+ Args:
128
+ vocab_file (:obj:`str`):
129
+ File containing the vocabulary.
130
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
131
+ Whether or not to lowercase the input when tokenizing.
132
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
133
+ Whether or not to do basic tokenization before WordPiece.
134
+ never_split (:obj:`Iterable`, `optional`):
135
+ Collection of tokens which will never be split during tokenization. Only has an effect when
136
+ :obj:`do_basic_tokenize=True`
137
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
138
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
139
+ token instead.
140
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
141
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
142
+ sequence classification or for a text and a question for question answering. It is also used as the last
143
+ token of a sequence built with special tokens.
144
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
145
+ The token used for padding, for example when batching sequences of different lengths.
146
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
147
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
148
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
149
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
150
+ The token used for masking values. This is the token used when training this model with masked language
151
+ modeling. This is the token which the model will try to predict.
152
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
153
+ Whether or not to tokenize Chinese characters.
154
+ This should likely be deactivated for Japanese (see this `issue
155
+ <https://github.com/huggingface/transformers/issues/328>`__).
156
+ strip_accents: (:obj:`bool`, `optional`):
157
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
158
+ value for :obj:`lowercase` (as in the original BERT).
159
+ """
160
+
161
+ vocab_files_names = VOCAB_FILES_NAMES
162
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
163
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
164
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
165
+
166
+ def __init__(
167
+ self,
168
+ vocab_file,
169
+ do_lower_case=True,
170
+ do_basic_tokenize=True,
171
+ never_split=None,
172
+ unk_token="[UNK]",
173
+ sep_token="[SEP]",
174
+ pad_token="[PAD]",
175
+ cls_token="[CLS]",
176
+ mask_token="[MASK]",
177
+ tokenize_chinese_chars=True,
178
+ strip_accents=None,
179
+ **kwargs
180
+ ):
181
+ super().__init__(
182
+ do_lower_case=do_lower_case,
183
+ do_basic_tokenize=do_basic_tokenize,
184
+ never_split=never_split,
185
+ unk_token=unk_token,
186
+ sep_token=sep_token,
187
+ pad_token=pad_token,
188
+ cls_token=cls_token,
189
+ mask_token=mask_token,
190
+ tokenize_chinese_chars=tokenize_chinese_chars,
191
+ strip_accents=strip_accents,
192
+ **kwargs,
193
+ )
194
+
195
+ if not os.path.isfile(vocab_file):
196
+ raise ValueError(
197
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
198
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
199
+ vocab_file
200
+ )
201
+ )
202
+ self.vocab = load_vocab(vocab_file)
203
+ self.ids_to_tokens = collections.OrderedDict(
204
+ [(ids, tok) for tok, ids in self.vocab.items()]
205
+ )
206
+ self.do_basic_tokenize = do_basic_tokenize
207
+ if do_basic_tokenize:
208
+ self.basic_tokenizer = BasicTokenizer(
209
+ do_lower_case=do_lower_case,
210
+ never_split=never_split,
211
+ tokenize_chinese_chars=tokenize_chinese_chars,
212
+ strip_accents=strip_accents,
213
+ )
214
+ self.wordpiece_tokenizer = WordpieceTokenizer(
215
+ vocab=self.vocab, unk_token=self.unk_token
216
+ )
217
+
218
+ @property
219
+ def do_lower_case(self):
220
+ return self.basic_tokenizer.do_lower_case
221
+
222
+ @property
223
+ def vocab_size(self):
224
+ return len(self.vocab)
225
+
226
+ def get_vocab(self):
227
+ return dict(self.vocab, **self.added_tokens_encoder)
228
+
229
+ def _tokenize(self, text):
230
+ split_tokens = []
231
+ if self.do_basic_tokenize:
232
+ for token in self.basic_tokenizer.tokenize(
233
+ text, never_split=self.all_special_tokens
234
+ ):
235
+
236
+ # If the token is part of the never_split set
237
+ if token in self.basic_tokenizer.never_split:
238
+ split_tokens.append(token)
239
+ else:
240
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
241
+ else:
242
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
243
+ return split_tokens
244
+
245
+ def _convert_token_to_id(self, token):
246
+ """ Converts a token (str) in an id using the vocab. """
247
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
248
+
249
+ def _convert_id_to_token(self, index):
250
+ """Converts an index (integer) in a token (str) using the vocab."""
251
+ return self.ids_to_tokens.get(index, self.unk_token)
252
+
253
+ def convert_tokens_to_string(self, tokens):
254
+ """ Converts a sequence of tokens (string) in a single string. """
255
+ out_string = " ".join(tokens).replace(" ##", "").strip()
256
+ return out_string
257
+
258
+ def build_inputs_with_special_tokens(
259
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
260
+ ) -> List[int]:
261
+ """
262
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
263
+ adding special tokens. A BERT sequence has the following format:
264
+ - single sequence: ``[CLS] X ``
265
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
266
+ Args:
267
+ token_ids_0 (:obj:`List[int]`):
268
+ List of IDs to which the special tokens will be added.
269
+ token_ids_1 (:obj:`List[int]`, `optional`):
270
+ Optional second list of IDs for sequence pairs.
271
+ Returns:
272
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
273
+ """
274
+ if token_ids_1 is None:
275
+ return [self.cls_token_id] + token_ids_0
276
+ cls = [self.cls_token_id]
277
+ sep = [self.sep_token_id]
278
+ return cls + token_ids_0 + sep + token_ids_1 + sep
279
+
280
+ def get_special_tokens_mask(
281
+ self,
282
+ token_ids_0: List[int],
283
+ token_ids_1: Optional[List[int]] = None,
284
+ already_has_special_tokens: bool = False,
285
+ ) -> List[int]:
286
+ """
287
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
288
+ special tokens using the tokenizer ``prepare_for_model`` method.
289
+ Args:
290
+ token_ids_0 (:obj:`List[int]`):
291
+ List of IDs.
292
+ token_ids_1 (:obj:`List[int]`, `optional`):
293
+ Optional second list of IDs for sequence pairs.
294
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
295
+ Whether or not the token list is already formatted with special tokens for the model.
296
+ Returns:
297
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
298
+ """
299
+
300
+ if already_has_special_tokens:
301
+ if token_ids_1 is not None:
302
+ raise ValueError(
303
+ "You should not supply a second sequence if the provided sequence of "
304
+ "ids is already formatted with special tokens for the model."
305
+ )
306
+ return list(
307
+ map(
308
+ lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
309
+ token_ids_0,
310
+ )
311
+ )
312
+
313
+ if token_ids_1 is not None:
314
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
315
+ return [1] + ([0] * len(token_ids_0)) + [1]
316
+
317
+ def create_token_type_ids_from_sequences(
318
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
319
+ ) -> List[int]:
320
+ """
321
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
322
+ pair mask has the following format:
323
+ ::
324
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
325
+ | first sequence | second sequence |
326
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
327
+ Args:
328
+ token_ids_0 (:obj:`List[int]`):
329
+ List of IDs.
330
+ token_ids_1 (:obj:`List[int]`, `optional`):
331
+ Optional second list of IDs for sequence pairs.
332
+ Returns:
333
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
334
+ sequence(s).
335
+ """
336
+ sep = [self.sep_token_id]
337
+ cls = [self.cls_token_id]
338
+ if token_ids_1 is None:
339
+ return len(cls + token_ids_0 + sep) * [0]
340
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
341
+
342
+ def save_vocabulary(
343
+ self, save_directory: str, filename_prefix: Optional[str] = None
344
+ ) -> Tuple[str]:
345
+ index = 0
346
+ if os.path.isdir(save_directory):
347
+ vocab_file = os.path.join(
348
+ save_directory,
349
+ (filename_prefix + "-" if filename_prefix else "")
350
+ + VOCAB_FILES_NAMES["vocab_file"],
351
+ )
352
+ else:
353
+ vocab_file = (
354
+ filename_prefix + "-" if filename_prefix else ""
355
+ ) + save_directory
356
+ with open(vocab_file, "w", encoding="utf-8") as writer:
357
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
358
+ if index != token_index:
359
+ logger.warning(
360
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
361
+ " Please check that the vocabulary is not corrupted!".format(
362
+ vocab_file
363
+ )
364
+ )
365
+ index = token_index
366
+ writer.write(token + "\n")
367
+ index += 1
368
+ return (vocab_file,)
369
+
370
+
371
+ class BasicTokenizer(object):
372
+ """
373
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
374
+ Args:
375
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
376
+ Whether or not to lowercase the input when tokenizing.
377
+ never_split (:obj:`Iterable`, `optional`):
378
+ Collection of tokens which will never be split during tokenization. Only has an effect when
379
+ :obj:`do_basic_tokenize=True`
380
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
381
+ Whether or not to tokenize Chinese characters.
382
+ This should likely be deactivated for Japanese (see this `issue
383
+ <https://github.com/huggingface/transformers/issues/328>`__).
384
+ strip_accents: (:obj:`bool`, `optional`):
385
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
386
+ value for :obj:`lowercase` (as in the original BERT).
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ do_lower_case=True,
392
+ never_split=None,
393
+ tokenize_chinese_chars=True,
394
+ strip_accents=None,
395
+ ):
396
+ if never_split is None:
397
+ never_split = []
398
+ self.do_lower_case = do_lower_case
399
+ self.never_split = set(never_split)
400
+ self.tokenize_chinese_chars = tokenize_chinese_chars
401
+ self.strip_accents = strip_accents
402
+
403
+ def tokenize(self, text, never_split=None):
404
+ """
405
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
406
+ WordPieceTokenizer.
407
+ Args:
408
+ **never_split**: (`optional`) list of str
409
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
410
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
411
+ """
412
+ # union() returns a new set by concatenating the two sets.
413
+ never_split = (
414
+ self.never_split.union(set(never_split))
415
+ if never_split
416
+ else self.never_split
417
+ )
418
+ text = self._clean_text(text)
419
+
420
+ # This was added on November 1st, 2018 for the multilingual and Chinese
421
+ # models. This is also applied to the English models now, but it doesn't
422
+ # matter since the English models were not trained on any Chinese data
423
+ # and generally don't have any Chinese data in them (there are Chinese
424
+ # characters in the vocabulary because Wikipedia does have some Chinese
425
+ # words in the English Wikipedia.).
426
+ if self.tokenize_chinese_chars:
427
+ text = self._tokenize_chinese_chars(text)
428
+ orig_tokens = whitespace_tokenize(text)
429
+ split_tokens = []
430
+ for token in orig_tokens:
431
+ if token not in never_split:
432
+ if self.do_lower_case:
433
+ token = token.lower()
434
+ if self.strip_accents is not False:
435
+ token = self._run_strip_accents(token)
436
+ elif self.strip_accents:
437
+ token = self._run_strip_accents(token)
438
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
439
+
440
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
441
+ return output_tokens
442
+
443
+ def _run_strip_accents(self, text):
444
+ """Strips accents from a piece of text."""
445
+ text = unicodedata.normalize("NFD", text)
446
+ output = []
447
+ for char in text:
448
+ cat = unicodedata.category(char)
449
+ if cat == "Mn":
450
+ continue
451
+ output.append(char)
452
+ return "".join(output)
453
+
454
+ def _run_split_on_punc(self, text, never_split=None):
455
+ """Splits punctuation on a piece of text."""
456
+ if never_split is not None and text in never_split:
457
+ return [text]
458
+ chars = list(text)
459
+ i = 0
460
+ start_new_word = True
461
+ output = []
462
+ while i < len(chars):
463
+ char = chars[i]
464
+ if _is_punctuation(char):
465
+ output.append([char])
466
+ start_new_word = True
467
+ else:
468
+ if start_new_word:
469
+ output.append([])
470
+ start_new_word = False
471
+ output[-1].append(char)
472
+ i += 1
473
+
474
+ return ["".join(x) for x in output]
475
+
476
+ def _tokenize_chinese_chars(self, text):
477
+ """Adds whitespace around any CJK character."""
478
+ output = []
479
+ for char in text:
480
+ cp = ord(char)
481
+ if self._is_chinese_char(cp):
482
+ output.append(" ")
483
+ output.append(char)
484
+ output.append(" ")
485
+ else:
486
+ output.append(char)
487
+ return "".join(output)
488
+
489
+ def _is_chinese_char(self, cp):
490
+ """Checks whether CP is the codepoint of a CJK character."""
491
+ # This defines a "chinese character" as anything in the CJK Unicode block:
492
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
493
+ #
494
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
495
+ # despite its name. The modern Korean Hangul alphabet is a different block,
496
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
497
+ # space-separated words, so they are not treated specially and handled
498
+ # like the all of the other languages.
499
+ if (
500
+ (cp >= 0x4E00 and cp <= 0x9FFF)
501
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
502
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
503
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
504
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
505
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
506
+ or (cp >= 0xF900 and cp <= 0xFAFF)
507
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
508
+ ): #
509
+ return True
510
+
511
+ return False
512
+
513
+ def _clean_text(self, text):
514
+ """Performs invalid character removal and whitespace cleanup on text."""
515
+ output = []
516
+ for char in text:
517
+ cp = ord(char)
518
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
519
+ continue
520
+ if _is_whitespace(char):
521
+ output.append(" ")
522
+ else:
523
+ output.append(char)
524
+ return "".join(output)
525
+
526
+
527
+ class WordpieceTokenizer(object):
528
+ """Runs WordPiece tokenization."""
529
+
530
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
531
+ self.vocab = vocab
532
+ self.unk_token = unk_token
533
+ self.max_input_chars_per_word = max_input_chars_per_word
534
+
535
+ def tokenize(self, text):
536
+ """
537
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
538
+ tokenization using the given vocabulary.
539
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
540
+ Args:
541
+ text: A single token or whitespace separated tokens. This should have
542
+ already been passed through `BasicTokenizer`.
543
+ Returns:
544
+ A list of wordpiece tokens.
545
+ """
546
+
547
+ output_tokens = []
548
+ for token in whitespace_tokenize(text):
549
+ chars = list(token)
550
+ if len(chars) > self.max_input_chars_per_word:
551
+ output_tokens.append(self.unk_token)
552
+ continue
553
+
554
+ is_bad = False
555
+ start = 0
556
+ sub_tokens = []
557
+ while start < len(chars):
558
+ end = len(chars)
559
+ cur_substr = None
560
+ while start < end:
561
+ substr = "".join(chars[start:end])
562
+ if start > 0:
563
+ substr = "##" + substr
564
+ if substr in self.vocab:
565
+ cur_substr = substr
566
+ break
567
+ end -= 1
568
+ if cur_substr is None:
569
+ is_bad = True
570
+ break
571
+ sub_tokens.append(cur_substr)
572
+ start = end
573
+
574
+ if is_bad:
575
+ output_tokens.append(self.unk_token)
576
+ else:
577
+ output_tokens.extend(sub_tokens)
578
+ return output_tokens
PreTrain_MeDSLIP/models/transformer.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code modified from DETR tranformer:
3
+ https://github.com/facebookresearch/detr
4
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
5
+ """
6
+
7
+ import copy
8
+ from typing import Optional, List
9
+ import pickle as cp
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn, Tensor
14
+
15
+
16
+ class TransformerDecoder(nn.Module):
17
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
18
+ super().__init__()
19
+ self.layers = _get_clones(decoder_layer, num_layers)
20
+ self.num_layers = num_layers
21
+ self.norm = norm
22
+ self.return_intermediate = return_intermediate
23
+
24
+ def forward(
25
+ self,
26
+ tgt,
27
+ memory,
28
+ tgt_mask: Optional[Tensor] = None,
29
+ memory_mask: Optional[Tensor] = None,
30
+ tgt_key_padding_mask: Optional[Tensor] = None,
31
+ memory_key_padding_mask: Optional[Tensor] = None,
32
+ pos: Optional[Tensor] = None,
33
+ query_pos: Optional[Tensor] = None,
34
+ ):
35
+ output = tgt
36
+ T, B, C = memory.shape
37
+ intermediate = []
38
+ atten_layers = []
39
+ for n, layer in enumerate(self.layers):
40
+
41
+ residual = True
42
+ output, ws = layer(
43
+ output,
44
+ memory,
45
+ tgt_mask=tgt_mask,
46
+ memory_mask=memory_mask,
47
+ tgt_key_padding_mask=tgt_key_padding_mask,
48
+ memory_key_padding_mask=memory_key_padding_mask,
49
+ pos=pos,
50
+ query_pos=query_pos,
51
+ residual=residual,
52
+ )
53
+ atten_layers.append(ws)
54
+ if self.return_intermediate:
55
+ intermediate.append(self.norm(output))
56
+ if self.norm is not None:
57
+ output = self.norm(output)
58
+ if self.return_intermediate:
59
+ intermediate.pop()
60
+ intermediate.append(output)
61
+
62
+ if self.return_intermediate:
63
+ return torch.stack(intermediate)
64
+ return output, atten_layers
65
+
66
+
67
+ class TransformerDecoderLayer(nn.Module):
68
+ def __init__(
69
+ self,
70
+ d_model,
71
+ nhead,
72
+ dim_feedforward=2048,
73
+ dropout=0.1,
74
+ activation="relu",
75
+ normalize_before=False,
76
+ ):
77
+ super().__init__()
78
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
79
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
80
+ # Implementation of Feedforward model
81
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
82
+ self.dropout = nn.Dropout(dropout)
83
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
84
+
85
+ self.norm1 = nn.LayerNorm(d_model)
86
+ self.norm2 = nn.LayerNorm(d_model)
87
+ self.norm3 = nn.LayerNorm(d_model)
88
+ self.dropout1 = nn.Dropout(dropout)
89
+ self.dropout2 = nn.Dropout(dropout)
90
+ self.dropout3 = nn.Dropout(dropout)
91
+
92
+ self.activation = _get_activation_fn(activation)
93
+ self.normalize_before = normalize_before
94
+
95
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
96
+ return tensor if pos is None else tensor + pos
97
+
98
+ def forward_post(
99
+ self,
100
+ tgt,
101
+ memory,
102
+ tgt_mask: Optional[Tensor] = None,
103
+ memory_mask: Optional[Tensor] = None,
104
+ tgt_key_padding_mask: Optional[Tensor] = None,
105
+ memory_key_padding_mask: Optional[Tensor] = None,
106
+ pos: Optional[Tensor] = None,
107
+ query_pos: Optional[Tensor] = None,
108
+ residual=True,
109
+ ):
110
+ q = k = self.with_pos_embed(tgt, query_pos)
111
+ tgt2, ws = self.self_attn(
112
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
113
+ )
114
+ tgt = self.norm1(tgt)
115
+ tgt2, ws = self.multihead_attn(
116
+ query=self.with_pos_embed(tgt, query_pos),
117
+ key=self.with_pos_embed(memory, pos),
118
+ value=memory,
119
+ attn_mask=memory_mask,
120
+ key_padding_mask=memory_key_padding_mask,
121
+ )
122
+
123
+ # attn_weights [B,NUM_Q,T]
124
+ tgt = tgt + self.dropout2(tgt2)
125
+ tgt = self.norm2(tgt)
126
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
127
+ tgt = tgt + self.dropout3(tgt2)
128
+ tgt = self.norm3(tgt)
129
+ return tgt, ws
130
+
131
+ def forward_pre(
132
+ self,
133
+ tgt,
134
+ memory,
135
+ tgt_mask: Optional[Tensor] = None,
136
+ memory_mask: Optional[Tensor] = None,
137
+ tgt_key_padding_mask: Optional[Tensor] = None,
138
+ memory_key_padding_mask: Optional[Tensor] = None,
139
+ pos: Optional[Tensor] = None,
140
+ query_pos: Optional[Tensor] = None,
141
+ ):
142
+ tgt2 = self.norm1(tgt)
143
+ q = k = self.with_pos_embed(tgt2, query_pos)
144
+ tgt2, ws = self.self_attn(
145
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
146
+ )
147
+ tgt = tgt + self.dropout1(tgt2)
148
+ tgt2 = self.norm2(tgt)
149
+ tgt2, attn_weights = self.multihead_attn(
150
+ query=self.with_pos_embed(tgt2, query_pos),
151
+ key=self.with_pos_embed(memory, pos),
152
+ value=memory,
153
+ attn_mask=memory_mask,
154
+ key_padding_mask=memory_key_padding_mask,
155
+ )
156
+ tgt = tgt + self.dropout2(tgt2)
157
+ tgt2 = self.norm3(tgt)
158
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
159
+ tgt = tgt + self.dropout3(tgt2)
160
+ return tgt, attn_weights
161
+
162
+ def forward(
163
+ self,
164
+ tgt,
165
+ memory,
166
+ tgt_mask: Optional[Tensor] = None,
167
+ memory_mask: Optional[Tensor] = None,
168
+ tgt_key_padding_mask: Optional[Tensor] = None,
169
+ memory_key_padding_mask: Optional[Tensor] = None,
170
+ pos: Optional[Tensor] = None,
171
+ query_pos: Optional[Tensor] = None,
172
+ residual=True,
173
+ ):
174
+ if self.normalize_before:
175
+ return self.forward_pre(
176
+ tgt,
177
+ memory,
178
+ tgt_mask,
179
+ memory_mask,
180
+ tgt_key_padding_mask,
181
+ memory_key_padding_mask,
182
+ pos,
183
+ query_pos,
184
+ )
185
+ return self.forward_post(
186
+ tgt,
187
+ memory,
188
+ tgt_mask,
189
+ memory_mask,
190
+ tgt_key_padding_mask,
191
+ memory_key_padding_mask,
192
+ pos,
193
+ query_pos,
194
+ residual,
195
+ )
196
+
197
+
198
+ def _get_clones(module, N):
199
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
200
+
201
+
202
+ def _get_activation_fn(activation):
203
+ """Return an activation function given a string"""
204
+ if activation == "relu":
205
+ return F.relu
206
+ if activation == "gelu":
207
+ return F.gelu
208
+ if activation == "glu":
209
+ return F.glu
210
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
PreTrain_MeDSLIP/optim/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .adamp import AdamP
2
+ from .adamw import AdamW
3
+ from .adafactor import Adafactor
4
+ from .adahessian import Adahessian
5
+ from .lookahead import Lookahead
6
+ from .nadam import Nadam
7
+ from .novograd import NovoGrad
8
+ from .nvnovograd import NvNovoGrad
9
+ from .radam import RAdam
10
+ from .rmsprop_tf import RMSpropTF
11
+ from .sgdp import SGDP
12
+
13
+ from .optim_factory import create_optimizer
PreTrain_MeDSLIP/optim/adafactor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adafactor Optimizer
2
+
3
+ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4
+
5
+ Original header/copyright below.
6
+
7
+ """
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ #
10
+ # This source code is licensed under the MIT license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ import torch
13
+ import math
14
+
15
+
16
+ class Adafactor(torch.optim.Optimizer):
17
+ """Implements Adafactor algorithm.
18
+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19
+ (see https://arxiv.org/abs/1804.04235)
20
+
21
+ Note that this optimizer internally adjusts the learning rate depending on the
22
+ *scale_parameter*, *relative_step* and *warmup_init* options.
23
+
24
+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25
+ `relative_step=False`.
26
+
27
+ Arguments:
28
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29
+ lr (float, optional): external learning rate (default: None)
30
+ eps (tuple[float, float]): regularization constants for square gradient
31
+ and parameter scale respectively (default: (1e-30, 1e-3))
32
+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33
+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34
+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
35
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37
+ relative_step (bool): if True, time-dependent learning rate is computed
38
+ instead of external learning rate (default: True)
39
+ warmup_init (bool): time-dependent learning rate computation depends on
40
+ whether warm-up initialization is being used (default: False)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ params,
46
+ lr=None,
47
+ eps=1e-30,
48
+ eps_scale=1e-3,
49
+ clip_threshold=1.0,
50
+ decay_rate=-0.8,
51
+ betas=None,
52
+ weight_decay=0.0,
53
+ scale_parameter=True,
54
+ warmup_init=False,
55
+ ):
56
+ relative_step = lr is None
57
+ if warmup_init and not relative_step:
58
+ raise ValueError("warmup_init requires relative_step=True")
59
+
60
+ beta1 = (
61
+ None if betas is None else betas[0]
62
+ ) # make it compat with standard betas arg
63
+ defaults = dict(
64
+ lr=lr,
65
+ eps=eps,
66
+ eps_scale=eps_scale,
67
+ clip_threshold=clip_threshold,
68
+ decay_rate=decay_rate,
69
+ beta1=beta1,
70
+ weight_decay=weight_decay,
71
+ scale_parameter=scale_parameter,
72
+ relative_step=relative_step,
73
+ warmup_init=warmup_init,
74
+ )
75
+ super(Adafactor, self).__init__(params, defaults)
76
+
77
+ @staticmethod
78
+ def _get_lr(param_group, param_state):
79
+ if param_group["relative_step"]:
80
+ min_step = (
81
+ 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
82
+ )
83
+ lr_t = min(min_step, 1.0 / math.sqrt(param_state["step"]))
84
+ param_scale = 1.0
85
+ if param_group["scale_parameter"]:
86
+ param_scale = max(param_group["eps_scale"], param_state["RMS"])
87
+ param_group["lr"] = lr_t * param_scale
88
+ return param_group["lr"]
89
+
90
+ @staticmethod
91
+ def _get_options(param_group, param_shape):
92
+ factored = len(param_shape) >= 2
93
+ use_first_moment = param_group["beta1"] is not None
94
+ return factored, use_first_moment
95
+
96
+ @staticmethod
97
+ def _rms(tensor):
98
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
99
+
100
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
101
+ r_factor = (
102
+ (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
103
+ .rsqrt_()
104
+ .unsqueeze(-1)
105
+ )
106
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
107
+ return torch.mul(r_factor, c_factor)
108
+
109
+ def step(self, closure=None):
110
+ """Performs a single optimization step.
111
+ Arguments:
112
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
113
+ """
114
+ loss = None
115
+ if closure is not None:
116
+ loss = closure()
117
+
118
+ for group in self.param_groups:
119
+ for p in group["params"]:
120
+ if p.grad is None:
121
+ continue
122
+ grad = p.grad.data
123
+ if grad.dtype in {torch.float16, torch.bfloat16}:
124
+ grad = grad.float()
125
+ if grad.is_sparse:
126
+ raise RuntimeError("Adafactor does not support sparse gradients.")
127
+
128
+ state = self.state[p]
129
+ grad_shape = grad.shape
130
+
131
+ factored, use_first_moment = self._get_options(group, grad_shape)
132
+ # State Initialization
133
+ if len(state) == 0:
134
+ state["step"] = 0
135
+
136
+ if use_first_moment:
137
+ # Exponential moving average of gradient values
138
+ state["exp_avg"] = torch.zeros_like(grad)
139
+ if factored:
140
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
141
+ state["exp_avg_sq_col"] = torch.zeros(
142
+ grad_shape[:-2] + grad_shape[-1:]
143
+ ).to(grad)
144
+ else:
145
+ state["exp_avg_sq"] = torch.zeros_like(grad)
146
+
147
+ state["RMS"] = 0
148
+ else:
149
+ if use_first_moment:
150
+ state["exp_avg"] = state["exp_avg"].to(grad)
151
+ if factored:
152
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
153
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
154
+ else:
155
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
156
+
157
+ p_data_fp32 = p.data
158
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
159
+ p_data_fp32 = p_data_fp32.float()
160
+
161
+ state["step"] += 1
162
+ state["RMS"] = self._rms(p_data_fp32)
163
+ lr_t = self._get_lr(group, state)
164
+
165
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
166
+ update = grad ** 2 + group["eps"]
167
+ if factored:
168
+ exp_avg_sq_row = state["exp_avg_sq_row"]
169
+ exp_avg_sq_col = state["exp_avg_sq_col"]
170
+
171
+ exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
172
+ exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
173
+ # exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
174
+ # exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
175
+
176
+ # Approximation of exponential moving average of square of gradient
177
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
178
+ update.mul_(grad)
179
+ else:
180
+ exp_avg_sq = state["exp_avg_sq"]
181
+
182
+ exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
183
+ # exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
184
+ update = exp_avg_sq.rsqrt().mul_(grad)
185
+
186
+ update.div_(
187
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
188
+ )
189
+ update.mul_(lr_t)
190
+
191
+ if use_first_moment:
192
+ exp_avg = state["exp_avg"]
193
+ exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
194
+ # exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
195
+ update = exp_avg
196
+
197
+ if group["weight_decay"] != 0:
198
+ p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
199
+ # p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
200
+
201
+ p_data_fp32.add_(-update)
202
+
203
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
204
+ p.data.copy_(p_data_fp32)
205
+
206
+ return loss
PreTrain_MeDSLIP/optim/adahessian.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdaHessian Optimizer
2
+
3
+ Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4
+ Originally licensed MIT, Copyright 2020, David Samuel
5
+ """
6
+ import torch
7
+
8
+
9
+ class Adahessian(torch.optim.Optimizer):
10
+ """
11
+ Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12
+
13
+ Arguments:
14
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15
+ lr (float, optional): learning rate (default: 0.1)
16
+ betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17
+ squared hessian trace (default: (0.9, 0.999))
18
+ eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20
+ hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21
+ update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22
+ (to save time) (default: 1)
23
+ n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ params,
29
+ lr=0.1,
30
+ betas=(0.9, 0.999),
31
+ eps=1e-8,
32
+ weight_decay=0.0,
33
+ hessian_power=1.0,
34
+ update_each=1,
35
+ n_samples=1,
36
+ avg_conv_kernel=False,
37
+ ):
38
+ if not 0.0 <= lr:
39
+ raise ValueError(f"Invalid learning rate: {lr}")
40
+ if not 0.0 <= eps:
41
+ raise ValueError(f"Invalid epsilon value: {eps}")
42
+ if not 0.0 <= betas[0] < 1.0:
43
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
44
+ if not 0.0 <= betas[1] < 1.0:
45
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
46
+ if not 0.0 <= hessian_power <= 1.0:
47
+ raise ValueError(f"Invalid Hessian power value: {hessian_power}")
48
+
49
+ self.n_samples = n_samples
50
+ self.update_each = update_each
51
+ self.avg_conv_kernel = avg_conv_kernel
52
+
53
+ # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
54
+ self.seed = 2147483647
55
+ self.generator = torch.Generator().manual_seed(self.seed)
56
+
57
+ defaults = dict(
58
+ lr=lr,
59
+ betas=betas,
60
+ eps=eps,
61
+ weight_decay=weight_decay,
62
+ hessian_power=hessian_power,
63
+ )
64
+ super(Adahessian, self).__init__(params, defaults)
65
+
66
+ for p in self.get_params():
67
+ p.hess = 0.0
68
+ self.state[p]["hessian step"] = 0
69
+
70
+ @property
71
+ def is_second_order(self):
72
+ return True
73
+
74
+ def get_params(self):
75
+ """
76
+ Gets all parameters in all param_groups with gradients
77
+ """
78
+
79
+ return (
80
+ p for group in self.param_groups for p in group["params"] if p.requires_grad
81
+ )
82
+
83
+ def zero_hessian(self):
84
+ """
85
+ Zeros out the accumalated hessian traces.
86
+ """
87
+
88
+ for p in self.get_params():
89
+ if (
90
+ not isinstance(p.hess, float)
91
+ and self.state[p]["hessian step"] % self.update_each == 0
92
+ ):
93
+ p.hess.zero_()
94
+
95
+ @torch.no_grad()
96
+ def set_hessian(self):
97
+ """
98
+ Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
99
+ """
100
+
101
+ params = []
102
+ for p in filter(lambda p: p.grad is not None, self.get_params()):
103
+ if (
104
+ self.state[p]["hessian step"] % self.update_each == 0
105
+ ): # compute the trace only each `update_each` step
106
+ params.append(p)
107
+ self.state[p]["hessian step"] += 1
108
+
109
+ if len(params) == 0:
110
+ return
111
+
112
+ if (
113
+ self.generator.device != params[0].device
114
+ ): # hackish way of casting the generator to the right device
115
+ self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
116
+
117
+ grads = [p.grad for p in params]
118
+
119
+ for i in range(self.n_samples):
120
+ # Rademacher distribution {-1.0, 1.0}
121
+ zs = [
122
+ torch.randint(0, 2, p.size(), generator=self.generator, device=p.device)
123
+ * 2.0
124
+ - 1.0
125
+ for p in params
126
+ ]
127
+ h_zs = torch.autograd.grad(
128
+ grads,
129
+ params,
130
+ grad_outputs=zs,
131
+ only_inputs=True,
132
+ retain_graph=i < self.n_samples - 1,
133
+ )
134
+ for h_z, z, p in zip(h_zs, zs, params):
135
+ p.hess += (
136
+ h_z * z / self.n_samples
137
+ ) # approximate the expected values of z*(H@z)
138
+
139
+ @torch.no_grad()
140
+ def step(self, closure=None):
141
+ """
142
+ Performs a single optimization step.
143
+ Arguments:
144
+ closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
145
+ """
146
+
147
+ loss = None
148
+ if closure is not None:
149
+ loss = closure()
150
+
151
+ self.zero_hessian()
152
+ self.set_hessian()
153
+
154
+ for group in self.param_groups:
155
+ for p in group["params"]:
156
+ if p.grad is None or p.hess is None:
157
+ continue
158
+
159
+ if self.avg_conv_kernel and p.dim() == 4:
160
+ p.hess = (
161
+ torch.abs(p.hess)
162
+ .mean(dim=[2, 3], keepdim=True)
163
+ .expand_as(p.hess)
164
+ .clone()
165
+ )
166
+
167
+ # Perform correct stepweight decay as in AdamW
168
+ p.mul_(1 - group["lr"] * group["weight_decay"])
169
+
170
+ state = self.state[p]
171
+
172
+ # State initialization
173
+ if len(state) == 1:
174
+ state["step"] = 0
175
+ # Exponential moving average of gradient values
176
+ state["exp_avg"] = torch.zeros_like(p)
177
+ # Exponential moving average of Hessian diagonal square values
178
+ state["exp_hessian_diag_sq"] = torch.zeros_like(p)
179
+
180
+ exp_avg, exp_hessian_diag_sq = (
181
+ state["exp_avg"],
182
+ state["exp_hessian_diag_sq"],
183
+ )
184
+ beta1, beta2 = group["betas"]
185
+ state["step"] += 1
186
+
187
+ # Decay the first and second moment running average coefficient
188
+ exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
189
+ exp_hessian_diag_sq.mul_(beta2).addcmul_(
190
+ p.hess, p.hess, value=1 - beta2
191
+ )
192
+
193
+ bias_correction1 = 1 - beta1 ** state["step"]
194
+ bias_correction2 = 1 - beta2 ** state["step"]
195
+
196
+ k = group["hessian_power"]
197
+ denom = (
198
+ (exp_hessian_diag_sq / bias_correction2)
199
+ .pow_(k / 2)
200
+ .add_(group["eps"])
201
+ )
202
+
203
+ # make update
204
+ step_size = group["lr"] / bias_correction1
205
+ p.addcdiv_(exp_avg, denom, value=-step_size)
206
+
207
+ return loss
PreTrain_MeDSLIP/optim/adamp.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3
+
4
+ Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5
+ Code: https://github.com/clovaai/AdamP
6
+
7
+ Copyright (c) 2020-present NAVER Corp.
8
+ MIT license
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.optim.optimizer import Optimizer, required
14
+ import math
15
+
16
+
17
+ class AdamP(Optimizer):
18
+ def __init__(
19
+ self,
20
+ params,
21
+ lr=1e-3,
22
+ betas=(0.9, 0.999),
23
+ eps=1e-8,
24
+ weight_decay=0,
25
+ delta=0.1,
26
+ wd_ratio=0.1,
27
+ nesterov=False,
28
+ ):
29
+ defaults = dict(
30
+ lr=lr,
31
+ betas=betas,
32
+ eps=eps,
33
+ weight_decay=weight_decay,
34
+ delta=delta,
35
+ wd_ratio=wd_ratio,
36
+ nesterov=nesterov,
37
+ )
38
+ super(AdamP, self).__init__(params, defaults)
39
+
40
+ def _channel_view(self, x):
41
+ return x.view(x.size(0), -1)
42
+
43
+ def _layer_view(self, x):
44
+ return x.view(1, -1)
45
+
46
+ def _cosine_similarity(self, x, y, eps, view_func):
47
+ x = view_func(x)
48
+ y = view_func(y)
49
+
50
+ x_norm = x.norm(dim=1).add_(eps)
51
+ y_norm = y.norm(dim=1).add_(eps)
52
+ dot = (x * y).sum(dim=1)
53
+
54
+ return dot.abs() / x_norm / y_norm
55
+
56
+ def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
57
+ wd = 1
58
+ expand_size = [-1] + [1] * (len(p.shape) - 1)
59
+ for view_func in [self._channel_view, self._layer_view]:
60
+
61
+ cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
62
+
63
+ if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
64
+ p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
65
+ perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
66
+ wd = wd_ratio
67
+
68
+ return perturb, wd
69
+
70
+ return perturb, wd
71
+
72
+ def step(self, closure=None):
73
+ loss = None
74
+ if closure is not None:
75
+ loss = closure()
76
+
77
+ for group in self.param_groups:
78
+ for p in group["params"]:
79
+ if p.grad is None:
80
+ continue
81
+
82
+ grad = p.grad.data
83
+ beta1, beta2 = group["betas"]
84
+ nesterov = group["nesterov"]
85
+
86
+ state = self.state[p]
87
+
88
+ # State initialization
89
+ if len(state) == 0:
90
+ state["step"] = 0
91
+ state["exp_avg"] = torch.zeros_like(p.data)
92
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
93
+
94
+ # Adam
95
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
96
+
97
+ state["step"] += 1
98
+ bias_correction1 = 1 - beta1 ** state["step"]
99
+ bias_correction2 = 1 - beta2 ** state["step"]
100
+
101
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
102
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
103
+
104
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
105
+ group["eps"]
106
+ )
107
+ step_size = group["lr"] / bias_correction1
108
+
109
+ if nesterov:
110
+ perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
111
+ else:
112
+ perturb = exp_avg / denom
113
+
114
+ # Projection
115
+ wd_ratio = 1
116
+ if len(p.shape) > 1:
117
+ perturb, wd_ratio = self._projection(
118
+ p,
119
+ grad,
120
+ perturb,
121
+ group["delta"],
122
+ group["wd_ratio"],
123
+ group["eps"],
124
+ )
125
+
126
+ # Weight decay
127
+ if group["weight_decay"] > 0:
128
+ p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio)
129
+
130
+ # Step
131
+ p.data.add_(-step_size, perturb)
132
+
133
+ return loss
PreTrain_MeDSLIP/optim/adamw.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdamW Optimizer
2
+ Impl copied from PyTorch master
3
+ """
4
+ import math
5
+ import torch
6
+ from torch.optim.optimizer import Optimizer
7
+
8
+
9
+ class AdamW(Optimizer):
10
+ r"""Implements AdamW algorithm.
11
+
12
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14
+
15
+ Arguments:
16
+ params (iterable): iterable of parameters to optimize or dicts defining
17
+ parameter groups
18
+ lr (float, optional): learning rate (default: 1e-3)
19
+ betas (Tuple[float, float], optional): coefficients used for computing
20
+ running averages of gradient and its square (default: (0.9, 0.999))
21
+ eps (float, optional): term added to the denominator to improve
22
+ numerical stability (default: 1e-8)
23
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
26
+ (default: False)
27
+
28
+ .. _Adam\: A Method for Stochastic Optimization:
29
+ https://arxiv.org/abs/1412.6980
30
+ .. _Decoupled Weight Decay Regularization:
31
+ https://arxiv.org/abs/1711.05101
32
+ .. _On the Convergence of Adam and Beyond:
33
+ https://openreview.net/forum?id=ryQu7f-RZ
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ params,
39
+ lr=1e-3,
40
+ betas=(0.9, 0.999),
41
+ eps=1e-8,
42
+ weight_decay=1e-2,
43
+ amsgrad=False,
44
+ ):
45
+ if not 0.0 <= lr:
46
+ raise ValueError("Invalid learning rate: {}".format(lr))
47
+ if not 0.0 <= eps:
48
+ raise ValueError("Invalid epsilon value: {}".format(eps))
49
+ if not 0.0 <= betas[0] < 1.0:
50
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
51
+ if not 0.0 <= betas[1] < 1.0:
52
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
53
+ defaults = dict(
54
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
55
+ )
56
+ super(AdamW, self).__init__(params, defaults)
57
+
58
+ def __setstate__(self, state):
59
+ super(AdamW, self).__setstate__(state)
60
+ for group in self.param_groups:
61
+ group.setdefault("amsgrad", False)
62
+
63
+ def step(self, closure=None):
64
+ """Performs a single optimization step.
65
+
66
+ Arguments:
67
+ closure (callable, optional): A closure that reevaluates the model
68
+ and returns the loss.
69
+ """
70
+ loss = None
71
+ if closure is not None:
72
+ loss = closure()
73
+
74
+ for group in self.param_groups:
75
+ for p in group["params"]:
76
+ if p.grad is None:
77
+ continue
78
+
79
+ # Perform stepweight decay
80
+ p.data.mul_(1 - group["lr"] * group["weight_decay"])
81
+
82
+ # Perform optimization step
83
+ grad = p.grad.data
84
+ if grad.is_sparse:
85
+ raise RuntimeError(
86
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
87
+ )
88
+ amsgrad = group["amsgrad"]
89
+
90
+ state = self.state[p]
91
+
92
+ # State initialization
93
+ if len(state) == 0:
94
+ state["step"] = 0
95
+ # Exponential moving average of gradient values
96
+ state["exp_avg"] = torch.zeros_like(p.data)
97
+ # Exponential moving average of squared gradient values
98
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
99
+ if amsgrad:
100
+ # Maintains max of all exp. moving avg. of sq. grad. values
101
+ state["max_exp_avg_sq"] = torch.zeros_like(p.data)
102
+
103
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
104
+ if amsgrad:
105
+ max_exp_avg_sq = state["max_exp_avg_sq"]
106
+ beta1, beta2 = group["betas"]
107
+
108
+ state["step"] += 1
109
+ bias_correction1 = 1 - beta1 ** state["step"]
110
+ bias_correction2 = 1 - beta2 ** state["step"]
111
+
112
+ # Decay the first and second moment running average coefficient
113
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
114
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
115
+ if amsgrad:
116
+ # Maintains the maximum of all 2nd moment running avg. till now
117
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
118
+ # Use the max. for normalizing running avg. of gradient
119
+ denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
120
+ group["eps"]
121
+ )
122
+ else:
123
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
124
+ group["eps"]
125
+ )
126
+
127
+ step_size = group["lr"] / bias_correction1
128
+
129
+ p.data.addcdiv_(-step_size, exp_avg, denom)
130
+
131
+ return loss
PreTrain_MeDSLIP/optim/lookahead.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Lookahead Optimizer Wrapper.
2
+ Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3
+ Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ from collections import defaultdict
10
+
11
+
12
+ class Lookahead(Optimizer):
13
+ def __init__(self, base_optimizer, alpha=0.5, k=6):
14
+ if not 0.0 <= alpha <= 1.0:
15
+ raise ValueError(f"Invalid slow update rate: {alpha}")
16
+ if not 1 <= k:
17
+ raise ValueError(f"Invalid lookahead steps: {k}")
18
+ defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
19
+ self.base_optimizer = base_optimizer
20
+ self.param_groups = self.base_optimizer.param_groups
21
+ self.defaults = base_optimizer.defaults
22
+ self.defaults.update(defaults)
23
+ self.state = defaultdict(dict)
24
+ # manually add our defaults to the param groups
25
+ for name, default in defaults.items():
26
+ for group in self.param_groups:
27
+ group.setdefault(name, default)
28
+
29
+ def update_slow(self, group):
30
+ for fast_p in group["params"]:
31
+ if fast_p.grad is None:
32
+ continue
33
+ param_state = self.state[fast_p]
34
+ if "slow_buffer" not in param_state:
35
+ param_state["slow_buffer"] = torch.empty_like(fast_p.data)
36
+ param_state["slow_buffer"].copy_(fast_p.data)
37
+ slow = param_state["slow_buffer"]
38
+ slow.add_(group["lookahead_alpha"], fast_p.data - slow)
39
+ fast_p.data.copy_(slow)
40
+
41
+ def sync_lookahead(self):
42
+ for group in self.param_groups:
43
+ self.update_slow(group)
44
+
45
+ def step(self, closure=None):
46
+ # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
47
+ loss = self.base_optimizer.step(closure)
48
+ for group in self.param_groups:
49
+ group["lookahead_step"] += 1
50
+ if group["lookahead_step"] % group["lookahead_k"] == 0:
51
+ self.update_slow(group)
52
+ return loss
53
+
54
+ def state_dict(self):
55
+ fast_state_dict = self.base_optimizer.state_dict()
56
+ slow_state = {
57
+ (id(k) if isinstance(k, torch.Tensor) else k): v
58
+ for k, v in self.state.items()
59
+ }
60
+ fast_state = fast_state_dict["state"]
61
+ param_groups = fast_state_dict["param_groups"]
62
+ return {
63
+ "state": fast_state,
64
+ "slow_state": slow_state,
65
+ "param_groups": param_groups,
66
+ }
67
+
68
+ def load_state_dict(self, state_dict):
69
+ fast_state_dict = {
70
+ "state": state_dict["state"],
71
+ "param_groups": state_dict["param_groups"],
72
+ }
73
+ self.base_optimizer.load_state_dict(fast_state_dict)
74
+
75
+ # We want to restore the slow state, but share param_groups reference
76
+ # with base_optimizer. This is a bit redundant but least code
77
+ slow_state_new = False
78
+ if "slow_state" not in state_dict:
79
+ print("Loading state_dict from optimizer without Lookahead applied.")
80
+ state_dict["slow_state"] = defaultdict(dict)
81
+ slow_state_new = True
82
+ slow_state_dict = {
83
+ "state": state_dict["slow_state"],
84
+ "param_groups": state_dict[
85
+ "param_groups"
86
+ ], # this is pointless but saves code
87
+ }
88
+ super(Lookahead, self).load_state_dict(slow_state_dict)
89
+ self.param_groups = (
90
+ self.base_optimizer.param_groups
91
+ ) # make both ref same container
92
+ if slow_state_new:
93
+ # reapply defaults to catch missing lookahead specific ones
94
+ for name, default in self.defaults.items():
95
+ for group in self.param_groups:
96
+ group.setdefault(name, default)
PreTrain_MeDSLIP/optim/nadam.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+
5
+ class Nadam(Optimizer):
6
+ """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7
+
8
+ It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9
+
10
+ Arguments:
11
+ params (iterable): iterable of parameters to optimize or dicts defining
12
+ parameter groups
13
+ lr (float, optional): learning rate (default: 2e-3)
14
+ betas (Tuple[float, float], optional): coefficients used for computing
15
+ running averages of gradient and its square
16
+ eps (float, optional): term added to the denominator to improve
17
+ numerical stability (default: 1e-8)
18
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19
+ schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20
+
21
+ __ http://cs229.stanford.edu/proj2015/054_report.pdf
22
+ __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23
+
24
+ Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25
+ NOTE: Has potential issues but does work well on some problems.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ params,
31
+ lr=2e-3,
32
+ betas=(0.9, 0.999),
33
+ eps=1e-8,
34
+ weight_decay=0,
35
+ schedule_decay=4e-3,
36
+ ):
37
+ defaults = dict(
38
+ lr=lr,
39
+ betas=betas,
40
+ eps=eps,
41
+ weight_decay=weight_decay,
42
+ schedule_decay=schedule_decay,
43
+ )
44
+ super(Nadam, self).__init__(params, defaults)
45
+
46
+ def step(self, closure=None):
47
+ """Performs a single optimization step.
48
+
49
+ Arguments:
50
+ closure (callable, optional): A closure that reevaluates the model
51
+ and returns the loss.
52
+ """
53
+ loss = None
54
+ if closure is not None:
55
+ loss = closure()
56
+
57
+ for group in self.param_groups:
58
+ for p in group["params"]:
59
+ if p.grad is None:
60
+ continue
61
+ grad = p.grad.data
62
+ state = self.state[p]
63
+
64
+ # State initialization
65
+ if len(state) == 0:
66
+ state["step"] = 0
67
+ state["m_schedule"] = 1.0
68
+ state["exp_avg"] = grad.new().resize_as_(grad).zero_()
69
+ state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_()
70
+
71
+ # Warming momentum schedule
72
+ m_schedule = state["m_schedule"]
73
+ schedule_decay = group["schedule_decay"]
74
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
+ beta1, beta2 = group["betas"]
76
+ eps = group["eps"]
77
+ state["step"] += 1
78
+ t = state["step"]
79
+
80
+ if group["weight_decay"] != 0:
81
+ grad = grad.add(group["weight_decay"], p.data)
82
+
83
+ momentum_cache_t = beta1 * (1.0 - 0.5 * (0.96 ** (t * schedule_decay)))
84
+ momentum_cache_t_1 = beta1 * (
85
+ 1.0 - 0.5 * (0.96 ** ((t + 1) * schedule_decay))
86
+ )
87
+ m_schedule_new = m_schedule * momentum_cache_t
88
+ m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
89
+ state["m_schedule"] = m_schedule_new
90
+
91
+ # Decay the first and second moment running average coefficient
92
+ exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
93
+ exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
94
+ exp_avg_sq_prime = exp_avg_sq / (1.0 - beta2 ** t)
95
+ denom = exp_avg_sq_prime.sqrt_().add_(eps)
96
+
97
+ p.data.addcdiv_(
98
+ -group["lr"] * (1.0 - momentum_cache_t) / (1.0 - m_schedule_new),
99
+ grad,
100
+ denom,
101
+ )
102
+ p.data.addcdiv_(
103
+ -group["lr"] * momentum_cache_t_1 / (1.0 - m_schedule_next),
104
+ exp_avg,
105
+ denom,
106
+ )
107
+
108
+ return loss
PreTrain_MeDSLIP/optim/novograd.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NovoGrad Optimizer.
2
+ Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3
+ Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4
+ - https://arxiv.org/abs/1905.11286
5
+ """
6
+
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ import math
10
+
11
+
12
+ class NovoGrad(Optimizer):
13
+ def __init__(
14
+ self,
15
+ params,
16
+ grad_averaging=False,
17
+ lr=0.1,
18
+ betas=(0.95, 0.98),
19
+ eps=1e-8,
20
+ weight_decay=0,
21
+ ):
22
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
23
+ super(NovoGrad, self).__init__(params, defaults)
24
+ self._lr = lr
25
+ self._beta1 = betas[0]
26
+ self._beta2 = betas[1]
27
+ self._eps = eps
28
+ self._wd = weight_decay
29
+ self._grad_averaging = grad_averaging
30
+
31
+ self._momentum_initialized = False
32
+
33
+ def step(self, closure=None):
34
+ loss = None
35
+ if closure is not None:
36
+ loss = closure()
37
+
38
+ if not self._momentum_initialized:
39
+ for group in self.param_groups:
40
+ for p in group["params"]:
41
+ if p.grad is None:
42
+ continue
43
+ state = self.state[p]
44
+ grad = p.grad.data
45
+ if grad.is_sparse:
46
+ raise RuntimeError("NovoGrad does not support sparse gradients")
47
+
48
+ v = torch.norm(grad) ** 2
49
+ m = grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
50
+ state["step"] = 0
51
+ state["v"] = v
52
+ state["m"] = m
53
+ state["grad_ema"] = None
54
+ self._momentum_initialized = True
55
+
56
+ for group in self.param_groups:
57
+ for p in group["params"]:
58
+ if p.grad is None:
59
+ continue
60
+ state = self.state[p]
61
+ state["step"] += 1
62
+
63
+ step, v, m = state["step"], state["v"], state["m"]
64
+ grad_ema = state["grad_ema"]
65
+
66
+ grad = p.grad.data
67
+ g2 = torch.norm(grad) ** 2
68
+ grad_ema = (
69
+ g2
70
+ if grad_ema is None
71
+ else grad_ema * self._beta2 + g2 * (1.0 - self._beta2)
72
+ )
73
+ grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
74
+
75
+ if self._grad_averaging:
76
+ grad *= 1.0 - self._beta1
77
+
78
+ g2 = torch.norm(grad) ** 2
79
+ v = self._beta2 * v + (1.0 - self._beta2) * g2
80
+ m = self._beta1 * m + (
81
+ grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
82
+ )
83
+ bias_correction1 = 1 - self._beta1 ** step
84
+ bias_correction2 = 1 - self._beta2 ** step
85
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
86
+
87
+ state["v"], state["m"] = v, m
88
+ state["grad_ema"] = grad_ema
89
+ p.data.add_(-step_size, m)
90
+ return loss
PreTrain_MeDSLIP/optim/nvnovograd.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Nvidia NovoGrad Optimizer.
2
+ Original impl by Nvidia from Jasper example:
3
+ - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4
+ Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5
+ - https://arxiv.org/abs/1905.11286
6
+ """
7
+
8
+ import torch
9
+ from torch.optim.optimizer import Optimizer
10
+ import math
11
+
12
+
13
+ class NvNovoGrad(Optimizer):
14
+ """
15
+ Implements Novograd algorithm.
16
+
17
+ Args:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float, optional): learning rate (default: 1e-3)
21
+ betas (Tuple[float, float], optional): coefficients used for computing
22
+ running averages of gradient and its square (default: (0.95, 0.98))
23
+ eps (float, optional): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26
+ grad_averaging: gradient averaging
27
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
29
+ (default: False)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ params,
35
+ lr=1e-3,
36
+ betas=(0.95, 0.98),
37
+ eps=1e-8,
38
+ weight_decay=0,
39
+ grad_averaging=False,
40
+ amsgrad=False,
41
+ ):
42
+ if not 0.0 <= lr:
43
+ raise ValueError("Invalid learning rate: {}".format(lr))
44
+ if not 0.0 <= eps:
45
+ raise ValueError("Invalid epsilon value: {}".format(eps))
46
+ if not 0.0 <= betas[0] < 1.0:
47
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
48
+ if not 0.0 <= betas[1] < 1.0:
49
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
50
+ defaults = dict(
51
+ lr=lr,
52
+ betas=betas,
53
+ eps=eps,
54
+ weight_decay=weight_decay,
55
+ grad_averaging=grad_averaging,
56
+ amsgrad=amsgrad,
57
+ )
58
+
59
+ super(NvNovoGrad, self).__init__(params, defaults)
60
+
61
+ def __setstate__(self, state):
62
+ super(NvNovoGrad, self).__setstate__(state)
63
+ for group in self.param_groups:
64
+ group.setdefault("amsgrad", False)
65
+
66
+ def step(self, closure=None):
67
+ """Performs a single optimization step.
68
+
69
+ Arguments:
70
+ closure (callable, optional): A closure that reevaluates the model
71
+ and returns the loss.
72
+ """
73
+ loss = None
74
+ if closure is not None:
75
+ loss = closure()
76
+
77
+ for group in self.param_groups:
78
+ for p in group["params"]:
79
+ if p.grad is None:
80
+ continue
81
+ grad = p.grad.data
82
+ if grad.is_sparse:
83
+ raise RuntimeError("Sparse gradients are not supported.")
84
+ amsgrad = group["amsgrad"]
85
+
86
+ state = self.state[p]
87
+
88
+ # State initialization
89
+ if len(state) == 0:
90
+ state["step"] = 0
91
+ # Exponential moving average of gradient values
92
+ state["exp_avg"] = torch.zeros_like(p.data)
93
+ # Exponential moving average of squared gradient values
94
+ state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
95
+ if amsgrad:
96
+ # Maintains max of all exp. moving avg. of sq. grad. values
97
+ state["max_exp_avg_sq"] = torch.zeros([]).to(
98
+ state["exp_avg"].device
99
+ )
100
+
101
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
102
+ if amsgrad:
103
+ max_exp_avg_sq = state["max_exp_avg_sq"]
104
+ beta1, beta2 = group["betas"]
105
+
106
+ state["step"] += 1
107
+
108
+ norm = torch.sum(torch.pow(grad, 2))
109
+
110
+ if exp_avg_sq == 0:
111
+ exp_avg_sq.copy_(norm)
112
+ else:
113
+ exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
114
+
115
+ if amsgrad:
116
+ # Maintains the maximum of all 2nd moment running avg. till now
117
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
118
+ # Use the max. for normalizing running avg. of gradient
119
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"])
120
+ else:
121
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
122
+
123
+ grad.div_(denom)
124
+ if group["weight_decay"] != 0:
125
+ grad.add_(group["weight_decay"], p.data)
126
+ if group["grad_averaging"]:
127
+ grad.mul_(1 - beta1)
128
+ exp_avg.mul_(beta1).add_(grad)
129
+
130
+ p.data.add_(-group["lr"], exp_avg)
131
+
132
+ return loss
PreTrain_MeDSLIP/optim/optim_factory.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Optimizer Factory w/ Custom Weight Decay
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ from torch import optim as optim
6
+
7
+ from .adafactor import Adafactor
8
+ from .adahessian import Adahessian
9
+ from .adamp import AdamP
10
+ from .lookahead import Lookahead
11
+ from .nadam import Nadam
12
+ from .novograd import NovoGrad
13
+ from .nvnovograd import NvNovoGrad
14
+ from .radam import RAdam
15
+ from .rmsprop_tf import RMSpropTF
16
+ from .sgdp import SGDP
17
+
18
+ try:
19
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
20
+
21
+ has_apex = True
22
+ except ImportError:
23
+ has_apex = False
24
+
25
+
26
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
27
+ decay = []
28
+ no_decay = []
29
+ for name, param in model.named_parameters():
30
+ if not param.requires_grad:
31
+ continue # frozen weights
32
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
33
+ no_decay.append(param)
34
+ else:
35
+ decay.append(param)
36
+ return [
37
+ {"params": no_decay, "weight_decay": 0.0},
38
+ {"params": decay, "weight_decay": weight_decay},
39
+ ]
40
+
41
+
42
+ def create_optimizer(args, model, filter_bias_and_bn=True):
43
+ opt_lower = args.opt.lower()
44
+ weight_decay = args.weight_decay
45
+ if weight_decay and filter_bias_and_bn:
46
+ skip = {}
47
+ if hasattr(model, "no_weight_decay"):
48
+ skip = model.no_weight_decay()
49
+ parameters = add_weight_decay(model, weight_decay, skip)
50
+ weight_decay = 0.0
51
+ else:
52
+ parameters = filter(
53
+ lambda p: p.requires_grad, model.parameters()
54
+ ) # model.parameters()
55
+
56
+ if "fused" in opt_lower:
57
+ assert (
58
+ has_apex and torch.cuda.is_available()
59
+ ), "APEX and CUDA required for fused optimizers"
60
+
61
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
62
+ if hasattr(args, "opt_eps") and args.opt_eps is not None:
63
+ opt_args["eps"] = args.opt_eps
64
+ if hasattr(args, "opt_betas") and args.opt_betas is not None:
65
+ opt_args["betas"] = args.opt_betas
66
+ if hasattr(args, "opt_args") and args.opt_args is not None:
67
+ opt_args.update(args.opt_args)
68
+
69
+ opt_split = opt_lower.split("_")
70
+ opt_lower = opt_split[-1]
71
+ if opt_lower == "sgd" or opt_lower == "nesterov":
72
+ opt_args.pop("eps", None)
73
+ optimizer = optim.SGD(
74
+ parameters, momentum=args.momentum, nesterov=True, **opt_args
75
+ )
76
+ elif opt_lower == "momentum":
77
+ opt_args.pop("eps", None)
78
+ optimizer = optim.SGD(
79
+ parameters, momentum=args.momentum, nesterov=False, **opt_args
80
+ )
81
+ elif opt_lower == "adam":
82
+ optimizer = optim.Adam(parameters, **opt_args)
83
+ elif opt_lower == "adamw":
84
+ optimizer = optim.AdamW(parameters, **opt_args)
85
+ elif opt_lower == "nadam":
86
+ optimizer = Nadam(parameters, **opt_args)
87
+ elif opt_lower == "radam":
88
+ optimizer = RAdam(parameters, **opt_args)
89
+ elif opt_lower == "adamp":
90
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
91
+ elif opt_lower == "sgdp":
92
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
93
+ elif opt_lower == "adadelta":
94
+ optimizer = optim.Adadelta(parameters, **opt_args)
95
+ elif opt_lower == "adafactor":
96
+ if not args.lr:
97
+ opt_args["lr"] = None
98
+ optimizer = Adafactor(parameters, **opt_args)
99
+ elif opt_lower == "adahessian":
100
+ optimizer = Adahessian(parameters, **opt_args)
101
+ elif opt_lower == "rmsprop":
102
+ optimizer = optim.RMSprop(
103
+ parameters, alpha=0.9, momentum=args.momentum, **opt_args
104
+ )
105
+ elif opt_lower == "rmsproptf":
106
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
107
+ elif opt_lower == "novograd":
108
+ optimizer = NovoGrad(parameters, **opt_args)
109
+ elif opt_lower == "nvnovograd":
110
+ optimizer = NvNovoGrad(parameters, **opt_args)
111
+ elif opt_lower == "fusedsgd":
112
+ opt_args.pop("eps", None)
113
+ optimizer = FusedSGD(
114
+ parameters, momentum=args.momentum, nesterov=True, **opt_args
115
+ )
116
+ elif opt_lower == "fusedmomentum":
117
+ opt_args.pop("eps", None)
118
+ optimizer = FusedSGD(
119
+ parameters, momentum=args.momentum, nesterov=False, **opt_args
120
+ )
121
+ elif opt_lower == "fusedadam":
122
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
123
+ elif opt_lower == "fusedadamw":
124
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
125
+ elif opt_lower == "fusedlamb":
126
+ optimizer = FusedLAMB(parameters, **opt_args)
127
+ elif opt_lower == "fusednovograd":
128
+ opt_args.setdefault("betas", (0.95, 0.98))
129
+ optimizer = FusedNovoGrad(parameters, **opt_args)
130
+ else:
131
+ assert False and "Invalid optimizer"
132
+ raise ValueError
133
+
134
+ if len(opt_split) > 1:
135
+ if opt_split[0] == "lookahead":
136
+ optimizer = Lookahead(optimizer)
137
+
138
+ return optimizer
PreTrain_MeDSLIP/optim/radam.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAdam Optimizer.
2
+ Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3
+ Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4
+ """
5
+ import math
6
+ import torch
7
+ from torch.optim.optimizer import Optimizer, required
8
+
9
+
10
+ class RAdam(Optimizer):
11
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
12
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
13
+ self.buffer = [[None, None, None] for ind in range(10)]
14
+ super(RAdam, self).__init__(params, defaults)
15
+
16
+ def __setstate__(self, state):
17
+ super(RAdam, self).__setstate__(state)
18
+
19
+ def step(self, closure=None):
20
+
21
+ loss = None
22
+ if closure is not None:
23
+ loss = closure()
24
+
25
+ for group in self.param_groups:
26
+
27
+ for p in group["params"]:
28
+ if p.grad is None:
29
+ continue
30
+ grad = p.grad.data.float()
31
+ if grad.is_sparse:
32
+ raise RuntimeError("RAdam does not support sparse gradients")
33
+
34
+ p_data_fp32 = p.data.float()
35
+
36
+ state = self.state[p]
37
+
38
+ if len(state) == 0:
39
+ state["step"] = 0
40
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
41
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
42
+ else:
43
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
44
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
45
+
46
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
47
+ beta1, beta2 = group["betas"]
48
+
49
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
50
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
51
+
52
+ state["step"] += 1
53
+ buffered = self.buffer[int(state["step"] % 10)]
54
+ if state["step"] == buffered[0]:
55
+ N_sma, step_size = buffered[1], buffered[2]
56
+ else:
57
+ buffered[0] = state["step"]
58
+ beta2_t = beta2 ** state["step"]
59
+ N_sma_max = 2 / (1 - beta2) - 1
60
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
61
+ buffered[1] = N_sma
62
+
63
+ # more conservative since it's an approximated value
64
+ if N_sma >= 5:
65
+ step_size = (
66
+ group["lr"]
67
+ * math.sqrt(
68
+ (1 - beta2_t)
69
+ * (N_sma - 4)
70
+ / (N_sma_max - 4)
71
+ * (N_sma - 2)
72
+ / N_sma
73
+ * N_sma_max
74
+ / (N_sma_max - 2)
75
+ )
76
+ / (1 - beta1 ** state["step"])
77
+ )
78
+ else:
79
+ step_size = group["lr"] / (1 - beta1 ** state["step"])
80
+ buffered[2] = step_size
81
+
82
+ if group["weight_decay"] != 0:
83
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
84
+
85
+ # more conservative since it's an approximated value
86
+ if N_sma >= 5:
87
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
88
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
89
+ else:
90
+ p_data_fp32.add_(-step_size, exp_avg)
91
+
92
+ p.data.copy_(p_data_fp32)
93
+
94
+ return loss
95
+
96
+
97
+ class PlainRAdam(Optimizer):
98
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
99
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
100
+
101
+ super(PlainRAdam, self).__init__(params, defaults)
102
+
103
+ def __setstate__(self, state):
104
+ super(PlainRAdam, self).__setstate__(state)
105
+
106
+ def step(self, closure=None):
107
+
108
+ loss = None
109
+ if closure is not None:
110
+ loss = closure()
111
+
112
+ for group in self.param_groups:
113
+
114
+ for p in group["params"]:
115
+ if p.grad is None:
116
+ continue
117
+ grad = p.grad.data.float()
118
+ if grad.is_sparse:
119
+ raise RuntimeError("RAdam does not support sparse gradients")
120
+
121
+ p_data_fp32 = p.data.float()
122
+
123
+ state = self.state[p]
124
+
125
+ if len(state) == 0:
126
+ state["step"] = 0
127
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
128
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
129
+ else:
130
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
131
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
132
+
133
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
134
+ beta1, beta2 = group["betas"]
135
+
136
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
137
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
138
+
139
+ state["step"] += 1
140
+ beta2_t = beta2 ** state["step"]
141
+ N_sma_max = 2 / (1 - beta2) - 1
142
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
143
+
144
+ if group["weight_decay"] != 0:
145
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
146
+
147
+ # more conservative since it's an approximated value
148
+ if N_sma >= 5:
149
+ step_size = (
150
+ group["lr"]
151
+ * math.sqrt(
152
+ (1 - beta2_t)
153
+ * (N_sma - 4)
154
+ / (N_sma_max - 4)
155
+ * (N_sma - 2)
156
+ / N_sma
157
+ * N_sma_max
158
+ / (N_sma_max - 2)
159
+ )
160
+ / (1 - beta1 ** state["step"])
161
+ )
162
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
163
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
164
+ else:
165
+ step_size = group["lr"] / (1 - beta1 ** state["step"])
166
+ p_data_fp32.add_(-step_size, exp_avg)
167
+
168
+ p.data.copy_(p_data_fp32)
169
+
170
+ return loss
PreTrain_MeDSLIP/optim/rmsprop_tf.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RMSProp modified to behave like Tensorflow impl
2
+
3
+ Originally cut & paste from PyTorch RMSProp
4
+ https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5
+ Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6
+
7
+ Modifications Copyright 2020 Ross Wightman
8
+ """
9
+
10
+ import torch
11
+ from torch.optim import Optimizer
12
+
13
+
14
+ class RMSpropTF(Optimizer):
15
+ """Implements RMSprop algorithm (TensorFlow style epsilon)
16
+
17
+ NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18
+ and a few other modifications to closer match Tensorflow for matching hyper-params.
19
+
20
+ Noteworthy changes include:
21
+ 1. Epsilon applied inside square-root
22
+ 2. square_avg initialized to ones
23
+ 3. LR scaling of update accumulated in momentum buffer
24
+
25
+ Proposed by G. Hinton in his
26
+ `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
27
+
28
+ The centered version first appears in `Generating Sequences
29
+ With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
30
+
31
+ Arguments:
32
+ params (iterable): iterable of parameters to optimize or dicts defining
33
+ parameter groups
34
+ lr (float, optional): learning rate (default: 1e-2)
35
+ momentum (float, optional): momentum factor (default: 0)
36
+ alpha (float, optional): smoothing (decay) constant (default: 0.9)
37
+ eps (float, optional): term added to the denominator to improve
38
+ numerical stability (default: 1e-10)
39
+ centered (bool, optional) : if ``True``, compute the centered RMSProp,
40
+ the gradient is normalized by an estimation of its variance
41
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42
+ decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43
+ lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44
+ update as per defaults in Tensorflow
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ params,
51
+ lr=1e-2,
52
+ alpha=0.9,
53
+ eps=1e-10,
54
+ weight_decay=0,
55
+ momentum=0.0,
56
+ centered=False,
57
+ decoupled_decay=False,
58
+ lr_in_momentum=True,
59
+ ):
60
+ if not 0.0 <= lr:
61
+ raise ValueError("Invalid learning rate: {}".format(lr))
62
+ if not 0.0 <= eps:
63
+ raise ValueError("Invalid epsilon value: {}".format(eps))
64
+ if not 0.0 <= momentum:
65
+ raise ValueError("Invalid momentum value: {}".format(momentum))
66
+ if not 0.0 <= weight_decay:
67
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
68
+ if not 0.0 <= alpha:
69
+ raise ValueError("Invalid alpha value: {}".format(alpha))
70
+
71
+ defaults = dict(
72
+ lr=lr,
73
+ momentum=momentum,
74
+ alpha=alpha,
75
+ eps=eps,
76
+ centered=centered,
77
+ weight_decay=weight_decay,
78
+ decoupled_decay=decoupled_decay,
79
+ lr_in_momentum=lr_in_momentum,
80
+ )
81
+ super(RMSpropTF, self).__init__(params, defaults)
82
+
83
+ def __setstate__(self, state):
84
+ super(RMSpropTF, self).__setstate__(state)
85
+ for group in self.param_groups:
86
+ group.setdefault("momentum", 0)
87
+ group.setdefault("centered", False)
88
+
89
+ def step(self, closure=None):
90
+ """Performs a single optimization step.
91
+
92
+ Arguments:
93
+ closure (callable, optional): A closure that reevaluates the model
94
+ and returns the loss.
95
+ """
96
+ loss = None
97
+ if closure is not None:
98
+ loss = closure()
99
+
100
+ for group in self.param_groups:
101
+ for p in group["params"]:
102
+ if p.grad is None:
103
+ continue
104
+ grad = p.grad.data
105
+ if grad.is_sparse:
106
+ raise RuntimeError("RMSprop does not support sparse gradients")
107
+ state = self.state[p]
108
+
109
+ # State initialization
110
+ if len(state) == 0:
111
+ state["step"] = 0
112
+ state["square_avg"] = torch.ones_like(
113
+ p.data
114
+ ) # PyTorch inits to zero
115
+ if group["momentum"] > 0:
116
+ state["momentum_buffer"] = torch.zeros_like(p.data)
117
+ if group["centered"]:
118
+ state["grad_avg"] = torch.zeros_like(p.data)
119
+
120
+ square_avg = state["square_avg"]
121
+ one_minus_alpha = 1.0 - group["alpha"]
122
+
123
+ state["step"] += 1
124
+
125
+ if group["weight_decay"] != 0:
126
+ if "decoupled_decay" in group and group["decoupled_decay"]:
127
+ p.data.add_(-group["weight_decay"], p.data)
128
+ else:
129
+ grad = grad.add(group["weight_decay"], p.data)
130
+
131
+ # Tensorflow order of ops for updating squared avg
132
+ square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
133
+ # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
134
+
135
+ if group["centered"]:
136
+ grad_avg = state["grad_avg"]
137
+ grad_avg.add_(one_minus_alpha, grad - grad_avg)
138
+ # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
139
+ avg = (
140
+ square_avg.addcmul(-1, grad_avg, grad_avg)
141
+ .add(group["eps"])
142
+ .sqrt_()
143
+ ) # eps moved in sqrt
144
+ else:
145
+ avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
146
+
147
+ if group["momentum"] > 0:
148
+ buf = state["momentum_buffer"]
149
+ # Tensorflow accumulates the LR scaling in the momentum buffer
150
+ if "lr_in_momentum" in group and group["lr_in_momentum"]:
151
+ buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
152
+ p.data.add_(-buf)
153
+ else:
154
+ # PyTorch scales the param update by LR
155
+ buf.mul_(group["momentum"]).addcdiv_(grad, avg)
156
+ p.data.add_(-group["lr"], buf)
157
+ else:
158
+ p.data.addcdiv_(-group["lr"], grad, avg)
159
+
160
+ return loss
PreTrain_MeDSLIP/optim/sgdp.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3
+
4
+ Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5
+ Code: https://github.com/clovaai/AdamP
6
+
7
+ Copyright (c) 2020-present NAVER Corp.
8
+ MIT license
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.optim.optimizer import Optimizer, required
14
+ import math
15
+
16
+
17
+ class SGDP(Optimizer):
18
+ def __init__(
19
+ self,
20
+ params,
21
+ lr=required,
22
+ momentum=0,
23
+ dampening=0,
24
+ weight_decay=0,
25
+ nesterov=False,
26
+ eps=1e-8,
27
+ delta=0.1,
28
+ wd_ratio=0.1,
29
+ ):
30
+ defaults = dict(
31
+ lr=lr,
32
+ momentum=momentum,
33
+ dampening=dampening,
34
+ weight_decay=weight_decay,
35
+ nesterov=nesterov,
36
+ eps=eps,
37
+ delta=delta,
38
+ wd_ratio=wd_ratio,
39
+ )
40
+ super(SGDP, self).__init__(params, defaults)
41
+
42
+ def _channel_view(self, x):
43
+ return x.view(x.size(0), -1)
44
+
45
+ def _layer_view(self, x):
46
+ return x.view(1, -1)
47
+
48
+ def _cosine_similarity(self, x, y, eps, view_func):
49
+ x = view_func(x)
50
+ y = view_func(y)
51
+
52
+ x_norm = x.norm(dim=1).add_(eps)
53
+ y_norm = y.norm(dim=1).add_(eps)
54
+ dot = (x * y).sum(dim=1)
55
+
56
+ return dot.abs() / x_norm / y_norm
57
+
58
+ def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
59
+ wd = 1
60
+ expand_size = [-1] + [1] * (len(p.shape) - 1)
61
+ for view_func in [self._channel_view, self._layer_view]:
62
+
63
+ cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
64
+
65
+ if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
66
+ p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
67
+ perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
68
+ wd = wd_ratio
69
+
70
+ return perturb, wd
71
+
72
+ return perturb, wd
73
+
74
+ def step(self, closure=None):
75
+ loss = None
76
+ if closure is not None:
77
+ loss = closure()
78
+
79
+ for group in self.param_groups:
80
+ weight_decay = group["weight_decay"]
81
+ momentum = group["momentum"]
82
+ dampening = group["dampening"]
83
+ nesterov = group["nesterov"]
84
+
85
+ for p in group["params"]:
86
+ if p.grad is None:
87
+ continue
88
+ grad = p.grad.data
89
+ state = self.state[p]
90
+
91
+ # State initialization
92
+ if len(state) == 0:
93
+ state["momentum"] = torch.zeros_like(p.data)
94
+
95
+ # SGD
96
+ buf = state["momentum"]
97
+ buf.mul_(momentum).add_(1 - dampening, grad)
98
+ if nesterov:
99
+ d_p = grad + momentum * buf
100
+ else:
101
+ d_p = buf
102
+
103
+ # Projection
104
+ wd_ratio = 1
105
+ if len(p.shape) > 1:
106
+ d_p, wd_ratio = self._projection(
107
+ p, grad, d_p, group["delta"], group["wd_ratio"], group["eps"]
108
+ )
109
+
110
+ # Weight decay
111
+ if weight_decay != 0:
112
+ p.data.mul_(
113
+ 1
114
+ - group["lr"]
115
+ * group["weight_decay"]
116
+ * wd_ratio
117
+ / (1 - momentum)
118
+ )
119
+
120
+ # Step
121
+ p.data.add_(-group["lr"], d_p)
122
+
123
+ return loss
PreTrain_MeDSLIP/scheduler/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cosine_lr import CosineLRScheduler
2
+ from .plateau_lr import PlateauLRScheduler
3
+ from .step_lr import StepLRScheduler
4
+ from .tanh_lr import TanhLRScheduler
5
+ from .scheduler_factory import create_scheduler
PreTrain_MeDSLIP/scheduler/cosine_lr.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Cosine Scheduler
2
+
3
+ Cosine LR schedule with warmup, cycle/restarts, noise.
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import logging
8
+ import math
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .scheduler import Scheduler
13
+
14
+ from pdb import set_trace as breakpoint
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CosineLRScheduler(Scheduler):
20
+ """
21
+ Cosine decay with restarts.
22
+ This is described in the paper https://arxiv.org/abs/1608.03983.
23
+
24
+ Inspiration from
25
+ https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ optimizer: torch.optim.Optimizer,
31
+ t_initial: int,
32
+ t_mul: float = 1.0,
33
+ lr_min: float = 0.0,
34
+ decay_rate: float = 1.0,
35
+ warmup_t=0,
36
+ warmup_lr_init=0,
37
+ warmup_prefix=True,
38
+ cycle_limit=0,
39
+ t_in_epochs=True,
40
+ noise_range_t=None,
41
+ noise_pct=0.67,
42
+ noise_std=1.0,
43
+ noise_seed=42,
44
+ initialize=True,
45
+ ) -> None:
46
+ super().__init__(
47
+ optimizer,
48
+ param_group_field="lr",
49
+ noise_range_t=noise_range_t,
50
+ noise_pct=noise_pct,
51
+ noise_std=noise_std,
52
+ noise_seed=noise_seed,
53
+ initialize=initialize,
54
+ )
55
+
56
+ assert t_initial > 0
57
+ assert lr_min >= 0
58
+ if t_initial == 1 and t_mul == 1 and decay_rate == 1:
59
+ _logger.warning(
60
+ "Cosine annealing scheduler will have no effect on the learning "
61
+ "rate since t_initial = t_mul = eta_mul = 1."
62
+ )
63
+ self.t_initial = t_initial
64
+ self.t_mul = t_mul
65
+ self.lr_min = lr_min
66
+ self.decay_rate = decay_rate
67
+ self.cycle_limit = cycle_limit
68
+ self.warmup_t = warmup_t
69
+ self.warmup_lr_init = warmup_lr_init
70
+ self.warmup_prefix = warmup_prefix
71
+ self.t_in_epochs = t_in_epochs
72
+ if self.warmup_t:
73
+ self.warmup_steps = [
74
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
75
+ ]
76
+ super().update_groups(self.warmup_lr_init)
77
+ else:
78
+ self.warmup_steps = [1 for _ in self.base_values]
79
+
80
+ def _get_lr(self, t):
81
+ if t < self.warmup_t:
82
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
83
+ else:
84
+ if self.warmup_prefix:
85
+ t = t - self.warmup_t
86
+
87
+ if self.t_mul != 1:
88
+ i = math.floor(
89
+ math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)
90
+ )
91
+ t_i = self.t_mul ** i * self.t_initial
92
+ t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
93
+ else:
94
+ i = t // self.t_initial
95
+ t_i = self.t_initial
96
+ t_curr = t - (self.t_initial * i)
97
+
98
+ gamma = self.decay_rate ** i
99
+ lr_min = self.lr_min * gamma
100
+ lr_max_values = [v * gamma for v in self.base_values]
101
+
102
+ if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
103
+ lrs = [
104
+ lr_min
105
+ + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i))
106
+ for lr_max in lr_max_values
107
+ ]
108
+ else:
109
+ lrs = [self.lr_min for _ in self.base_values]
110
+
111
+ return lrs
112
+
113
+ def get_epoch_values(self, epoch: int):
114
+ if self.t_in_epochs:
115
+ return self._get_lr(epoch)
116
+ else:
117
+ return None
118
+
119
+ def get_update_values(self, num_updates: int):
120
+ if not self.t_in_epochs:
121
+ return self._get_lr(num_updates)
122
+ else:
123
+ return None
124
+
125
+ def get_cycle_length(self, cycles=0):
126
+ if not cycles:
127
+ cycles = self.cycle_limit
128
+ cycles = max(1, cycles)
129
+ if self.t_mul == 1.0:
130
+ return self.t_initial * cycles
131
+ else:
132
+ return int(
133
+ math.floor(
134
+ -self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)
135
+ )
136
+ )
PreTrain_MeDSLIP/scheduler/plateau_lr.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Plateau Scheduler
2
+
3
+ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import torch
8
+
9
+ from .scheduler import Scheduler
10
+
11
+
12
+ class PlateauLRScheduler(Scheduler):
13
+ """Decay the LR by a factor every time the validation loss plateaus."""
14
+
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ decay_rate=0.1,
19
+ patience_t=10,
20
+ verbose=True,
21
+ threshold=1e-4,
22
+ cooldown_t=0,
23
+ warmup_t=0,
24
+ warmup_lr_init=0,
25
+ lr_min=0,
26
+ mode="max",
27
+ noise_range_t=None,
28
+ noise_type="normal",
29
+ noise_pct=0.67,
30
+ noise_std=1.0,
31
+ noise_seed=None,
32
+ initialize=True,
33
+ ):
34
+ super().__init__(optimizer, "lr", initialize=initialize)
35
+
36
+ self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
37
+ self.optimizer,
38
+ patience=patience_t,
39
+ factor=decay_rate,
40
+ verbose=verbose,
41
+ threshold=threshold,
42
+ cooldown=cooldown_t,
43
+ mode=mode,
44
+ min_lr=lr_min,
45
+ )
46
+
47
+ self.noise_range = noise_range_t
48
+ self.noise_pct = noise_pct
49
+ self.noise_type = noise_type
50
+ self.noise_std = noise_std
51
+ self.noise_seed = noise_seed if noise_seed is not None else 42
52
+ self.warmup_t = warmup_t
53
+ self.warmup_lr_init = warmup_lr_init
54
+ if self.warmup_t:
55
+ self.warmup_steps = [
56
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
57
+ ]
58
+ super().update_groups(self.warmup_lr_init)
59
+ else:
60
+ self.warmup_steps = [1 for _ in self.base_values]
61
+ self.restore_lr = None
62
+
63
+ def state_dict(self):
64
+ return {
65
+ "best": self.lr_scheduler.best,
66
+ "last_epoch": self.lr_scheduler.last_epoch,
67
+ }
68
+
69
+ def load_state_dict(self, state_dict):
70
+ self.lr_scheduler.best = state_dict["best"]
71
+ if "last_epoch" in state_dict:
72
+ self.lr_scheduler.last_epoch = state_dict["last_epoch"]
73
+
74
+ # override the base class step fn completely
75
+ def step(self, epoch, metric=None):
76
+ if epoch <= self.warmup_t:
77
+ lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
78
+ super().update_groups(lrs)
79
+ else:
80
+ if self.restore_lr is not None:
81
+ # restore actual LR from before our last noise perturbation before stepping base
82
+ for i, param_group in enumerate(self.optimizer.param_groups):
83
+ param_group["lr"] = self.restore_lr[i]
84
+ self.restore_lr = None
85
+
86
+ self.lr_scheduler.step(metric, epoch) # step the base scheduler
87
+
88
+ if self.noise_range is not None:
89
+ if isinstance(self.noise_range, (list, tuple)):
90
+ apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
91
+ else:
92
+ apply_noise = epoch >= self.noise_range
93
+ if apply_noise:
94
+ self._apply_noise(epoch)
95
+
96
+ def _apply_noise(self, epoch):
97
+ g = torch.Generator()
98
+ g.manual_seed(self.noise_seed + epoch)
99
+ if self.noise_type == "normal":
100
+ while True:
101
+ # resample if noise out of percent limit, brute force but shouldn't spin much
102
+ noise = torch.randn(1, generator=g).item()
103
+ if abs(noise) < self.noise_pct:
104
+ break
105
+ else:
106
+ noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
107
+
108
+ # apply the noise on top of previous LR, cache the old value so we can restore for normal
109
+ # stepping of base scheduler
110
+ restore_lr = []
111
+ for i, param_group in enumerate(self.optimizer.param_groups):
112
+ old_lr = float(param_group["lr"])
113
+ restore_lr.append(old_lr)
114
+ new_lr = old_lr + old_lr * noise
115
+ param_group["lr"] = new_lr
116
+ self.restore_lr = restore_lr
PreTrain_MeDSLIP/scheduler/scheduler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ import torch
4
+
5
+
6
+ class Scheduler:
7
+ """ Parameter Scheduler Base Class
8
+ A scheduler base class that can be used to schedule any optimizer parameter groups.
9
+
10
+ Unlike the builtin PyTorch schedulers, this is intended to be consistently called
11
+ * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
12
+ * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
13
+
14
+ The schedulers built on this should try to remain as stateless as possible (for simplicity).
15
+
16
+ This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
17
+ and -1 values for special behaviour. All epoch and update counts must be tracked in the training
18
+ code and explicitly passed in to the schedulers on the corresponding step or step_update call.
19
+
20
+ Based on ideas from:
21
+ * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
22
+ * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ optimizer: torch.optim.Optimizer,
28
+ param_group_field: str,
29
+ noise_range_t=None,
30
+ noise_type="normal",
31
+ noise_pct=0.67,
32
+ noise_std=1.0,
33
+ noise_seed=None,
34
+ initialize: bool = True,
35
+ ) -> None:
36
+ self.optimizer = optimizer
37
+ self.param_group_field = param_group_field
38
+ self._initial_param_group_field = f"initial_{param_group_field}"
39
+ if initialize:
40
+ for i, group in enumerate(self.optimizer.param_groups):
41
+ if param_group_field not in group:
42
+ raise KeyError(
43
+ f"{param_group_field} missing from param_groups[{i}]"
44
+ )
45
+ group.setdefault(
46
+ self._initial_param_group_field, group[param_group_field]
47
+ )
48
+ else:
49
+ for i, group in enumerate(self.optimizer.param_groups):
50
+ if self._initial_param_group_field not in group:
51
+ raise KeyError(
52
+ f"{self._initial_param_group_field} missing from param_groups[{i}]"
53
+ )
54
+ self.base_values = [
55
+ group[self._initial_param_group_field]
56
+ for group in self.optimizer.param_groups
57
+ ]
58
+ self.metric = None # any point to having this for all?
59
+ self.noise_range_t = noise_range_t
60
+ self.noise_pct = noise_pct
61
+ self.noise_type = noise_type
62
+ self.noise_std = noise_std
63
+ self.noise_seed = noise_seed if noise_seed is not None else 42
64
+ self.update_groups(self.base_values)
65
+
66
+ def state_dict(self) -> Dict[str, Any]:
67
+ return {
68
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
69
+ }
70
+
71
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
72
+ self.__dict__.update(state_dict)
73
+
74
+ def get_epoch_values(self, epoch: int):
75
+ return None
76
+
77
+ def get_update_values(self, num_updates: int):
78
+ return None
79
+
80
+ def step(self, epoch: int, metric: float = None) -> None:
81
+ self.metric = metric
82
+ values = self.get_epoch_values(epoch)
83
+ if values is not None:
84
+ values = self._add_noise(values, epoch)
85
+ self.update_groups(values)
86
+
87
+ def step_update(self, num_updates: int, metric: float = None):
88
+ self.metric = metric
89
+ values = self.get_update_values(num_updates)
90
+ if values is not None:
91
+ values = self._add_noise(values, num_updates)
92
+ self.update_groups(values)
93
+
94
+ def update_groups(self, values):
95
+ if not isinstance(values, (list, tuple)):
96
+ values = [values] * len(self.optimizer.param_groups)
97
+ for param_group, value in zip(self.optimizer.param_groups, values):
98
+ param_group[self.param_group_field] = value
99
+
100
+ def _add_noise(self, lrs, t):
101
+ if self.noise_range_t is not None:
102
+ if isinstance(self.noise_range_t, (list, tuple)):
103
+ apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
104
+ else:
105
+ apply_noise = t >= self.noise_range_t
106
+ if apply_noise:
107
+ g = torch.Generator()
108
+ g.manual_seed(self.noise_seed + t)
109
+ if self.noise_type == "normal":
110
+ while True:
111
+ # resample if noise out of percent limit, brute force but shouldn't spin much
112
+ noise = torch.randn(1, generator=g).item()
113
+ if abs(noise) < self.noise_pct:
114
+ break
115
+ else:
116
+ noise = (
117
+ 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
118
+ )
119
+ lrs = [v + v * noise for v in lrs]
120
+ return lrs
PreTrain_MeDSLIP/scheduler/scheduler_factory.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Scheduler Factory
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ from .cosine_lr import CosineLRScheduler
5
+ from .tanh_lr import TanhLRScheduler
6
+ from .step_lr import StepLRScheduler
7
+ from .plateau_lr import PlateauLRScheduler
8
+
9
+
10
+ def create_scheduler(args, optimizer):
11
+ num_epochs = args.epochs
12
+
13
+ if getattr(args, "lr_noise", None) is not None:
14
+ lr_noise = getattr(args, "lr_noise")
15
+ if isinstance(lr_noise, (list, tuple)):
16
+ noise_range = [n * num_epochs for n in lr_noise]
17
+ if len(noise_range) == 1:
18
+ noise_range = noise_range[0]
19
+ else:
20
+ noise_range = lr_noise * num_epochs
21
+ else:
22
+ noise_range = None
23
+
24
+ lr_scheduler = None
25
+ if args.sched == "cosine":
26
+ lr_scheduler = CosineLRScheduler(
27
+ optimizer,
28
+ t_initial=num_epochs,
29
+ t_mul=getattr(args, "lr_cycle_mul", 1.0),
30
+ lr_min=args.min_lr,
31
+ decay_rate=args.decay_rate,
32
+ warmup_lr_init=args.warmup_lr,
33
+ warmup_t=args.warmup_epochs,
34
+ cycle_limit=getattr(args, "lr_cycle_limit", 1),
35
+ t_in_epochs=True,
36
+ noise_range_t=noise_range,
37
+ noise_pct=getattr(args, "lr_noise_pct", 0.67),
38
+ noise_std=getattr(args, "lr_noise_std", 1.0),
39
+ noise_seed=getattr(args, "seed", 42),
40
+ )
41
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
42
+ elif args.sched == "tanh":
43
+ lr_scheduler = TanhLRScheduler(
44
+ optimizer,
45
+ t_initial=num_epochs,
46
+ t_mul=getattr(args, "lr_cycle_mul", 1.0),
47
+ lr_min=args.min_lr,
48
+ warmup_lr_init=args.warmup_lr,
49
+ warmup_t=args.warmup_epochs,
50
+ cycle_limit=getattr(args, "lr_cycle_limit", 1),
51
+ t_in_epochs=True,
52
+ noise_range_t=noise_range,
53
+ noise_pct=getattr(args, "lr_noise_pct", 0.67),
54
+ noise_std=getattr(args, "lr_noise_std", 1.0),
55
+ noise_seed=getattr(args, "seed", 42),
56
+ )
57
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
58
+ elif args.sched == "step":
59
+ lr_scheduler = StepLRScheduler(
60
+ optimizer,
61
+ decay_t=args.decay_epochs,
62
+ decay_rate=args.decay_rate,
63
+ warmup_lr_init=args.warmup_lr,
64
+ warmup_t=args.warmup_epochs,
65
+ noise_range_t=noise_range,
66
+ noise_pct=getattr(args, "lr_noise_pct", 0.67),
67
+ noise_std=getattr(args, "lr_noise_std", 1.0),
68
+ noise_seed=getattr(args, "seed", 42),
69
+ )
70
+ elif args.sched == "plateau":
71
+ mode = "min" if "loss" in getattr(args, "eval_metric", "") else "max"
72
+ lr_scheduler = PlateauLRScheduler(
73
+ optimizer,
74
+ decay_rate=args.decay_rate,
75
+ patience_t=args.patience_epochs,
76
+ lr_min=args.min_lr,
77
+ mode=mode,
78
+ warmup_lr_init=args.warmup_lr,
79
+ warmup_t=args.warmup_epochs,
80
+ cooldown_t=0,
81
+ noise_range_t=noise_range,
82
+ noise_pct=getattr(args, "lr_noise_pct", 0.67),
83
+ noise_std=getattr(args, "lr_noise_std", 1.0),
84
+ noise_seed=getattr(args, "seed", 42),
85
+ )
86
+
87
+ return lr_scheduler, num_epochs
PreTrain_MeDSLIP/scheduler/step_lr.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Step Scheduler
2
+
3
+ Basic step LR schedule with warmup, noise.
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import math
8
+ import torch
9
+
10
+ from .scheduler import Scheduler
11
+
12
+
13
+ class StepLRScheduler(Scheduler):
14
+ """
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ optimizer: torch.optim.Optimizer,
20
+ decay_t: float,
21
+ decay_rate: float = 1.0,
22
+ warmup_t=0,
23
+ warmup_lr_init=0,
24
+ t_in_epochs=True,
25
+ noise_range_t=None,
26
+ noise_pct=0.67,
27
+ noise_std=1.0,
28
+ noise_seed=42,
29
+ initialize=True,
30
+ ) -> None:
31
+ super().__init__(
32
+ optimizer,
33
+ param_group_field="lr",
34
+ noise_range_t=noise_range_t,
35
+ noise_pct=noise_pct,
36
+ noise_std=noise_std,
37
+ noise_seed=noise_seed,
38
+ initialize=initialize,
39
+ )
40
+
41
+ self.decay_t = decay_t
42
+ self.decay_rate = decay_rate
43
+ self.warmup_t = warmup_t
44
+ self.warmup_lr_init = warmup_lr_init
45
+ self.t_in_epochs = t_in_epochs
46
+ if self.warmup_t:
47
+ self.warmup_steps = [
48
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
49
+ ]
50
+ super().update_groups(self.warmup_lr_init)
51
+ else:
52
+ self.warmup_steps = [1 for _ in self.base_values]
53
+
54
+ def _get_lr(self, t):
55
+ if t < self.warmup_t:
56
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
57
+ else:
58
+ lrs = [
59
+ v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values
60
+ ]
61
+ return lrs
62
+
63
+ def get_epoch_values(self, epoch: int):
64
+ if self.t_in_epochs:
65
+ return self._get_lr(epoch)
66
+ else:
67
+ return None
68
+
69
+ def get_update_values(self, num_updates: int):
70
+ if not self.t_in_epochs:
71
+ return self._get_lr(num_updates)
72
+ else:
73
+ return None
PreTrain_MeDSLIP/scheduler/tanh_lr.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ TanH Scheduler
2
+
3
+ TanH schedule with warmup, cycle/restarts, noise.
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import logging
8
+ import math
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .scheduler import Scheduler
13
+
14
+
15
+ _logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TanhLRScheduler(Scheduler):
19
+ """
20
+ Hyberbolic-Tangent decay with restarts.
21
+ This is described in the paper https://arxiv.org/abs/1806.01593
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ optimizer: torch.optim.Optimizer,
27
+ t_initial: int,
28
+ lb: float = -6.0,
29
+ ub: float = 4.0,
30
+ t_mul: float = 1.0,
31
+ lr_min: float = 0.0,
32
+ decay_rate: float = 1.0,
33
+ warmup_t=0,
34
+ warmup_lr_init=0,
35
+ warmup_prefix=False,
36
+ cycle_limit=0,
37
+ t_in_epochs=True,
38
+ noise_range_t=None,
39
+ noise_pct=0.67,
40
+ noise_std=1.0,
41
+ noise_seed=42,
42
+ initialize=True,
43
+ ) -> None:
44
+ super().__init__(
45
+ optimizer,
46
+ param_group_field="lr",
47
+ noise_range_t=noise_range_t,
48
+ noise_pct=noise_pct,
49
+ noise_std=noise_std,
50
+ noise_seed=noise_seed,
51
+ initialize=initialize,
52
+ )
53
+
54
+ assert t_initial > 0
55
+ assert lr_min >= 0
56
+ assert lb < ub
57
+ assert cycle_limit >= 0
58
+ assert warmup_t >= 0
59
+ assert warmup_lr_init >= 0
60
+ self.lb = lb
61
+ self.ub = ub
62
+ self.t_initial = t_initial
63
+ self.t_mul = t_mul
64
+ self.lr_min = lr_min
65
+ self.decay_rate = decay_rate
66
+ self.cycle_limit = cycle_limit
67
+ self.warmup_t = warmup_t
68
+ self.warmup_lr_init = warmup_lr_init
69
+ self.warmup_prefix = warmup_prefix
70
+ self.t_in_epochs = t_in_epochs
71
+ if self.warmup_t:
72
+ t_v = (
73
+ self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
74
+ )
75
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
76
+ super().update_groups(self.warmup_lr_init)
77
+ else:
78
+ self.warmup_steps = [1 for _ in self.base_values]
79
+
80
+ def _get_lr(self, t):
81
+ if t < self.warmup_t:
82
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
83
+ else:
84
+ if self.warmup_prefix:
85
+ t = t - self.warmup_t
86
+
87
+ if self.t_mul != 1:
88
+ i = math.floor(
89
+ math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)
90
+ )
91
+ t_i = self.t_mul ** i * self.t_initial
92
+ t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
93
+ else:
94
+ i = t // self.t_initial
95
+ t_i = self.t_initial
96
+ t_curr = t - (self.t_initial * i)
97
+
98
+ if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
99
+ gamma = self.decay_rate ** i
100
+ lr_min = self.lr_min * gamma
101
+ lr_max_values = [v * gamma for v in self.base_values]
102
+
103
+ tr = t_curr / t_i
104
+ lrs = [
105
+ lr_min
106
+ + 0.5
107
+ * (lr_max - lr_min)
108
+ * (1 - math.tanh(self.lb * (1.0 - tr) + self.ub * tr))
109
+ for lr_max in lr_max_values
110
+ ]
111
+ else:
112
+ lrs = [
113
+ self.lr_min * (self.decay_rate ** self.cycle_limit)
114
+ for _ in self.base_values
115
+ ]
116
+ return lrs
117
+
118
+ def get_epoch_values(self, epoch: int):
119
+ if self.t_in_epochs:
120
+ return self._get_lr(epoch)
121
+ else:
122
+ return None
123
+
124
+ def get_update_values(self, num_updates: int):
125
+ if not self.t_in_epochs:
126
+ return self._get_lr(num_updates)
127
+ else:
128
+ return None
129
+
130
+ def get_cycle_length(self, cycles=0):
131
+ if not cycles:
132
+ cycles = self.cycle_limit
133
+ cycles = max(1, cycles)
134
+ if self.t_mul == 1.0:
135
+ return self.t_initial * cycles
136
+ else:
137
+ return int(
138
+ math.floor(
139
+ -self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)
140
+ )
141
+ )
PreTrain_MeDSLIP/train_MeDSLIP.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import ruamel_yaml as yaml
4
+ import numpy as np
5
+ import random
6
+ import time
7
+ import datetime
8
+ import json
9
+ from pathlib import Path
10
+ import warnings
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader
18
+ import torch.backends.cudnn as cudnn
19
+
20
+ from tensorboardX import SummaryWriter
21
+
22
+ import utils
23
+ from scheduler import create_scheduler
24
+ from optim import create_optimizer
25
+ from dataset.dataset import MeDSLIP_Dataset
26
+ from models.model_MeDSLIP import MeDSLIP
27
+ from models.tokenization_bert import BertTokenizer
28
+
29
+
30
+ def get_tokenizer(tokenizer, target_text):
31
+
32
+ target_tokenizer = tokenizer(
33
+ list(target_text),
34
+ padding="max_length",
35
+ truncation=True,
36
+ max_length=128,
37
+ return_tensors="pt",
38
+ )
39
+
40
+ return target_tokenizer
41
+
42
+
43
+ def train(
44
+ model,
45
+ data_loader,
46
+ optimizer,
47
+ epoch,
48
+ warmup_steps,
49
+ device,
50
+ scheduler,
51
+ args,
52
+ config,
53
+ writer,
54
+ ):
55
+ model.train()
56
+ metric_logger = utils.MetricLogger(delimiter=" ")
57
+ metric_logger.add_meter(
58
+ "lr", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
59
+ )
60
+ metric_logger.add_meter(
61
+ "loss", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
62
+ )
63
+ metric_logger.add_meter(
64
+ "loss_ce_p", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
65
+ )
66
+ metric_logger.add_meter(
67
+ "loss_cl_p", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
68
+ )
69
+ metric_logger.add_meter(
70
+ "loss_ce_a", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
71
+ )
72
+ metric_logger.add_meter(
73
+ "loss_cl_a", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
74
+ )
75
+ metric_logger.add_meter(
76
+ "loss_ap", utils.SmoothedValue(window_size=50, fmt="{value:.6f}")
77
+ )
78
+ metric_logger.update(loss=1.0)
79
+ metric_logger.update(loss_ce_p=1.0)
80
+ metric_logger.update(loss_cl_p=1.0)
81
+ metric_logger.update(loss_ce_a=1.0)
82
+ metric_logger.update(loss_cl_a=1.0)
83
+ metric_logger.update(loss_ap=1.0)
84
+ metric_logger.update(lr=scheduler._get_lr(epoch)[0])
85
+
86
+ header = "Train Epoch: [{}]".format(epoch)
87
+ print_freq = 1
88
+ step_size = 100
89
+ warmup_iterations = warmup_steps * step_size
90
+ scalar_step = epoch * len(data_loader)
91
+
92
+ for i, sample in enumerate(
93
+ metric_logger.log_every(data_loader, print_freq, header)
94
+ ):
95
+
96
+ images = sample["image"].to(device)
97
+ labels_pathology = sample["label_pathology"].to(device)
98
+ labels_anatomy = sample["label_anatomy"].to(device)
99
+ index_pathology = sample["index_pathology"].to(device)
100
+ index_anatomy = sample["index_anatomy"].to(device)
101
+ matrix = sample["matrix"].to(device)
102
+
103
+ optimizer.zero_grad()
104
+
105
+ (
106
+ loss,
107
+ loss_ce_pathology,
108
+ loss_cl_pathology,
109
+ loss_ce_anatomy,
110
+ loss_cl_anatomy,
111
+ loss_ap,
112
+ ) = model(
113
+ images,
114
+ labels_pathology=labels_pathology,
115
+ labels_anatomy=labels_anatomy,
116
+ matrix=matrix,
117
+ sample_index_pathology=index_pathology,
118
+ sample_index_anatomy=index_anatomy,
119
+ is_train=True,
120
+ no_cl=config["no_cl"],
121
+ exclude_class=config["exclude_class"],
122
+ )
123
+ loss.backward()
124
+ optimizer.step()
125
+ writer.add_scalar("loss/loss", loss, scalar_step)
126
+ writer.add_scalar("loss/loss_ce_pathology", loss_ce_pathology, scalar_step)
127
+ writer.add_scalar("loss/loss_cl_pathology", loss_cl_pathology, scalar_step)
128
+ writer.add_scalar("loss/loss_ce_anatomy", loss_ce_anatomy, scalar_step)
129
+ writer.add_scalar("loss/loss_cl_anatomy", loss_cl_anatomy, scalar_step)
130
+ writer.add_scalar("loss/loss_ap", loss_ap, scalar_step)
131
+ scalar_step += 1
132
+ metric_logger.update(loss_ce_p=loss_ce_pathology.item())
133
+ metric_logger.update(loss_cl_p=loss_cl_pathology.item())
134
+ metric_logger.update(loss_ce_a=loss_ce_anatomy.item())
135
+ metric_logger.update(loss_cl_a=loss_cl_anatomy.item())
136
+ metric_logger.update(loss_ap=loss_ap.item())
137
+ metric_logger.update(loss=loss.item())
138
+ # metric_logger.update(loss_cl=loss_cl.item())
139
+ if epoch == 0 and i % step_size == 0 and i <= warmup_iterations:
140
+ scheduler.step(i // step_size)
141
+ metric_logger.update(lr=scheduler._get_lr(epoch)[0])
142
+
143
+ # gather the stats from all processes
144
+ metric_logger.synchronize_between_processes()
145
+ print("Averaged stats:", metric_logger.global_avg())
146
+ return {
147
+ k: "{:.3f}".format(meter.global_avg)
148
+ for k, meter in metric_logger.meters.items()
149
+ }
150
+
151
+
152
+ def valid(model, data_loader, epoch, device, config, writer):
153
+ model.eval()
154
+ val_scalar_step = epoch * len(data_loader)
155
+ val_loss = []
156
+ for i, sample in enumerate(data_loader):
157
+
158
+ images = sample["image"].to(device)
159
+ labels_pathology = sample["label_pathology"].to(device)
160
+ labels_anatomy = sample["label_anatomy"].to(device)
161
+ index_pathology = sample["index_pathology"].to(device)
162
+ index_anatomy = sample["index_anatomy"].to(device)
163
+ matrix = sample["matrix"].to(device)
164
+
165
+ with torch.no_grad():
166
+ (
167
+ loss,
168
+ loss_ce_pathology,
169
+ loss_cl_pathology,
170
+ loss_ce_anatomy,
171
+ loss_cl_anatomy,
172
+ loss_ap,
173
+ ) = model(
174
+ images,
175
+ labels_pathology=labels_pathology,
176
+ labels_anatomy=labels_anatomy,
177
+ matrix=matrix,
178
+ sample_index_pathology=index_pathology,
179
+ sample_index_anatomy=index_anatomy,
180
+ is_train=True,
181
+ no_cl=config["no_cl"],
182
+ exclude_class=config["exclude_class"],
183
+ )
184
+ val_loss.append(loss.item())
185
+ writer.add_scalar("val_loss/loss", loss, val_scalar_step)
186
+ writer.add_scalar(
187
+ "val_loss/loss_ce_pathology", loss_ce_pathology, val_scalar_step
188
+ )
189
+ writer.add_scalar(
190
+ "val_loss/loss_cl_pathology", loss_cl_pathology, val_scalar_step
191
+ )
192
+ writer.add_scalar(
193
+ "val_loss/loss_ce_anatomy", loss_ce_anatomy, val_scalar_step
194
+ )
195
+ writer.add_scalar(
196
+ "val_loss/loss_cl_anatomy", loss_cl_anatomy, val_scalar_step
197
+ )
198
+ writer.add_scalar("val_loss/loss_ap", loss_ap, val_scalar_step)
199
+ val_scalar_step += 1
200
+ avg_val_loss = np.array(val_loss).mean()
201
+ return avg_val_loss
202
+
203
+
204
+ def main(args, config):
205
+
206
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
+ if args.computing == "parallel":
208
+ world_size = torch.distributed.get_world_size()
209
+ rank = torch.distributed.get_rank()
210
+ device = torch.device("cuda", rank)
211
+ print("World size: ", world_size, "; Rank: ", rank)
212
+
213
+ print("Total CUDA devices: ", torch.cuda.device_count())
214
+ torch.set_default_tensor_type("torch.FloatTensor")
215
+ cudnn.benchmark = True
216
+
217
+ start_epoch = 0
218
+ max_epoch = config["schedular"]["epochs"]
219
+ warmup_steps = config["schedular"]["warmup_epochs"]
220
+
221
+ #### Dataset ####
222
+ print("Creating dataset")
223
+ train_datasets = MeDSLIP_Dataset(
224
+ config["train_file"], config["label_file"], mode="train"
225
+ )
226
+ val_datasets = MeDSLIP_Dataset(
227
+ config["valid_file"], config["label_file"], mode="train"
228
+ )
229
+ if args.computing == "parallel":
230
+ # shuffl
231
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
232
+ train_datasets, num_replicas=world_size, rank=rank, shuffle=True
233
+ )
234
+ val_sampler = torch.utils.data.distributed.DistributedSampler(
235
+ val_datasets, num_replicas=world_size, rank=rank, shuffle=True
236
+ )
237
+ else:
238
+ train_sampler = torch.utils.data.RandomSampler(train_datasets)
239
+ val_sampler = torch.utils.data.RandomSampler(val_datasets)
240
+ train_dataloader = DataLoader(
241
+ train_datasets,
242
+ batch_size=config["batch_size"],
243
+ num_workers=30,
244
+ pin_memory=True,
245
+ sampler=train_sampler,
246
+ collate_fn=None,
247
+ drop_last=True,
248
+ )
249
+
250
+ val_dataloader = DataLoader(
251
+ val_datasets,
252
+ batch_size=config["batch_size"],
253
+ num_workers=30,
254
+ pin_memory=True,
255
+ sampler=val_sampler,
256
+ collate_fn=None,
257
+ drop_last=True,
258
+ )
259
+
260
+ print("Creating book")
261
+ json_book = json.load(open(config["pathology_book"], "r"))
262
+ pathology_book = [json_book[i] for i in json_book]
263
+ anatomy_list = [
264
+ "trachea",
265
+ "left_hilar",
266
+ "right_hilar",
267
+ "hilar_unspec",
268
+ "left_pleural",
269
+ "right_pleural",
270
+ "pleural_unspec",
271
+ "heart_size",
272
+ "heart_border",
273
+ "left_diaphragm",
274
+ "right_diaphragm",
275
+ "diaphragm_unspec",
276
+ "retrocardiac",
277
+ "lower_left_lobe",
278
+ "upper_left_lobe",
279
+ "lower_right_lobe",
280
+ "middle_right_lobe",
281
+ "upper_right_lobe",
282
+ "left_lower_lung",
283
+ "left_mid_lung",
284
+ "left_upper_lung",
285
+ "left_apical_lung",
286
+ "left_lung_unspec",
287
+ "right_lower_lung",
288
+ "right_mid_lung",
289
+ "right_upper_lung",
290
+ "right_apical_lung",
291
+ "right_lung_unspec",
292
+ "lung_apices",
293
+ "lung_bases",
294
+ "left_costophrenic",
295
+ "right_costophrenic",
296
+ "costophrenic_unspec",
297
+ "cardiophrenic_sulcus",
298
+ "mediastinal",
299
+ "spine",
300
+ "clavicle",
301
+ "rib",
302
+ "stomach",
303
+ "right_atrium",
304
+ "right_ventricle",
305
+ "aorta",
306
+ "svc",
307
+ "interstitium",
308
+ "parenchymal",
309
+ "cavoatrial_junction",
310
+ "cardiopulmonary",
311
+ "pulmonary",
312
+ "lung_volumes",
313
+ "unspecified",
314
+ "other",
315
+ ]
316
+ anatomy_book = []
317
+ for i in anatomy_list:
318
+ anatomy_book.append("It is located at " + i + ". ")
319
+
320
+ tokenizer = BertTokenizer.from_pretrained(config["text_encoder"])
321
+ anatomy_book_tokenizer = get_tokenizer(tokenizer, anatomy_book).to(device)
322
+ pathology_book_tokenizer = get_tokenizer(tokenizer, pathology_book).to(device)
323
+ print("Creating model")
324
+ model = MeDSLIP(
325
+ config, anatomy_book_tokenizer, pathology_book_tokenizer, mode="train"
326
+ )
327
+ model = model.to(device)
328
+ if args.computing == "parallel":
329
+ model = nn.parallel.DistributedDataParallel(
330
+ model, device_ids=[rank], find_unused_parameters=True
331
+ )
332
+
333
+ arg_opt = utils.AttrDict(config["optimizer"])
334
+ optimizer = create_optimizer(arg_opt, model)
335
+ arg_sche = utils.AttrDict(config["schedular"])
336
+ lr_scheduler, _ = create_scheduler(arg_sche, optimizer)
337
+
338
+ if args.checkpoint:
339
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
340
+ state_dict = checkpoint["model"]
341
+ optimizer.load_state_dict(checkpoint["optimizer"])
342
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
343
+ start_epoch = checkpoint["epoch"] + 1
344
+ model.load_state_dict(state_dict)
345
+ print("load checkpoint from %s" % args.checkpoint)
346
+
347
+ print("Start training")
348
+ start_time = time.time()
349
+
350
+ writer = SummaryWriter(os.path.join(args.output_dir, "log"))
351
+ for epoch in range(start_epoch, max_epoch):
352
+ if epoch > 0:
353
+ lr_scheduler.step(epoch + warmup_steps)
354
+ train_stats = train(
355
+ model,
356
+ train_dataloader,
357
+ optimizer,
358
+ epoch,
359
+ warmup_steps,
360
+ device,
361
+ lr_scheduler,
362
+ args,
363
+ config,
364
+ writer,
365
+ )
366
+
367
+ for k, v in train_stats.items():
368
+ train_loss_epoch = v
369
+
370
+ writer.add_scalar("loss/train_loss_epoch", float(train_loss_epoch), epoch)
371
+ writer.add_scalar("loss/leaning_rate", lr_scheduler._get_lr(epoch)[0], epoch)
372
+
373
+ val_loss = valid(model, val_dataloader, epoch, device, config, writer)
374
+ writer.add_scalar("loss/val_loss_epoch", val_loss, epoch)
375
+
376
+ if utils.is_main_process():
377
+ log_stats = {
378
+ **{f"train_{k}": v for k, v in train_stats.items()},
379
+ "epoch": epoch,
380
+ "val_loss": val_loss.item(),
381
+ }
382
+ save_obj = {
383
+ "model": model.state_dict(),
384
+ "optimizer": optimizer.state_dict(),
385
+ "lr_scheduler": lr_scheduler.state_dict(),
386
+ "config": config,
387
+ "epoch": epoch,
388
+ }
389
+ torch.save(save_obj, os.path.join(args.output_dir, "checkpoint_state.pth"))
390
+
391
+ with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
392
+ f.write(json.dumps(log_stats) + "\n")
393
+
394
+ if epoch % 1 == 0 and epoch > 15:
395
+ save_obj = {
396
+ "model": model.state_dict(),
397
+ "optimizer": optimizer.state_dict(),
398
+ "lr_scheduler": lr_scheduler.state_dict(),
399
+ "config": config,
400
+ "epoch": epoch,
401
+ }
402
+ torch.save(
403
+ save_obj,
404
+ os.path.join(args.output_dir, "checkpoint_" + str(epoch) + ".pth"),
405
+ )
406
+
407
+ total_time = time.time() - start_time
408
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
409
+ print("Training time {}".format(total_time_str))
410
+
411
+
412
+ if __name__ == "__main__":
413
+ parser = argparse.ArgumentParser()
414
+ parser.add_argument(
415
+ "--config", default="PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml"
416
+ )
417
+ parser.add_argument("--checkpoint", default="")
418
+ parser.add_argument("--output_dir", default="runs/")
419
+ parser.add_argument("--device", default="cuda")
420
+ parser.add_argument("--local_rank", default=0, type=int)
421
+ parser.add_argument("--world_size", default=1, type=int)
422
+ parser.add_argument(
423
+ "--computing", type=str, default="single", help="number of gpus"
424
+ )
425
+ args = parser.parse_args()
426
+ import datetime
427
+
428
+ args.output_dir = os.path.join(
429
+ args.output_dir, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
430
+ )
431
+
432
+ gpus = torch.cuda.device_count()
433
+ if gpus > 1:
434
+ args.computing = "parallel"
435
+
436
+ config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
437
+
438
+ if not Path(args.output_dir).exists():
439
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
440
+
441
+ yaml.dump(config, open(os.path.join(args.output_dir, "config.yaml"), "w"))
442
+
443
+ if args.computing == "parallel":
444
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
445
+
446
+ main(args, config)
PreTrain_MeDSLIP/utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import io
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+ import datetime
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tqdm import tqdm
11
+
12
+ import warnings
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+
17
+ class SmoothedValue(object):
18
+ """Track a series of values and provide access to smoothed values over a
19
+ window or the global series average.
20
+ """
21
+
22
+ def __init__(self, window_size=20, fmt=None):
23
+ if fmt is None:
24
+ fmt = "{median:.4f} ({global_avg:.4f})"
25
+ self.deque = deque(maxlen=window_size)
26
+ self.total = 0.0
27
+ self.count = 0
28
+ self.fmt = fmt
29
+
30
+ def update(self, value, n=1):
31
+ self.deque.append(value)
32
+ self.count += n
33
+ self.total += value * n
34
+
35
+ def synchronize_between_processes(self):
36
+ """
37
+ Warning: does not synchronize the deque!
38
+ """
39
+ if not is_dist_avail_and_initialized():
40
+ return
41
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
42
+ dist.barrier()
43
+ dist.all_reduce(t)
44
+ t = t.tolist()
45
+ self.count = int(t[0])
46
+ self.total = t[1]
47
+
48
+ @property
49
+ def median(self):
50
+ d = torch.tensor(list(self.deque))
51
+ return d.median().item()
52
+
53
+ @property
54
+ def avg(self):
55
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
56
+ return d.mean().item()
57
+
58
+ @property
59
+ def global_avg(self):
60
+ if self.count == 0:
61
+ return self.total
62
+ else:
63
+ return self.total / self.count
64
+
65
+ @property
66
+ def max(self):
67
+ return max(self.deque)
68
+
69
+ @property
70
+ def value(self):
71
+ return self.deque[-1]
72
+
73
+ def __str__(self):
74
+ return self.fmt.format(
75
+ median=self.median,
76
+ avg=self.avg,
77
+ global_avg=self.global_avg,
78
+ max=self.max,
79
+ value=self.value,
80
+ )
81
+
82
+
83
+ class MetricLogger(object):
84
+ def __init__(self, delimiter="\t"):
85
+ self.meters = defaultdict(SmoothedValue)
86
+ self.delimiter = delimiter
87
+
88
+ def update(self, **kwargs):
89
+ for k, v in kwargs.items():
90
+ if isinstance(v, torch.Tensor):
91
+ v = v.item()
92
+ assert isinstance(v, (float, int))
93
+ self.meters[k].update(v)
94
+
95
+ def __getattr__(self, attr):
96
+ if attr in self.meters:
97
+ return self.meters[attr]
98
+ if attr in self.__dict__:
99
+ return self.__dict__[attr]
100
+ raise AttributeError(
101
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
102
+ )
103
+
104
+ def __str__(self):
105
+ loss_str = []
106
+ for name, meter in self.meters.items():
107
+ loss_str.append("{}: {}".format(name, str(meter)))
108
+ return self.delimiter.join(loss_str)
109
+
110
+ def global_avg(self):
111
+ loss_str = []
112
+ for name, meter in self.meters.items():
113
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
114
+ return self.delimiter.join(loss_str)
115
+
116
+ def synchronize_between_processes(self):
117
+ for meter in self.meters.values():
118
+ meter.synchronize_between_processes()
119
+
120
+ def add_meter(self, name, meter):
121
+ self.meters[name] = meter
122
+
123
+ def log_every(self, iterable, print_freq, header=None):
124
+ i = 0
125
+ if not header:
126
+ header = ""
127
+ start_time = time.time()
128
+ end = time.time()
129
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
130
+ data_time = SmoothedValue(fmt="{avg:.4f}")
131
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
132
+ log_msg = ["{meters}"]
133
+ if torch.cuda.is_available():
134
+ log_msg.append("max mem: {memory:.0f}")
135
+ log_msg = self.delimiter.join(log_msg)
136
+ MB = 1024.0 * 1024.0
137
+
138
+ loop = tqdm(iterable)
139
+ loop.set_description(header)
140
+
141
+ for obj in loop:
142
+ data_time.update(time.time() - end)
143
+ yield obj
144
+ iter_time.update(time.time() - end)
145
+ if i % print_freq == 0 or i == len(loop) - 1:
146
+ eta_seconds = iter_time.global_avg * (len(loop) - i)
147
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
148
+ if torch.cuda.is_available():
149
+ loop.set_postfix_str(
150
+ log_msg.format(
151
+ i,
152
+ len(loop),
153
+ eta=eta_string,
154
+ meters=str(self),
155
+ time=str(iter_time),
156
+ data=str(data_time),
157
+ memory=torch.cuda.max_memory_allocated() / MB,
158
+ )
159
+ )
160
+ else:
161
+ loop.set_postfix_str(
162
+ log_msg.format(
163
+ i,
164
+ len(loop),
165
+ eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time),
168
+ data=str(data_time),
169
+ )
170
+ )
171
+ i += 1
172
+ end = time.time()
173
+
174
+
175
+ class AttrDict(dict):
176
+ def __init__(self, *args, **kwargs):
177
+ super(AttrDict, self).__init__(*args, **kwargs)
178
+ self.__dict__ = self
179
+
180
+
181
+ def compute_acc(logits, label, reduction="mean"):
182
+ ret = (torch.argmax(logits, dim=1) == label).float()
183
+ if reduction == "none":
184
+ return ret.detach()
185
+ elif reduction == "mean":
186
+ return ret.mean().item()
187
+
188
+
189
+ def compute_n_params(model, return_str=True):
190
+ tot = 0
191
+ for p in model.parameters():
192
+ w = 1
193
+ for x in p.shape:
194
+ w *= x
195
+ tot += w
196
+ if return_str:
197
+ if tot >= 1e6:
198
+ return "{:.1f}M".format(tot / 1e6)
199
+ else:
200
+ return "{:.1f}K".format(tot / 1e3)
201
+ else:
202
+ return tot
203
+
204
+
205
+ def setup_for_distributed(is_master):
206
+ """
207
+ This function disables printing when not in master process
208
+ """
209
+ import builtins as __builtin__
210
+
211
+ builtin_print = __builtin__.print
212
+
213
+ def print(*args, **kwargs):
214
+ force = kwargs.pop("force", False)
215
+ if is_master or force:
216
+ builtin_print(*args, **kwargs)
217
+
218
+ __builtin__.print = print
219
+
220
+
221
+ def is_dist_avail_and_initialized():
222
+ if not dist.is_available():
223
+ return False
224
+ if not dist.is_initialized():
225
+ return False
226
+ return True
227
+
228
+
229
+ def get_world_size():
230
+ if not is_dist_avail_and_initialized():
231
+ return 1
232
+ return dist.get_world_size()
233
+
234
+
235
+ def get_rank():
236
+ if not is_dist_avail_and_initialized():
237
+ return 0
238
+ return dist.get_rank()
239
+
240
+
241
+ def is_main_process():
242
+ return get_rank() == 0
243
+
244
+
245
+ def save_on_master(*args, **kwargs):
246
+ if is_main_process():
247
+ torch.save(*args, **kwargs)
248
+
249
+
250
+ def init_distributed_mode(args):
251
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
252
+ args.rank = int(os.environ["RANK"])
253
+ args.world_size = int(os.environ["WORLD_SIZE"])
254
+ args.local_rank = int(os.environ["LOCAL_RANK"])
255
+ elif "SLURM_PROCID" in os.environ:
256
+ args.rank = int(os.environ["SLURM_PROCID"])
257
+ args.local_rank = args.rank % torch.cuda.device_count()
258
+ else:
259
+ print("Not using distributed mode")
260
+ args.distributed = False
261
+ return
262
+
263
+ args.distributed = True
264
+
265
+ torch.cuda.set_device(args.local_rank)
266
+ args.dist_backend = "nccl"
267
+ print(
268
+ "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
269
+ )
270
+ torch.distributed.init_process_group(
271
+ backend=args.dist_backend,
272
+ init_method=args.dist_url,
273
+ world_size=args.world_size,
274
+ rank=args.rank,
275
+ )
276
+ torch.distributed.barrier()
277
+ setup_for_distributed(args.rank == 0)
README.md CHANGED
@@ -1,3 +1,49 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MeDSLIP: Medical Knowledge Enhanced Language-Image Pre-Training in Radiology
2
+
3
+ ## Introduction:
4
+
5
+ The official implementation code for "MeDSLIP: Medical Knowledge Enhanced Language-Image Pre-Training in Radiology".
6
+
7
+ [**Arxiv Version**](https://arxiv.org/abs/2403.10635)
8
+
9
+ ## Quick Start:
10
+ Check checkpoints directory to download our pre-trained model from [Hugging Face: MeDSLIP](https://huggingface.co/pykale/MeDSLIP). It can be used for all zero-shot and finetuning tasks.
11
+
12
+ * **Zero-Shot Classification:**
13
+
14
+ We give an example on CXR14 in ```Sample_Zero-Shot_Classification_CXR14```. Change the data paths, and test our model by ```python test.py```.
15
+ We give an example on RSNA in ```Sample_Zero-Shot_Classification_RSNA```. Change the data paths, and test our model by ```python test.py```.
16
+
17
+ * **Zero-Shot Grounding:**
18
+
19
+ We give an example on RSNA_Pneumonia in ```Sample_Zero-Shot_Grounding_RSNA```. Change the data paths, and test our model by ```python test.py```.
20
+
21
+ * **Finetuning:**
22
+
23
+ We give segmentation and classification finetune code on SIIM_ACR dataset in ```Sample_Finetuning_SIIMACR```. Change the data paths, and finetune our model by ```python I1_classification/train_res_ft.py``` or ```python I2_segementation/train_res_ft.py```.
24
+
25
+ ## Pre-train:
26
+ ### Data Preparation
27
+ All files for data preparation files can be downloaded from [Hugging Face: MeDSLIP](https://huggingface.co/pykale/MeDSLIP).
28
+ - Extracted triplets: `landmark_observation_adj_mtx.npy`
29
+ - Training list: `train.json`
30
+ - Validation list: `valid.json`
31
+ - Test list: `test.json`
32
+
33
+ ### Pre-training
34
+ Our pre-train code is given in ```PreTrain_MeDSLIP```.
35
+ * Check the ```PreTrain_MeDSLIP/data_file``` dir and download the files for data preparation.
36
+ * Change the data and preparation files paths as you disire in ```PreTrain_MeDSLIP/configs/Pretrain_MeDSLIP.yaml```, and ```python PreTrain_MeDSLIP/train_MeDSLIP.py``` to pre-train.
37
+
38
+ ## Reference
39
+ ```
40
+ @article{fan2024medslip,
41
+ title={MeDSLIP: Medical Dual-Stream Language-Image Pre-training for Fine-grained Alignment},
42
+ author={Fan, Wenrui and Suvon, Mohammod Naimul Islam and Zhou, Shuo and Liu, Xianyuan and Alabed, Samer and Osmani, Venet and Swift, Andrew and Chen, Chen and Lu, Haiping},
43
+ journal={arXiv preprint arXiv:2403.10635},
44
+ year={2024}
45
+ }
46
+ ```
47
+
48
+ ## Contact
49
+ If you have any question, please feel free to contact [email protected].
Sample_Finetuning_SIIMACR/I1_classification/configs/Res_train.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_train_images.csv"
2
+ valid_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_test_images.csv"
3
+ test_file: "SIIM-CLS/siim-acr-pneumothorax/stage_1_test_images.csv"
4
+
5
+ image_res: 224
6
+ batch_size: 64
7
+ test_batch_size: 64
8
+ num_classes: 1
9
+ temp: 0.07
10
+ mlm_probability: 0.15
11
+ queue_size: 8192
12
+ momentum: 0.995
13
+ alpha: 0.4
14
+ percentage: 1.0
15
+
16
+ optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02}
17
+ schedular: {sched: cosine, lr: 1e-5, epochs: 200, min_lr: 1e-5, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 20, cooldown_epochs: 0}
Sample_Finetuning_SIIMACR/I1_classification/dataset/dataset_siim_acr.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cmath import nan
2
+ import csv
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ import pydicom
8
+
9
+ from abc import abstractmethod
10
+ from itertools import islice
11
+ from typing import List, Tuple, Dict, Any
12
+ from torch.utils.data import DataLoader
13
+ import PIL
14
+ from torch.utils.data import Dataset
15
+ import numpy as np
16
+ import pandas as pd
17
+ from torchvision import transforms
18
+ from PIL import Image
19
+ from skimage import exposure
20
+ import torch
21
+ from torchvision.transforms import InterpolationMode
22
+ from dataset.randaugment import RandomAugment
23
+
24
+
25
+ class SIIM_ACR_Dataset(Dataset):
26
+ def __init__(self, csv_path, is_train=True, percentage=0.01):
27
+ data_info = pd.read_csv(csv_path)
28
+ if is_train == True:
29
+ total_len = int(percentage * len(data_info))
30
+ choice_list = np.random.choice(
31
+ range(len(data_info)), size=total_len, replace=False
32
+ )
33
+ self.img_path_list = data_info["image_path"][choice_list].tolist()
34
+ else:
35
+ self.img_path_list = data_info["image_path"].tolist()
36
+
37
+ self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/"
38
+ self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/" # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
39
+
40
+ normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
41
+
42
+ if is_train:
43
+ self.transform = transforms.Compose(
44
+ [
45
+ transforms.RandomResizedCrop(
46
+ 224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
47
+ ),
48
+ transforms.RandomHorizontalFlip(),
49
+ RandomAugment(
50
+ 2,
51
+ 7,
52
+ isPIL=True,
53
+ augs=[
54
+ "Identity",
55
+ "AutoContrast",
56
+ "Equalize",
57
+ "Brightness",
58
+ "Sharpness",
59
+ "ShearX",
60
+ "ShearY",
61
+ "TranslateX",
62
+ "TranslateY",
63
+ "Rotate",
64
+ ],
65
+ ),
66
+ transforms.ToTensor(),
67
+ normalize,
68
+ ]
69
+ )
70
+ else:
71
+ self.transform = transforms.Compose(
72
+ [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,]
73
+ )
74
+
75
+ self.seg_transfrom = transforms.Compose(
76
+ [
77
+ transforms.ToTensor(),
78
+ transforms.Resize([224, 224], interpolation=InterpolationMode.NEAREST),
79
+ ]
80
+ )
81
+
82
+ def __getitem__(self, index):
83
+ img_path = self.img_root + self.img_path_list[index].split("/")[-1] # + ".png"
84
+ seg_path = (
85
+ self.seg_root + self.img_path_list[index].split("/")[-1] # + ".png"
86
+ ) # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
87
+ img = PIL.Image.open(img_path).convert("RGB")
88
+ image = self.transform(img)
89
+
90
+ seg_map = PIL.Image.open(seg_path)
91
+ seg_map = self.seg_transfrom(seg_map)
92
+ seg_map = (seg_map > 0).type(torch.int)
93
+ class_label = np.array([int(torch.sum(seg_map) > 0)])
94
+ return {"image": image, "label": class_label}
95
+
96
+ def __len__(self):
97
+ return len(self.img_path_list)
98
+
99
+
100
+ def create_loader_RSNA(
101
+ datasets, samplers, batch_size, num_workers, is_trains, collate_fns
102
+ ):
103
+ loaders = []
104
+ for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
105
+ datasets, samplers, batch_size, num_workers, is_trains, collate_fns
106
+ ):
107
+ if is_train:
108
+ shuffle = sampler is None
109
+ drop_last = True
110
+ else:
111
+ shuffle = False
112
+ drop_last = False
113
+ loader = DataLoader(
114
+ dataset,
115
+ batch_size=bs,
116
+ num_workers=n_worker,
117
+ pin_memory=True,
118
+ sampler=sampler,
119
+ shuffle=shuffle,
120
+ collate_fn=collate_fn,
121
+ drop_last=drop_last,
122
+ )
123
+ loaders.append(loader)
124
+ return loaders
Sample_Finetuning_SIIMACR/I1_classification/dataset/randaugment.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ """
12
+ same output as PIL.ImageOps.autocontrast
13
+ """
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ """
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ """
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0:
55
+ return ch
56
+ n = np.empty_like(hist)
57
+ n[0] = step // 2
58
+ n[1:] = hist[:-1]
59
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
60
+ return table[ch]
61
+
62
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
63
+ out = cv2.merge(channels)
64
+ return out
65
+
66
+
67
+ def rotate_func(img, degree, fill=(0, 0, 0)):
68
+ """
69
+ like PIL, rotate by degree, not radians
70
+ """
71
+ H, W = img.shape[0], img.shape[1]
72
+ center = W / 2, H / 2
73
+ M = cv2.getRotationMatrix2D(center, degree, 1)
74
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
75
+ return out
76
+
77
+
78
+ def solarize_func(img, thresh=128):
79
+ """
80
+ same output as PIL.ImageOps.posterize
81
+ """
82
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
83
+ table = table.clip(0, 255).astype(np.uint8)
84
+ out = table[img]
85
+ return out
86
+
87
+
88
+ def color_func(img, factor):
89
+ """
90
+ same output as PIL.ImageEnhance.Color
91
+ """
92
+ ## implementation according to PIL definition, quite slow
93
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
94
+ # out = blend(degenerate, img, factor)
95
+ # M = (
96
+ # np.eye(3) * factor
97
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
98
+ # )[np.newaxis, np.newaxis, :]
99
+ M = np.float32(
100
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
101
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
102
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
103
+ return out
104
+
105
+
106
+ def contrast_func(img, factor):
107
+ """
108
+ same output as PIL.ImageEnhance.Contrast
109
+ """
110
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
111
+ table = (
112
+ np.array([(el - mean) * factor + mean for el in range(256)])
113
+ .clip(0, 255)
114
+ .astype(np.uint8)
115
+ )
116
+ out = table[img]
117
+ return out
118
+
119
+
120
+ def brightness_func(img, factor):
121
+ """
122
+ same output as PIL.ImageEnhance.Contrast
123
+ """
124
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def sharpness_func(img, factor):
130
+ """
131
+ The differences the this result and PIL are all on the 4 boundaries, the center
132
+ areas are same
133
+ """
134
+ kernel = np.ones((3, 3), dtype=np.float32)
135
+ kernel[1][1] = 5
136
+ kernel /= 13
137
+ degenerate = cv2.filter2D(img, -1, kernel)
138
+ if factor == 0.0:
139
+ out = degenerate
140
+ elif factor == 1.0:
141
+ out = img
142
+ else:
143
+ out = img.astype(np.float32)
144
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
145
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
146
+ out = out.astype(np.uint8)
147
+ return out
148
+
149
+
150
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
151
+ H, W = img.shape[0], img.shape[1]
152
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
153
+ out = cv2.warpAffine(
154
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
155
+ ).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ """
161
+ same output as PIL.Image.transform
162
+ """
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(
166
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
167
+ ).astype(np.uint8)
168
+ return out
169
+
170
+
171
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
172
+ """
173
+ same output as PIL.Image.transform
174
+ """
175
+ H, W = img.shape[0], img.shape[1]
176
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
177
+ out = cv2.warpAffine(
178
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
179
+ ).astype(np.uint8)
180
+ return out
181
+
182
+
183
+ def posterize_func(img, bits):
184
+ """
185
+ same output as PIL.ImageOps.posterize
186
+ """
187
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
188
+ return out
189
+
190
+
191
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
192
+ H, W = img.shape[0], img.shape[1]
193
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
194
+ out = cv2.warpAffine(
195
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
196
+ ).astype(np.uint8)
197
+ return out
198
+
199
+
200
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
201
+ replace = np.array(replace, dtype=np.uint8)
202
+ H, W = img.shape[0], img.shape[1]
203
+ rh, rw = np.random.random(2)
204
+ pad_size = pad_size // 2
205
+ ch, cw = int(rh * H), int(rw * W)
206
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
207
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
208
+ out = img.copy()
209
+ out[x1:x2, y1:y2, :] = replace
210
+ return out
211
+
212
+
213
+ ### level to args
214
+ def enhance_level_to_args(MAX_LEVEL):
215
+ def level_to_args(level):
216
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
217
+
218
+ return level_to_args
219
+
220
+
221
+ def shear_level_to_args(MAX_LEVEL, replace_value):
222
+ def level_to_args(level):
223
+ level = (level / MAX_LEVEL) * 0.3
224
+ if np.random.random() > 0.5:
225
+ level = -level
226
+ return (level, replace_value)
227
+
228
+ return level_to_args
229
+
230
+
231
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
232
+ def level_to_args(level):
233
+ level = (level / MAX_LEVEL) * float(translate_const)
234
+ if np.random.random() > 0.5:
235
+ level = -level
236
+ return (level, replace_value)
237
+
238
+ return level_to_args
239
+
240
+
241
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
242
+ def level_to_args(level):
243
+ level = int((level / MAX_LEVEL) * cutout_const)
244
+ return (level, replace_value)
245
+
246
+ return level_to_args
247
+
248
+
249
+ def solarize_level_to_args(MAX_LEVEL):
250
+ def level_to_args(level):
251
+ level = int((level / MAX_LEVEL) * 256)
252
+ return (level,)
253
+
254
+ return level_to_args
255
+
256
+
257
+ def none_level_to_args(level):
258
+ return ()
259
+
260
+
261
+ def posterize_level_to_args(MAX_LEVEL):
262
+ def level_to_args(level):
263
+ level = int((level / MAX_LEVEL) * 4)
264
+ return (level,)
265
+
266
+ return level_to_args
267
+
268
+
269
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
270
+ def level_to_args(level):
271
+ level = (level / MAX_LEVEL) * 30
272
+ if np.random.random() < 0.5:
273
+ level = -level
274
+ return (level, replace_value)
275
+
276
+ return level_to_args
277
+
278
+
279
+ func_dict = {
280
+ "Identity": identity_func,
281
+ "AutoContrast": autocontrast_func,
282
+ "Equalize": equalize_func,
283
+ "Rotate": rotate_func,
284
+ "Solarize": solarize_func,
285
+ "Color": color_func,
286
+ "Contrast": contrast_func,
287
+ "Brightness": brightness_func,
288
+ "Sharpness": sharpness_func,
289
+ "ShearX": shear_x_func,
290
+ "TranslateX": translate_x_func,
291
+ "TranslateY": translate_y_func,
292
+ "Posterize": posterize_func,
293
+ "ShearY": shear_y_func,
294
+ }
295
+
296
+ translate_const = 10
297
+ MAX_LEVEL = 10
298
+ replace_value = (128, 128, 128)
299
+ arg_dict = {
300
+ "Identity": none_level_to_args,
301
+ "AutoContrast": none_level_to_args,
302
+ "Equalize": none_level_to_args,
303
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
304
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
305
+ "Color": enhance_level_to_args(MAX_LEVEL),
306
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
307
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
308
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
309
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
310
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
311
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
312
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
313
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
314
+ }
315
+
316
+
317
+ class RandomAugment(object):
318
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
319
+ self.N = N
320
+ self.M = M
321
+ self.isPIL = isPIL
322
+ if augs:
323
+ self.augs = augs
324
+ else:
325
+ self.augs = list(arg_dict.keys())
326
+
327
+ def get_random_ops(self):
328
+ sampled_ops = np.random.choice(self.augs, self.N)
329
+ return [(op, 0.5, self.M) for op in sampled_ops]
330
+
331
+ def __call__(self, img):
332
+ if self.isPIL:
333
+ img = np.array(img)
334
+ ops = self.get_random_ops()
335
+ for name, prob, level in ops:
336
+ if np.random.random() > prob:
337
+ continue
338
+ args = arg_dict[name](level)
339
+ img = func_dict[name](img, *args)
340
+ return img
341
+
342
+
343
+ if __name__ == "__main__":
344
+ a = RandomAugment()
345
+ img = np.random.randn(32, 32, 3)
346
+ a(img)
Sample_Finetuning_SIIMACR/I1_classification/models/resnet.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torchvision.models as models
4
+ import torch
5
+ from einops import rearrange
6
+
7
+
8
+ class ModelRes_ft(nn.Module):
9
+ def __init__(
10
+ self,
11
+ res_base_model,
12
+ out_size,
13
+ imagenet_pretrain=False,
14
+ linear_probe=False,
15
+ use_base=True,
16
+ ):
17
+ super(ModelRes_ft, self).__init__()
18
+ self.resnet_dict = {
19
+ "resnet18": models.resnet18(pretrained=imagenet_pretrain),
20
+ "resnet50": models.resnet50(pretrained=imagenet_pretrain),
21
+ }
22
+ resnet = self._get_res_basemodel(res_base_model)
23
+ self.use_base = use_base
24
+
25
+ if not self.use_base:
26
+ num_ftrs = int(resnet.fc.in_features / 2)
27
+ self.res_features = nn.Sequential(*list(resnet.children())[:-3])
28
+ self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs)
29
+ self.res_l2_anatomy = nn.Linear(num_ftrs, 256)
30
+ self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs)
31
+ self.res_l2_pathology = nn.Linear(num_ftrs, 256)
32
+
33
+ self.mask_generator = nn.Linear(num_ftrs, num_ftrs)
34
+ self.back = nn.Linear(256, num_ftrs)
35
+ self.last_res = nn.Sequential(*list(resnet.children())[-3:-1])
36
+ else:
37
+ self.res_features = nn.Sequential(*list(resnet.children())[:-1])
38
+ self.res_out = nn.Linear(int(resnet.fc.in_features), out_size)
39
+
40
+ def _get_res_basemodel(self, res_model_name):
41
+ try:
42
+ res_model = self.resnet_dict[res_model_name]
43
+ print("Image feature extractor:", res_model_name)
44
+ return res_model
45
+ except:
46
+ raise (
47
+ "Invalid model name. Check the config file and pass one of: resnet18 or resnet50"
48
+ )
49
+
50
+ def image_encoder(self, xis):
51
+ # patch features
52
+ """
53
+ 16 torch.Size([16, 1024, 14, 14])
54
+ torch.Size([16, 196, 1024])
55
+ torch.Size([3136, 1024])
56
+ torch.Size([16, 196, 256])
57
+ """
58
+ batch_size = xis.shape[0]
59
+ res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num
60
+ res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d")
61
+ x = rearrange(res_fea, "b n d -> (b n) d")
62
+ mask = self.mask_generator(x)
63
+ x_pathology = mask * x
64
+ x_pathology = self.res_l1_pathology(x_pathology)
65
+ x_pathology = F.relu(x_pathology)
66
+
67
+ x_pathology = self.res_l2_pathology(x_pathology)
68
+
69
+ out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size)
70
+ out_emb_pathology = self.back(out_emb_pathology)
71
+ out_emb_pathology = rearrange(out_emb_pathology, "b (n1 n2) d -> b d n1 n2", n1=14, n2=14)
72
+ out_emb_pathology = self.last_res(out_emb_pathology)
73
+ out_emb_pathology = out_emb_pathology.squeeze()
74
+
75
+ return out_emb_pathology
76
+
77
+ def forward(self, img, linear_probe=False):
78
+ if self.use_base:
79
+ x = self.res_features(img)
80
+ else:
81
+ x = self.image_encoder(img)
82
+
83
+ x = x.squeeze()
84
+ if linear_probe:
85
+ return x
86
+ else:
87
+ x = self.res_out(x)
88
+ return x
Sample_Finetuning_SIIMACR/I1_classification/optim/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .adamp import AdamP
2
+ from .adamw import AdamW
3
+ from .adafactor import Adafactor
4
+ from .adahessian import Adahessian
5
+ from .lookahead import Lookahead
6
+ from .nadam import Nadam
7
+ from .novograd import NovoGrad
8
+ from .nvnovograd import NvNovoGrad
9
+ from .radam import RAdam
10
+ from .rmsprop_tf import RMSpropTF
11
+ from .sgdp import SGDP
12
+
13
+ from .optim_factory import create_optimizer
Sample_Finetuning_SIIMACR/I1_classification/optim/adafactor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adafactor Optimizer
2
+
3
+ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4
+
5
+ Original header/copyright below.
6
+
7
+ """
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ #
10
+ # This source code is licensed under the MIT license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ import torch
13
+ import math
14
+
15
+
16
+ class Adafactor(torch.optim.Optimizer):
17
+ """Implements Adafactor algorithm.
18
+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19
+ (see https://arxiv.org/abs/1804.04235)
20
+
21
+ Note that this optimizer internally adjusts the learning rate depending on the
22
+ *scale_parameter*, *relative_step* and *warmup_init* options.
23
+
24
+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25
+ `relative_step=False`.
26
+
27
+ Arguments:
28
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29
+ lr (float, optional): external learning rate (default: None)
30
+ eps (tuple[float, float]): regularization constants for square gradient
31
+ and parameter scale respectively (default: (1e-30, 1e-3))
32
+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33
+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34
+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
35
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37
+ relative_step (bool): if True, time-dependent learning rate is computed
38
+ instead of external learning rate (default: True)
39
+ warmup_init (bool): time-dependent learning rate computation depends on
40
+ whether warm-up initialization is being used (default: False)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ params,
46
+ lr=None,
47
+ eps=1e-30,
48
+ eps_scale=1e-3,
49
+ clip_threshold=1.0,
50
+ decay_rate=-0.8,
51
+ betas=None,
52
+ weight_decay=0.0,
53
+ scale_parameter=True,
54
+ warmup_init=False,
55
+ ):
56
+ relative_step = lr is None
57
+ if warmup_init and not relative_step:
58
+ raise ValueError("warmup_init requires relative_step=True")
59
+
60
+ beta1 = (
61
+ None if betas is None else betas[0]
62
+ ) # make it compat with standard betas arg
63
+ defaults = dict(
64
+ lr=lr,
65
+ eps=eps,
66
+ eps_scale=eps_scale,
67
+ clip_threshold=clip_threshold,
68
+ decay_rate=decay_rate,
69
+ beta1=beta1,
70
+ weight_decay=weight_decay,
71
+ scale_parameter=scale_parameter,
72
+ relative_step=relative_step,
73
+ warmup_init=warmup_init,
74
+ )
75
+ super(Adafactor, self).__init__(params, defaults)
76
+
77
+ @staticmethod
78
+ def _get_lr(param_group, param_state):
79
+ if param_group["relative_step"]:
80
+ min_step = (
81
+ 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
82
+ )
83
+ lr_t = min(min_step, 1.0 / math.sqrt(param_state["step"]))
84
+ param_scale = 1.0
85
+ if param_group["scale_parameter"]:
86
+ param_scale = max(param_group["eps_scale"], param_state["RMS"])
87
+ param_group["lr"] = lr_t * param_scale
88
+ return param_group["lr"]
89
+
90
+ @staticmethod
91
+ def _get_options(param_group, param_shape):
92
+ factored = len(param_shape) >= 2
93
+ use_first_moment = param_group["beta1"] is not None
94
+ return factored, use_first_moment
95
+
96
+ @staticmethod
97
+ def _rms(tensor):
98
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
99
+
100
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
101
+ r_factor = (
102
+ (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
103
+ .rsqrt_()
104
+ .unsqueeze(-1)
105
+ )
106
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
107
+ return torch.mul(r_factor, c_factor)
108
+
109
+ def step(self, closure=None):
110
+ """Performs a single optimization step.
111
+ Arguments:
112
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
113
+ """
114
+ loss = None
115
+ if closure is not None:
116
+ loss = closure()
117
+
118
+ for group in self.param_groups:
119
+ for p in group["params"]:
120
+ if p.grad is None:
121
+ continue
122
+ grad = p.grad.data
123
+ if grad.dtype in {torch.float16, torch.bfloat16}:
124
+ grad = grad.float()
125
+ if grad.is_sparse:
126
+ raise RuntimeError("Adafactor does not support sparse gradients.")
127
+
128
+ state = self.state[p]
129
+ grad_shape = grad.shape
130
+
131
+ factored, use_first_moment = self._get_options(group, grad_shape)
132
+ # State Initialization
133
+ if len(state) == 0:
134
+ state["step"] = 0
135
+
136
+ if use_first_moment:
137
+ # Exponential moving average of gradient values
138
+ state["exp_avg"] = torch.zeros_like(grad)
139
+ if factored:
140
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
141
+ state["exp_avg_sq_col"] = torch.zeros(
142
+ grad_shape[:-2] + grad_shape[-1:]
143
+ ).to(grad)
144
+ else:
145
+ state["exp_avg_sq"] = torch.zeros_like(grad)
146
+
147
+ state["RMS"] = 0
148
+ else:
149
+ if use_first_moment:
150
+ state["exp_avg"] = state["exp_avg"].to(grad)
151
+ if factored:
152
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
153
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
154
+ else:
155
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
156
+
157
+ p_data_fp32 = p.data
158
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
159
+ p_data_fp32 = p_data_fp32.float()
160
+
161
+ state["step"] += 1
162
+ state["RMS"] = self._rms(p_data_fp32)
163
+ lr_t = self._get_lr(group, state)
164
+
165
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
166
+ update = grad ** 2 + group["eps"]
167
+ if factored:
168
+ exp_avg_sq_row = state["exp_avg_sq_row"]
169
+ exp_avg_sq_col = state["exp_avg_sq_col"]
170
+
171
+ exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
172
+ exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
173
+ # exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
174
+ # exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
175
+
176
+ # Approximation of exponential moving average of square of gradient
177
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
178
+ update.mul_(grad)
179
+ else:
180
+ exp_avg_sq = state["exp_avg_sq"]
181
+
182
+ exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
183
+ # exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
184
+ update = exp_avg_sq.rsqrt().mul_(grad)
185
+
186
+ update.div_(
187
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)
188
+ )
189
+ update.mul_(lr_t)
190
+
191
+ if use_first_moment:
192
+ exp_avg = state["exp_avg"]
193
+ exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
194
+ # exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
195
+ update = exp_avg
196
+
197
+ if group["weight_decay"] != 0:
198
+ p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
199
+ # p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
200
+
201
+ p_data_fp32.add_(-update)
202
+
203
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
204
+ p.data.copy_(p_data_fp32)
205
+
206
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/adahessian.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdaHessian Optimizer
2
+
3
+ Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4
+ Originally licensed MIT, Copyright 2020, David Samuel
5
+ """
6
+ import torch
7
+
8
+
9
+ class Adahessian(torch.optim.Optimizer):
10
+ """
11
+ Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12
+
13
+ Arguments:
14
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15
+ lr (float, optional): learning rate (default: 0.1)
16
+ betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17
+ squared hessian trace (default: (0.9, 0.999))
18
+ eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20
+ hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21
+ update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22
+ (to save time) (default: 1)
23
+ n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ params,
29
+ lr=0.1,
30
+ betas=(0.9, 0.999),
31
+ eps=1e-8,
32
+ weight_decay=0.0,
33
+ hessian_power=1.0,
34
+ update_each=1,
35
+ n_samples=1,
36
+ avg_conv_kernel=False,
37
+ ):
38
+ if not 0.0 <= lr:
39
+ raise ValueError(f"Invalid learning rate: {lr}")
40
+ if not 0.0 <= eps:
41
+ raise ValueError(f"Invalid epsilon value: {eps}")
42
+ if not 0.0 <= betas[0] < 1.0:
43
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
44
+ if not 0.0 <= betas[1] < 1.0:
45
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
46
+ if not 0.0 <= hessian_power <= 1.0:
47
+ raise ValueError(f"Invalid Hessian power value: {hessian_power}")
48
+
49
+ self.n_samples = n_samples
50
+ self.update_each = update_each
51
+ self.avg_conv_kernel = avg_conv_kernel
52
+
53
+ # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
54
+ self.seed = 2147483647
55
+ self.generator = torch.Generator().manual_seed(self.seed)
56
+
57
+ defaults = dict(
58
+ lr=lr,
59
+ betas=betas,
60
+ eps=eps,
61
+ weight_decay=weight_decay,
62
+ hessian_power=hessian_power,
63
+ )
64
+ super(Adahessian, self).__init__(params, defaults)
65
+
66
+ for p in self.get_params():
67
+ p.hess = 0.0
68
+ self.state[p]["hessian step"] = 0
69
+
70
+ @property
71
+ def is_second_order(self):
72
+ return True
73
+
74
+ def get_params(self):
75
+ """
76
+ Gets all parameters in all param_groups with gradients
77
+ """
78
+
79
+ return (
80
+ p for group in self.param_groups for p in group["params"] if p.requires_grad
81
+ )
82
+
83
+ def zero_hessian(self):
84
+ """
85
+ Zeros out the accumalated hessian traces.
86
+ """
87
+
88
+ for p in self.get_params():
89
+ if (
90
+ not isinstance(p.hess, float)
91
+ and self.state[p]["hessian step"] % self.update_each == 0
92
+ ):
93
+ p.hess.zero_()
94
+
95
+ @torch.no_grad()
96
+ def set_hessian(self):
97
+ """
98
+ Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
99
+ """
100
+
101
+ params = []
102
+ for p in filter(lambda p: p.grad is not None, self.get_params()):
103
+ if (
104
+ self.state[p]["hessian step"] % self.update_each == 0
105
+ ): # compute the trace only each `update_each` step
106
+ params.append(p)
107
+ self.state[p]["hessian step"] += 1
108
+
109
+ if len(params) == 0:
110
+ return
111
+
112
+ if (
113
+ self.generator.device != params[0].device
114
+ ): # hackish way of casting the generator to the right device
115
+ self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
116
+
117
+ grads = [p.grad for p in params]
118
+
119
+ for i in range(self.n_samples):
120
+ # Rademacher distribution {-1.0, 1.0}
121
+ zs = [
122
+ torch.randint(0, 2, p.size(), generator=self.generator, device=p.device)
123
+ * 2.0
124
+ - 1.0
125
+ for p in params
126
+ ]
127
+ h_zs = torch.autograd.grad(
128
+ grads,
129
+ params,
130
+ grad_outputs=zs,
131
+ only_inputs=True,
132
+ retain_graph=i < self.n_samples - 1,
133
+ )
134
+ for h_z, z, p in zip(h_zs, zs, params):
135
+ p.hess += (
136
+ h_z * z / self.n_samples
137
+ ) # approximate the expected values of z*(H@z)
138
+
139
+ @torch.no_grad()
140
+ def step(self, closure=None):
141
+ """
142
+ Performs a single optimization step.
143
+ Arguments:
144
+ closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
145
+ """
146
+
147
+ loss = None
148
+ if closure is not None:
149
+ loss = closure()
150
+
151
+ self.zero_hessian()
152
+ self.set_hessian()
153
+
154
+ for group in self.param_groups:
155
+ for p in group["params"]:
156
+ if p.grad is None or p.hess is None:
157
+ continue
158
+
159
+ if self.avg_conv_kernel and p.dim() == 4:
160
+ p.hess = (
161
+ torch.abs(p.hess)
162
+ .mean(dim=[2, 3], keepdim=True)
163
+ .expand_as(p.hess)
164
+ .clone()
165
+ )
166
+
167
+ # Perform correct stepweight decay as in AdamW
168
+ p.mul_(1 - group["lr"] * group["weight_decay"])
169
+
170
+ state = self.state[p]
171
+
172
+ # State initialization
173
+ if len(state) == 1:
174
+ state["step"] = 0
175
+ # Exponential moving average of gradient values
176
+ state["exp_avg"] = torch.zeros_like(p)
177
+ # Exponential moving average of Hessian diagonal square values
178
+ state["exp_hessian_diag_sq"] = torch.zeros_like(p)
179
+
180
+ exp_avg, exp_hessian_diag_sq = (
181
+ state["exp_avg"],
182
+ state["exp_hessian_diag_sq"],
183
+ )
184
+ beta1, beta2 = group["betas"]
185
+ state["step"] += 1
186
+
187
+ # Decay the first and second moment running average coefficient
188
+ exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
189
+ exp_hessian_diag_sq.mul_(beta2).addcmul_(
190
+ p.hess, p.hess, value=1 - beta2
191
+ )
192
+
193
+ bias_correction1 = 1 - beta1 ** state["step"]
194
+ bias_correction2 = 1 - beta2 ** state["step"]
195
+
196
+ k = group["hessian_power"]
197
+ denom = (
198
+ (exp_hessian_diag_sq / bias_correction2)
199
+ .pow_(k / 2)
200
+ .add_(group["eps"])
201
+ )
202
+
203
+ # make update
204
+ step_size = group["lr"] / bias_correction1
205
+ p.addcdiv_(exp_avg, denom, value=-step_size)
206
+
207
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/adamp.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3
+
4
+ Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5
+ Code: https://github.com/clovaai/AdamP
6
+
7
+ Copyright (c) 2020-present NAVER Corp.
8
+ MIT license
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.optim.optimizer import Optimizer, required
14
+ import math
15
+
16
+
17
+ class AdamP(Optimizer):
18
+ def __init__(
19
+ self,
20
+ params,
21
+ lr=1e-3,
22
+ betas=(0.9, 0.999),
23
+ eps=1e-8,
24
+ weight_decay=0,
25
+ delta=0.1,
26
+ wd_ratio=0.1,
27
+ nesterov=False,
28
+ ):
29
+ defaults = dict(
30
+ lr=lr,
31
+ betas=betas,
32
+ eps=eps,
33
+ weight_decay=weight_decay,
34
+ delta=delta,
35
+ wd_ratio=wd_ratio,
36
+ nesterov=nesterov,
37
+ )
38
+ super(AdamP, self).__init__(params, defaults)
39
+
40
+ def _channel_view(self, x):
41
+ return x.view(x.size(0), -1)
42
+
43
+ def _layer_view(self, x):
44
+ return x.view(1, -1)
45
+
46
+ def _cosine_similarity(self, x, y, eps, view_func):
47
+ x = view_func(x)
48
+ y = view_func(y)
49
+
50
+ x_norm = x.norm(dim=1).add_(eps)
51
+ y_norm = y.norm(dim=1).add_(eps)
52
+ dot = (x * y).sum(dim=1)
53
+
54
+ return dot.abs() / x_norm / y_norm
55
+
56
+ def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
57
+ wd = 1
58
+ expand_size = [-1] + [1] * (len(p.shape) - 1)
59
+ for view_func in [self._channel_view, self._layer_view]:
60
+
61
+ cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
62
+
63
+ if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
64
+ p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
65
+ perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
66
+ wd = wd_ratio
67
+
68
+ return perturb, wd
69
+
70
+ return perturb, wd
71
+
72
+ def step(self, closure=None):
73
+ loss = None
74
+ if closure is not None:
75
+ loss = closure()
76
+
77
+ for group in self.param_groups:
78
+ for p in group["params"]:
79
+ if p.grad is None:
80
+ continue
81
+
82
+ grad = p.grad.data
83
+ beta1, beta2 = group["betas"]
84
+ nesterov = group["nesterov"]
85
+
86
+ state = self.state[p]
87
+
88
+ # State initialization
89
+ if len(state) == 0:
90
+ state["step"] = 0
91
+ state["exp_avg"] = torch.zeros_like(p.data)
92
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
93
+
94
+ # Adam
95
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
96
+
97
+ state["step"] += 1
98
+ bias_correction1 = 1 - beta1 ** state["step"]
99
+ bias_correction2 = 1 - beta2 ** state["step"]
100
+
101
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
102
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
103
+
104
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
105
+ group["eps"]
106
+ )
107
+ step_size = group["lr"] / bias_correction1
108
+
109
+ if nesterov:
110
+ perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
111
+ else:
112
+ perturb = exp_avg / denom
113
+
114
+ # Projection
115
+ wd_ratio = 1
116
+ if len(p.shape) > 1:
117
+ perturb, wd_ratio = self._projection(
118
+ p,
119
+ grad,
120
+ perturb,
121
+ group["delta"],
122
+ group["wd_ratio"],
123
+ group["eps"],
124
+ )
125
+
126
+ # Weight decay
127
+ if group["weight_decay"] > 0:
128
+ p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio)
129
+
130
+ # Step
131
+ p.data.add_(-step_size, perturb)
132
+
133
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/adamw.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdamW Optimizer
2
+ Impl copied from PyTorch master
3
+ """
4
+ import math
5
+ import torch
6
+ from torch.optim.optimizer import Optimizer
7
+
8
+
9
+ class AdamW(Optimizer):
10
+ r"""Implements AdamW algorithm.
11
+
12
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14
+
15
+ Arguments:
16
+ params (iterable): iterable of parameters to optimize or dicts defining
17
+ parameter groups
18
+ lr (float, optional): learning rate (default: 1e-3)
19
+ betas (Tuple[float, float], optional): coefficients used for computing
20
+ running averages of gradient and its square (default: (0.9, 0.999))
21
+ eps (float, optional): term added to the denominator to improve
22
+ numerical stability (default: 1e-8)
23
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
26
+ (default: False)
27
+
28
+ .. _Adam\: A Method for Stochastic Optimization:
29
+ https://arxiv.org/abs/1412.6980
30
+ .. _Decoupled Weight Decay Regularization:
31
+ https://arxiv.org/abs/1711.05101
32
+ .. _On the Convergence of Adam and Beyond:
33
+ https://openreview.net/forum?id=ryQu7f-RZ
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ params,
39
+ lr=1e-3,
40
+ betas=(0.9, 0.999),
41
+ eps=1e-8,
42
+ weight_decay=1e-2,
43
+ amsgrad=False,
44
+ ):
45
+ if not 0.0 <= lr:
46
+ raise ValueError("Invalid learning rate: {}".format(lr))
47
+ if not 0.0 <= eps:
48
+ raise ValueError("Invalid epsilon value: {}".format(eps))
49
+ if not 0.0 <= betas[0] < 1.0:
50
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
51
+ if not 0.0 <= betas[1] < 1.0:
52
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
53
+ defaults = dict(
54
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
55
+ )
56
+ super(AdamW, self).__init__(params, defaults)
57
+
58
+ def __setstate__(self, state):
59
+ super(AdamW, self).__setstate__(state)
60
+ for group in self.param_groups:
61
+ group.setdefault("amsgrad", False)
62
+
63
+ def step(self, closure=None):
64
+ """Performs a single optimization step.
65
+
66
+ Arguments:
67
+ closure (callable, optional): A closure that reevaluates the model
68
+ and returns the loss.
69
+ """
70
+ loss = None
71
+ if closure is not None:
72
+ loss = closure()
73
+
74
+ for group in self.param_groups:
75
+ for p in group["params"]:
76
+ if p.grad is None:
77
+ continue
78
+
79
+ # Perform stepweight decay
80
+ p.data.mul_(1 - group["lr"] * group["weight_decay"])
81
+
82
+ # Perform optimization step
83
+ grad = p.grad.data
84
+ if grad.is_sparse:
85
+ raise RuntimeError(
86
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
87
+ )
88
+ amsgrad = group["amsgrad"]
89
+
90
+ state = self.state[p]
91
+
92
+ # State initialization
93
+ if len(state) == 0:
94
+ state["step"] = 0
95
+ # Exponential moving average of gradient values
96
+ state["exp_avg"] = torch.zeros_like(p.data)
97
+ # Exponential moving average of squared gradient values
98
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
99
+ if amsgrad:
100
+ # Maintains max of all exp. moving avg. of sq. grad. values
101
+ state["max_exp_avg_sq"] = torch.zeros_like(p.data)
102
+
103
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
104
+ if amsgrad:
105
+ max_exp_avg_sq = state["max_exp_avg_sq"]
106
+ beta1, beta2 = group["betas"]
107
+
108
+ state["step"] += 1
109
+ bias_correction1 = 1 - beta1 ** state["step"]
110
+ bias_correction2 = 1 - beta2 ** state["step"]
111
+
112
+ # Decay the first and second moment running average coefficient
113
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
114
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
115
+ if amsgrad:
116
+ # Maintains the maximum of all 2nd moment running avg. till now
117
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
118
+ # Use the max. for normalizing running avg. of gradient
119
+ denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
120
+ group["eps"]
121
+ )
122
+ else:
123
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
124
+ group["eps"]
125
+ )
126
+
127
+ step_size = group["lr"] / bias_correction1
128
+
129
+ p.data.addcdiv_(-step_size, exp_avg, denom)
130
+
131
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/lookahead.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Lookahead Optimizer Wrapper.
2
+ Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3
+ Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ from collections import defaultdict
10
+
11
+
12
+ class Lookahead(Optimizer):
13
+ def __init__(self, base_optimizer, alpha=0.5, k=6):
14
+ if not 0.0 <= alpha <= 1.0:
15
+ raise ValueError(f"Invalid slow update rate: {alpha}")
16
+ if not 1 <= k:
17
+ raise ValueError(f"Invalid lookahead steps: {k}")
18
+ defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
19
+ self.base_optimizer = base_optimizer
20
+ self.param_groups = self.base_optimizer.param_groups
21
+ self.defaults = base_optimizer.defaults
22
+ self.defaults.update(defaults)
23
+ self.state = defaultdict(dict)
24
+ # manually add our defaults to the param groups
25
+ for name, default in defaults.items():
26
+ for group in self.param_groups:
27
+ group.setdefault(name, default)
28
+
29
+ def update_slow(self, group):
30
+ for fast_p in group["params"]:
31
+ if fast_p.grad is None:
32
+ continue
33
+ param_state = self.state[fast_p]
34
+ if "slow_buffer" not in param_state:
35
+ param_state["slow_buffer"] = torch.empty_like(fast_p.data)
36
+ param_state["slow_buffer"].copy_(fast_p.data)
37
+ slow = param_state["slow_buffer"]
38
+ slow.add_(group["lookahead_alpha"], fast_p.data - slow)
39
+ fast_p.data.copy_(slow)
40
+
41
+ def sync_lookahead(self):
42
+ for group in self.param_groups:
43
+ self.update_slow(group)
44
+
45
+ def step(self, closure=None):
46
+ # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
47
+ loss = self.base_optimizer.step(closure)
48
+ for group in self.param_groups:
49
+ group["lookahead_step"] += 1
50
+ if group["lookahead_step"] % group["lookahead_k"] == 0:
51
+ self.update_slow(group)
52
+ return loss
53
+
54
+ def state_dict(self):
55
+ fast_state_dict = self.base_optimizer.state_dict()
56
+ slow_state = {
57
+ (id(k) if isinstance(k, torch.Tensor) else k): v
58
+ for k, v in self.state.items()
59
+ }
60
+ fast_state = fast_state_dict["state"]
61
+ param_groups = fast_state_dict["param_groups"]
62
+ return {
63
+ "state": fast_state,
64
+ "slow_state": slow_state,
65
+ "param_groups": param_groups,
66
+ }
67
+
68
+ def load_state_dict(self, state_dict):
69
+ fast_state_dict = {
70
+ "state": state_dict["state"],
71
+ "param_groups": state_dict["param_groups"],
72
+ }
73
+ self.base_optimizer.load_state_dict(fast_state_dict)
74
+
75
+ # We want to restore the slow state, but share param_groups reference
76
+ # with base_optimizer. This is a bit redundant but least code
77
+ slow_state_new = False
78
+ if "slow_state" not in state_dict:
79
+ print("Loading state_dict from optimizer without Lookahead applied.")
80
+ state_dict["slow_state"] = defaultdict(dict)
81
+ slow_state_new = True
82
+ slow_state_dict = {
83
+ "state": state_dict["slow_state"],
84
+ "param_groups": state_dict[
85
+ "param_groups"
86
+ ], # this is pointless but saves code
87
+ }
88
+ super(Lookahead, self).load_state_dict(slow_state_dict)
89
+ self.param_groups = (
90
+ self.base_optimizer.param_groups
91
+ ) # make both ref same container
92
+ if slow_state_new:
93
+ # reapply defaults to catch missing lookahead specific ones
94
+ for name, default in self.defaults.items():
95
+ for group in self.param_groups:
96
+ group.setdefault(name, default)
Sample_Finetuning_SIIMACR/I1_classification/optim/nadam.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+
5
+ class Nadam(Optimizer):
6
+ """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7
+
8
+ It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9
+
10
+ Arguments:
11
+ params (iterable): iterable of parameters to optimize or dicts defining
12
+ parameter groups
13
+ lr (float, optional): learning rate (default: 2e-3)
14
+ betas (Tuple[float, float], optional): coefficients used for computing
15
+ running averages of gradient and its square
16
+ eps (float, optional): term added to the denominator to improve
17
+ numerical stability (default: 1e-8)
18
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19
+ schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20
+
21
+ __ http://cs229.stanford.edu/proj2015/054_report.pdf
22
+ __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23
+
24
+ Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25
+ NOTE: Has potential issues but does work well on some problems.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ params,
31
+ lr=2e-3,
32
+ betas=(0.9, 0.999),
33
+ eps=1e-8,
34
+ weight_decay=0,
35
+ schedule_decay=4e-3,
36
+ ):
37
+ defaults = dict(
38
+ lr=lr,
39
+ betas=betas,
40
+ eps=eps,
41
+ weight_decay=weight_decay,
42
+ schedule_decay=schedule_decay,
43
+ )
44
+ super(Nadam, self).__init__(params, defaults)
45
+
46
+ def step(self, closure=None):
47
+ """Performs a single optimization step.
48
+
49
+ Arguments:
50
+ closure (callable, optional): A closure that reevaluates the model
51
+ and returns the loss.
52
+ """
53
+ loss = None
54
+ if closure is not None:
55
+ loss = closure()
56
+
57
+ for group in self.param_groups:
58
+ for p in group["params"]:
59
+ if p.grad is None:
60
+ continue
61
+ grad = p.grad.data
62
+ state = self.state[p]
63
+
64
+ # State initialization
65
+ if len(state) == 0:
66
+ state["step"] = 0
67
+ state["m_schedule"] = 1.0
68
+ state["exp_avg"] = grad.new().resize_as_(grad).zero_()
69
+ state["exp_avg_sq"] = grad.new().resize_as_(grad).zero_()
70
+
71
+ # Warming momentum schedule
72
+ m_schedule = state["m_schedule"]
73
+ schedule_decay = group["schedule_decay"]
74
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
75
+ beta1, beta2 = group["betas"]
76
+ eps = group["eps"]
77
+ state["step"] += 1
78
+ t = state["step"]
79
+
80
+ if group["weight_decay"] != 0:
81
+ grad = grad.add(group["weight_decay"], p.data)
82
+
83
+ momentum_cache_t = beta1 * (1.0 - 0.5 * (0.96 ** (t * schedule_decay)))
84
+ momentum_cache_t_1 = beta1 * (
85
+ 1.0 - 0.5 * (0.96 ** ((t + 1) * schedule_decay))
86
+ )
87
+ m_schedule_new = m_schedule * momentum_cache_t
88
+ m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
89
+ state["m_schedule"] = m_schedule_new
90
+
91
+ # Decay the first and second moment running average coefficient
92
+ exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
93
+ exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
94
+ exp_avg_sq_prime = exp_avg_sq / (1.0 - beta2 ** t)
95
+ denom = exp_avg_sq_prime.sqrt_().add_(eps)
96
+
97
+ p.data.addcdiv_(
98
+ -group["lr"] * (1.0 - momentum_cache_t) / (1.0 - m_schedule_new),
99
+ grad,
100
+ denom,
101
+ )
102
+ p.data.addcdiv_(
103
+ -group["lr"] * momentum_cache_t_1 / (1.0 - m_schedule_next),
104
+ exp_avg,
105
+ denom,
106
+ )
107
+
108
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/novograd.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NovoGrad Optimizer.
2
+ Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3
+ Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4
+ - https://arxiv.org/abs/1905.11286
5
+ """
6
+
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ import math
10
+
11
+
12
+ class NovoGrad(Optimizer):
13
+ def __init__(
14
+ self,
15
+ params,
16
+ grad_averaging=False,
17
+ lr=0.1,
18
+ betas=(0.95, 0.98),
19
+ eps=1e-8,
20
+ weight_decay=0,
21
+ ):
22
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
23
+ super(NovoGrad, self).__init__(params, defaults)
24
+ self._lr = lr
25
+ self._beta1 = betas[0]
26
+ self._beta2 = betas[1]
27
+ self._eps = eps
28
+ self._wd = weight_decay
29
+ self._grad_averaging = grad_averaging
30
+
31
+ self._momentum_initialized = False
32
+
33
+ def step(self, closure=None):
34
+ loss = None
35
+ if closure is not None:
36
+ loss = closure()
37
+
38
+ if not self._momentum_initialized:
39
+ for group in self.param_groups:
40
+ for p in group["params"]:
41
+ if p.grad is None:
42
+ continue
43
+ state = self.state[p]
44
+ grad = p.grad.data
45
+ if grad.is_sparse:
46
+ raise RuntimeError("NovoGrad does not support sparse gradients")
47
+
48
+ v = torch.norm(grad) ** 2
49
+ m = grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
50
+ state["step"] = 0
51
+ state["v"] = v
52
+ state["m"] = m
53
+ state["grad_ema"] = None
54
+ self._momentum_initialized = True
55
+
56
+ for group in self.param_groups:
57
+ for p in group["params"]:
58
+ if p.grad is None:
59
+ continue
60
+ state = self.state[p]
61
+ state["step"] += 1
62
+
63
+ step, v, m = state["step"], state["v"], state["m"]
64
+ grad_ema = state["grad_ema"]
65
+
66
+ grad = p.grad.data
67
+ g2 = torch.norm(grad) ** 2
68
+ grad_ema = (
69
+ g2
70
+ if grad_ema is None
71
+ else grad_ema * self._beta2 + g2 * (1.0 - self._beta2)
72
+ )
73
+ grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
74
+
75
+ if self._grad_averaging:
76
+ grad *= 1.0 - self._beta1
77
+
78
+ g2 = torch.norm(grad) ** 2
79
+ v = self._beta2 * v + (1.0 - self._beta2) * g2
80
+ m = self._beta1 * m + (
81
+ grad / (torch.sqrt(v) + self._eps) + self._wd * p.data
82
+ )
83
+ bias_correction1 = 1 - self._beta1 ** step
84
+ bias_correction2 = 1 - self._beta2 ** step
85
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
86
+
87
+ state["v"], state["m"] = v, m
88
+ state["grad_ema"] = grad_ema
89
+ p.data.add_(-step_size, m)
90
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/nvnovograd.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Nvidia NovoGrad Optimizer.
2
+ Original impl by Nvidia from Jasper example:
3
+ - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4
+ Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5
+ - https://arxiv.org/abs/1905.11286
6
+ """
7
+
8
+ import torch
9
+ from torch.optim.optimizer import Optimizer
10
+ import math
11
+
12
+
13
+ class NvNovoGrad(Optimizer):
14
+ """
15
+ Implements Novograd algorithm.
16
+
17
+ Args:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float, optional): learning rate (default: 1e-3)
21
+ betas (Tuple[float, float], optional): coefficients used for computing
22
+ running averages of gradient and its square (default: (0.95, 0.98))
23
+ eps (float, optional): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26
+ grad_averaging: gradient averaging
27
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
29
+ (default: False)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ params,
35
+ lr=1e-3,
36
+ betas=(0.95, 0.98),
37
+ eps=1e-8,
38
+ weight_decay=0,
39
+ grad_averaging=False,
40
+ amsgrad=False,
41
+ ):
42
+ if not 0.0 <= lr:
43
+ raise ValueError("Invalid learning rate: {}".format(lr))
44
+ if not 0.0 <= eps:
45
+ raise ValueError("Invalid epsilon value: {}".format(eps))
46
+ if not 0.0 <= betas[0] < 1.0:
47
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
48
+ if not 0.0 <= betas[1] < 1.0:
49
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
50
+ defaults = dict(
51
+ lr=lr,
52
+ betas=betas,
53
+ eps=eps,
54
+ weight_decay=weight_decay,
55
+ grad_averaging=grad_averaging,
56
+ amsgrad=amsgrad,
57
+ )
58
+
59
+ super(NvNovoGrad, self).__init__(params, defaults)
60
+
61
+ def __setstate__(self, state):
62
+ super(NvNovoGrad, self).__setstate__(state)
63
+ for group in self.param_groups:
64
+ group.setdefault("amsgrad", False)
65
+
66
+ def step(self, closure=None):
67
+ """Performs a single optimization step.
68
+
69
+ Arguments:
70
+ closure (callable, optional): A closure that reevaluates the model
71
+ and returns the loss.
72
+ """
73
+ loss = None
74
+ if closure is not None:
75
+ loss = closure()
76
+
77
+ for group in self.param_groups:
78
+ for p in group["params"]:
79
+ if p.grad is None:
80
+ continue
81
+ grad = p.grad.data
82
+ if grad.is_sparse:
83
+ raise RuntimeError("Sparse gradients are not supported.")
84
+ amsgrad = group["amsgrad"]
85
+
86
+ state = self.state[p]
87
+
88
+ # State initialization
89
+ if len(state) == 0:
90
+ state["step"] = 0
91
+ # Exponential moving average of gradient values
92
+ state["exp_avg"] = torch.zeros_like(p.data)
93
+ # Exponential moving average of squared gradient values
94
+ state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
95
+ if amsgrad:
96
+ # Maintains max of all exp. moving avg. of sq. grad. values
97
+ state["max_exp_avg_sq"] = torch.zeros([]).to(
98
+ state["exp_avg"].device
99
+ )
100
+
101
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
102
+ if amsgrad:
103
+ max_exp_avg_sq = state["max_exp_avg_sq"]
104
+ beta1, beta2 = group["betas"]
105
+
106
+ state["step"] += 1
107
+
108
+ norm = torch.sum(torch.pow(grad, 2))
109
+
110
+ if exp_avg_sq == 0:
111
+ exp_avg_sq.copy_(norm)
112
+ else:
113
+ exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
114
+
115
+ if amsgrad:
116
+ # Maintains the maximum of all 2nd moment running avg. till now
117
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
118
+ # Use the max. for normalizing running avg. of gradient
119
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"])
120
+ else:
121
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
122
+
123
+ grad.div_(denom)
124
+ if group["weight_decay"] != 0:
125
+ grad.add_(group["weight_decay"], p.data)
126
+ if group["grad_averaging"]:
127
+ grad.mul_(1 - beta1)
128
+ exp_avg.mul_(beta1).add_(grad)
129
+
130
+ p.data.add_(-group["lr"], exp_avg)
131
+
132
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/optim_factory.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Optimizer Factory w/ Custom Weight Decay
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ from torch import optim as optim
6
+
7
+ from .adafactor import Adafactor
8
+ from .adahessian import Adahessian
9
+ from .adamp import AdamP
10
+ from .lookahead import Lookahead
11
+ from .nadam import Nadam
12
+ from .novograd import NovoGrad
13
+ from .nvnovograd import NvNovoGrad
14
+ from .radam import RAdam
15
+ from .rmsprop_tf import RMSpropTF
16
+ from .sgdp import SGDP
17
+
18
+ try:
19
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
20
+
21
+ has_apex = True
22
+ except ImportError:
23
+ has_apex = False
24
+
25
+
26
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
27
+ decay = []
28
+ no_decay = []
29
+ for name, param in model.named_parameters():
30
+ if not param.requires_grad:
31
+ continue # frozen weights
32
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
33
+ no_decay.append(param)
34
+ else:
35
+ decay.append(param)
36
+ return [
37
+ {"params": no_decay, "weight_decay": 0.0},
38
+ {"params": decay, "weight_decay": weight_decay},
39
+ ]
40
+
41
+
42
+ def create_optimizer(args, model, filter_bias_and_bn=True):
43
+ opt_lower = args.opt.lower()
44
+ weight_decay = args.weight_decay
45
+ if weight_decay and filter_bias_and_bn:
46
+ skip = {}
47
+ if hasattr(model, "no_weight_decay"):
48
+ skip = model.no_weight_decay()
49
+ parameters = add_weight_decay(model, weight_decay, skip)
50
+ weight_decay = 0.0
51
+ else:
52
+ parameters = filter(
53
+ lambda p: p.requires_grad, model.parameters()
54
+ ) # model.parameters()
55
+
56
+ if "fused" in opt_lower:
57
+ assert (
58
+ has_apex and torch.cuda.is_available()
59
+ ), "APEX and CUDA required for fused optimizers"
60
+
61
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
62
+ if hasattr(args, "opt_eps") and args.opt_eps is not None:
63
+ opt_args["eps"] = args.opt_eps
64
+ if hasattr(args, "opt_betas") and args.opt_betas is not None:
65
+ opt_args["betas"] = args.opt_betas
66
+ if hasattr(args, "opt_args") and args.opt_args is not None:
67
+ opt_args.update(args.opt_args)
68
+
69
+ opt_split = opt_lower.split("_")
70
+ opt_lower = opt_split[-1]
71
+ if opt_lower == "sgd" or opt_lower == "nesterov":
72
+ opt_args.pop("eps", None)
73
+ optimizer = optim.SGD(
74
+ parameters, momentum=args.momentum, nesterov=True, **opt_args
75
+ )
76
+ elif opt_lower == "momentum":
77
+ opt_args.pop("eps", None)
78
+ optimizer = optim.SGD(
79
+ parameters, momentum=args.momentum, nesterov=False, **opt_args
80
+ )
81
+ elif opt_lower == "adam":
82
+ optimizer = optim.Adam(parameters, **opt_args)
83
+ elif opt_lower == "adamw":
84
+ optimizer = optim.AdamW(parameters, **opt_args)
85
+ elif opt_lower == "nadam":
86
+ optimizer = Nadam(parameters, **opt_args)
87
+ elif opt_lower == "radam":
88
+ optimizer = RAdam(parameters, **opt_args)
89
+ elif opt_lower == "adamp":
90
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
91
+ elif opt_lower == "sgdp":
92
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
93
+ elif opt_lower == "adadelta":
94
+ optimizer = optim.Adadelta(parameters, **opt_args)
95
+ elif opt_lower == "adafactor":
96
+ if not args.lr:
97
+ opt_args["lr"] = None
98
+ optimizer = Adafactor(parameters, **opt_args)
99
+ elif opt_lower == "adahessian":
100
+ optimizer = Adahessian(parameters, **opt_args)
101
+ elif opt_lower == "rmsprop":
102
+ optimizer = optim.RMSprop(
103
+ parameters, alpha=0.9, momentum=args.momentum, **opt_args
104
+ )
105
+ elif opt_lower == "rmsproptf":
106
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
107
+ elif opt_lower == "novograd":
108
+ optimizer = NovoGrad(parameters, **opt_args)
109
+ elif opt_lower == "nvnovograd":
110
+ optimizer = NvNovoGrad(parameters, **opt_args)
111
+ elif opt_lower == "fusedsgd":
112
+ opt_args.pop("eps", None)
113
+ optimizer = FusedSGD(
114
+ parameters, momentum=args.momentum, nesterov=True, **opt_args
115
+ )
116
+ elif opt_lower == "fusedmomentum":
117
+ opt_args.pop("eps", None)
118
+ optimizer = FusedSGD(
119
+ parameters, momentum=args.momentum, nesterov=False, **opt_args
120
+ )
121
+ elif opt_lower == "fusedadam":
122
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
123
+ elif opt_lower == "fusedadamw":
124
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
125
+ elif opt_lower == "fusedlamb":
126
+ optimizer = FusedLAMB(parameters, **opt_args)
127
+ elif opt_lower == "fusednovograd":
128
+ opt_args.setdefault("betas", (0.95, 0.98))
129
+ optimizer = FusedNovoGrad(parameters, **opt_args)
130
+ else:
131
+ assert False and "Invalid optimizer"
132
+ raise ValueError
133
+
134
+ if len(opt_split) > 1:
135
+ if opt_split[0] == "lookahead":
136
+ optimizer = Lookahead(optimizer)
137
+
138
+ return optimizer
Sample_Finetuning_SIIMACR/I1_classification/optim/radam.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAdam Optimizer.
2
+ Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3
+ Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4
+ """
5
+ import math
6
+ import torch
7
+ from torch.optim.optimizer import Optimizer, required
8
+
9
+
10
+ class RAdam(Optimizer):
11
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
12
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
13
+ self.buffer = [[None, None, None] for ind in range(10)]
14
+ super(RAdam, self).__init__(params, defaults)
15
+
16
+ def __setstate__(self, state):
17
+ super(RAdam, self).__setstate__(state)
18
+
19
+ def step(self, closure=None):
20
+
21
+ loss = None
22
+ if closure is not None:
23
+ loss = closure()
24
+
25
+ for group in self.param_groups:
26
+
27
+ for p in group["params"]:
28
+ if p.grad is None:
29
+ continue
30
+ grad = p.grad.data.float()
31
+ if grad.is_sparse:
32
+ raise RuntimeError("RAdam does not support sparse gradients")
33
+
34
+ p_data_fp32 = p.data.float()
35
+
36
+ state = self.state[p]
37
+
38
+ if len(state) == 0:
39
+ state["step"] = 0
40
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
41
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
42
+ else:
43
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
44
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
45
+
46
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
47
+ beta1, beta2 = group["betas"]
48
+
49
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
50
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
51
+
52
+ state["step"] += 1
53
+ buffered = self.buffer[int(state["step"] % 10)]
54
+ if state["step"] == buffered[0]:
55
+ N_sma, step_size = buffered[1], buffered[2]
56
+ else:
57
+ buffered[0] = state["step"]
58
+ beta2_t = beta2 ** state["step"]
59
+ N_sma_max = 2 / (1 - beta2) - 1
60
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
61
+ buffered[1] = N_sma
62
+
63
+ # more conservative since it's an approximated value
64
+ if N_sma >= 5:
65
+ step_size = (
66
+ group["lr"]
67
+ * math.sqrt(
68
+ (1 - beta2_t)
69
+ * (N_sma - 4)
70
+ / (N_sma_max - 4)
71
+ * (N_sma - 2)
72
+ / N_sma
73
+ * N_sma_max
74
+ / (N_sma_max - 2)
75
+ )
76
+ / (1 - beta1 ** state["step"])
77
+ )
78
+ else:
79
+ step_size = group["lr"] / (1 - beta1 ** state["step"])
80
+ buffered[2] = step_size
81
+
82
+ if group["weight_decay"] != 0:
83
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
84
+
85
+ # more conservative since it's an approximated value
86
+ if N_sma >= 5:
87
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
88
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
89
+ else:
90
+ p_data_fp32.add_(-step_size, exp_avg)
91
+
92
+ p.data.copy_(p_data_fp32)
93
+
94
+ return loss
95
+
96
+
97
+ class PlainRAdam(Optimizer):
98
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
99
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
100
+
101
+ super(PlainRAdam, self).__init__(params, defaults)
102
+
103
+ def __setstate__(self, state):
104
+ super(PlainRAdam, self).__setstate__(state)
105
+
106
+ def step(self, closure=None):
107
+
108
+ loss = None
109
+ if closure is not None:
110
+ loss = closure()
111
+
112
+ for group in self.param_groups:
113
+
114
+ for p in group["params"]:
115
+ if p.grad is None:
116
+ continue
117
+ grad = p.grad.data.float()
118
+ if grad.is_sparse:
119
+ raise RuntimeError("RAdam does not support sparse gradients")
120
+
121
+ p_data_fp32 = p.data.float()
122
+
123
+ state = self.state[p]
124
+
125
+ if len(state) == 0:
126
+ state["step"] = 0
127
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
128
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
129
+ else:
130
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
131
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
132
+
133
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
134
+ beta1, beta2 = group["betas"]
135
+
136
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
137
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
138
+
139
+ state["step"] += 1
140
+ beta2_t = beta2 ** state["step"]
141
+ N_sma_max = 2 / (1 - beta2) - 1
142
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
143
+
144
+ if group["weight_decay"] != 0:
145
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
146
+
147
+ # more conservative since it's an approximated value
148
+ if N_sma >= 5:
149
+ step_size = (
150
+ group["lr"]
151
+ * math.sqrt(
152
+ (1 - beta2_t)
153
+ * (N_sma - 4)
154
+ / (N_sma_max - 4)
155
+ * (N_sma - 2)
156
+ / N_sma
157
+ * N_sma_max
158
+ / (N_sma_max - 2)
159
+ )
160
+ / (1 - beta1 ** state["step"])
161
+ )
162
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
163
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
164
+ else:
165
+ step_size = group["lr"] / (1 - beta1 ** state["step"])
166
+ p_data_fp32.add_(-step_size, exp_avg)
167
+
168
+ p.data.copy_(p_data_fp32)
169
+
170
+ return loss
Sample_Finetuning_SIIMACR/I1_classification/optim/rmsprop_tf.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RMSProp modified to behave like Tensorflow impl
2
+
3
+ Originally cut & paste from PyTorch RMSProp
4
+ https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5
+ Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6
+
7
+ Modifications Copyright 2020 Ross Wightman
8
+ """
9
+
10
+ import torch
11
+ from torch.optim import Optimizer
12
+
13
+
14
+ class RMSpropTF(Optimizer):
15
+ """Implements RMSprop algorithm (TensorFlow style epsilon)
16
+
17
+ NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18
+ and a few other modifications to closer match Tensorflow for matching hyper-params.
19
+
20
+ Noteworthy changes include:
21
+ 1. Epsilon applied inside square-root
22
+ 2. square_avg initialized to ones
23
+ 3. LR scaling of update accumulated in momentum buffer
24
+
25
+ Proposed by G. Hinton in his
26
+ `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
27
+
28
+ The centered version first appears in `Generating Sequences
29
+ With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
30
+
31
+ Arguments:
32
+ params (iterable): iterable of parameters to optimize or dicts defining
33
+ parameter groups
34
+ lr (float, optional): learning rate (default: 1e-2)
35
+ momentum (float, optional): momentum factor (default: 0)
36
+ alpha (float, optional): smoothing (decay) constant (default: 0.9)
37
+ eps (float, optional): term added to the denominator to improve
38
+ numerical stability (default: 1e-10)
39
+ centered (bool, optional) : if ``True``, compute the centered RMSProp,
40
+ the gradient is normalized by an estimation of its variance
41
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42
+ decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43
+ lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44
+ update as per defaults in Tensorflow
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ params,
51
+ lr=1e-2,
52
+ alpha=0.9,
53
+ eps=1e-10,
54
+ weight_decay=0,
55
+ momentum=0.0,
56
+ centered=False,
57
+ decoupled_decay=False,
58
+ lr_in_momentum=True,
59
+ ):
60
+ if not 0.0 <= lr:
61
+ raise ValueError("Invalid learning rate: {}".format(lr))
62
+ if not 0.0 <= eps:
63
+ raise ValueError("Invalid epsilon value: {}".format(eps))
64
+ if not 0.0 <= momentum:
65
+ raise ValueError("Invalid momentum value: {}".format(momentum))
66
+ if not 0.0 <= weight_decay:
67
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
68
+ if not 0.0 <= alpha:
69
+ raise ValueError("Invalid alpha value: {}".format(alpha))
70
+
71
+ defaults = dict(
72
+ lr=lr,
73
+ momentum=momentum,
74
+ alpha=alpha,
75
+ eps=eps,
76
+ centered=centered,
77
+ weight_decay=weight_decay,
78
+ decoupled_decay=decoupled_decay,
79
+ lr_in_momentum=lr_in_momentum,
80
+ )
81
+ super(RMSpropTF, self).__init__(params, defaults)
82
+
83
+ def __setstate__(self, state):
84
+ super(RMSpropTF, self).__setstate__(state)
85
+ for group in self.param_groups:
86
+ group.setdefault("momentum", 0)
87
+ group.setdefault("centered", False)
88
+
89
+ def step(self, closure=None):
90
+ """Performs a single optimization step.
91
+
92
+ Arguments:
93
+ closure (callable, optional): A closure that reevaluates the model
94
+ and returns the loss.
95
+ """
96
+ loss = None
97
+ if closure is not None:
98
+ loss = closure()
99
+
100
+ for group in self.param_groups:
101
+ for p in group["params"]:
102
+ if p.grad is None:
103
+ continue
104
+ grad = p.grad.data
105
+ if grad.is_sparse:
106
+ raise RuntimeError("RMSprop does not support sparse gradients")
107
+ state = self.state[p]
108
+
109
+ # State initialization
110
+ if len(state) == 0:
111
+ state["step"] = 0
112
+ state["square_avg"] = torch.ones_like(
113
+ p.data
114
+ ) # PyTorch inits to zero
115
+ if group["momentum"] > 0:
116
+ state["momentum_buffer"] = torch.zeros_like(p.data)
117
+ if group["centered"]:
118
+ state["grad_avg"] = torch.zeros_like(p.data)
119
+
120
+ square_avg = state["square_avg"]
121
+ one_minus_alpha = 1.0 - group["alpha"]
122
+
123
+ state["step"] += 1
124
+
125
+ if group["weight_decay"] != 0:
126
+ if "decoupled_decay" in group and group["decoupled_decay"]:
127
+ p.data.add_(-group["weight_decay"], p.data)
128
+ else:
129
+ grad = grad.add(group["weight_decay"], p.data)
130
+
131
+ # Tensorflow order of ops for updating squared avg
132
+ square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
133
+ # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
134
+
135
+ if group["centered"]:
136
+ grad_avg = state["grad_avg"]
137
+ grad_avg.add_(one_minus_alpha, grad - grad_avg)
138
+ # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
139
+ avg = (
140
+ square_avg.addcmul(-1, grad_avg, grad_avg)
141
+ .add(group["eps"])
142
+ .sqrt_()
143
+ ) # eps moved in sqrt
144
+ else:
145
+ avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
146
+
147
+ if group["momentum"] > 0:
148
+ buf = state["momentum_buffer"]
149
+ # Tensorflow accumulates the LR scaling in the momentum buffer
150
+ if "lr_in_momentum" in group and group["lr_in_momentum"]:
151
+ buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
152
+ p.data.add_(-buf)
153
+ else:
154
+ # PyTorch scales the param update by LR
155
+ buf.mul_(group["momentum"]).addcdiv_(grad, avg)
156
+ p.data.add_(-group["lr"], buf)
157
+ else:
158
+ p.data.addcdiv_(-group["lr"], grad, avg)
159
+
160
+ return loss