Update training scripts
Browse files- .gitattributes +1 -0
- data/train_dataset.json +3 -0
- data/valid_dataset.json +3 -0
- requirements.txt +8 -0
- run_medclip.sh +15 -0
- src/hybrid_clip/utils/create_roco_dataset.py +0 -0
- src/hybrid_clip/utils/roco_dataset.ipynb +123 -0
.gitattributes
CHANGED
@@ -14,3 +14,4 @@
|
|
14 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
14 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
data/train_dataset.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6f8f9ecea3f4c6f8196f194510159fccde43ee7f2192b259a11d6bc9ad684cb
|
3 |
+
size 13426560
|
data/valid_dataset.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7dbb940f0dee7cb4a85959dc6018aafc824a988b46e3ae8ca2fea6500251ee0a
|
3 |
+
size 4132661
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jax>=0.2.8
|
2 |
+
jaxlib>=0.1.59
|
3 |
+
flax>=0.3.4
|
4 |
+
optax>=0.0.8
|
5 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
6 |
+
torch==1.9.0+cpu
|
7 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
8 |
+
torchvision==0.10.0+cpu
|
run_medclip.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/hybrid_clip/run_hybrid_clip.py \
|
2 |
+
--output_dir snapshots \
|
3 |
+
--text_model_name_or_path="roberta-base" \
|
4 |
+
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
+
--tokenizer_name="roberta-base" \
|
6 |
+
--train_file="data/train_dataset.json" \
|
7 |
+
--validation_file="data/valid_dataset.json" \
|
8 |
+
--do_train --do_eval \
|
9 |
+
--num_train_epochs="40" --max_seq_length 96 \
|
10 |
+
--per_device_train_batch_size="64" \
|
11 |
+
--per_device_eval_batch_size="64" \
|
12 |
+
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
13 |
+
--overwrite_output_dir \
|
14 |
+
--preprocessing_num_workers 32 \
|
15 |
+
# --push_to_hub
|
src/hybrid_clip/utils/create_roco_dataset.py
ADDED
File without changes
|
src/hybrid_clip/utils/roco_dataset.ipynb
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"source": [
|
7 |
+
"import json\n",
|
8 |
+
"import os\n",
|
9 |
+
"import pandas as pd\n",
|
10 |
+
"import matplotlib.pyplot as plt"
|
11 |
+
],
|
12 |
+
"outputs": [],
|
13 |
+
"metadata": {}
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"source": [
|
19 |
+
"train_path = '/home/kaumad/roco-dataset/train'"
|
20 |
+
],
|
21 |
+
"outputs": [],
|
22 |
+
"metadata": {}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 3,
|
27 |
+
"source": [
|
28 |
+
"img_dir = os.path.join(train_path, 'radiology', 'images')"
|
29 |
+
],
|
30 |
+
"outputs": [],
|
31 |
+
"metadata": {}
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 5,
|
36 |
+
"source": [
|
37 |
+
"train_csv = pd.read_csv(os.path.join(train_path, 'radiology', 'traindata.csv'))"
|
38 |
+
],
|
39 |
+
"outputs": [],
|
40 |
+
"metadata": {}
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 6,
|
45 |
+
"source": [
|
46 |
+
"lines = []\n",
|
47 |
+
"for id, row in train_csv.iterrows():\n",
|
48 |
+
" img_path = os.path.join(img_dir, 'radiology', row['name'])\n",
|
49 |
+
" line = json.dumps({\"image_path\": img_path, \"captions\": row['caption']})\n",
|
50 |
+
" lines.append(line)\n",
|
51 |
+
" # if id>100:\n",
|
52 |
+
" # break\n"
|
53 |
+
],
|
54 |
+
"outputs": [],
|
55 |
+
"metadata": {}
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 7,
|
60 |
+
"source": [
|
61 |
+
"len(lines)"
|
62 |
+
],
|
63 |
+
"outputs": [
|
64 |
+
{
|
65 |
+
"output_type": "execute_result",
|
66 |
+
"data": {
|
67 |
+
"text/plain": [
|
68 |
+
"65450"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
"metadata": {},
|
72 |
+
"execution_count": 7
|
73 |
+
}
|
74 |
+
],
|
75 |
+
"metadata": {}
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "code",
|
79 |
+
"execution_count": 8,
|
80 |
+
"source": [
|
81 |
+
"train_lines = lines[:45000]\n",
|
82 |
+
"val_lines = lines[-45000:]\n",
|
83 |
+
"\n",
|
84 |
+
"json_dir = '../../../data'\n",
|
85 |
+
"with open(os.path.join(json_dir, \"train_dataset.json\"), \"w\") as f:\n",
|
86 |
+
" f.write(\"\\n\".join(train_lines))\n",
|
87 |
+
"\n",
|
88 |
+
"with open(os.path.join(json_dir, \"valid_dataset.json\"), \"w\") as f:\n",
|
89 |
+
" f.write(\"\\n\".join(val_lines))"
|
90 |
+
],
|
91 |
+
"outputs": [],
|
92 |
+
"metadata": {}
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 14,
|
97 |
+
"source": [
|
98 |
+
"os.listdir('../../../data')"
|
99 |
+
],
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"output_type": "execute_result",
|
103 |
+
"data": {
|
104 |
+
"text/plain": [
|
105 |
+
"['train_dataset.json', 'valid_dataset.json']"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
"metadata": {},
|
109 |
+
"execution_count": 14
|
110 |
+
}
|
111 |
+
],
|
112 |
+
"metadata": {}
|
113 |
+
}
|
114 |
+
],
|
115 |
+
"metadata": {
|
116 |
+
"orig_nbformat": 4,
|
117 |
+
"language_info": {
|
118 |
+
"name": "python"
|
119 |
+
}
|
120 |
+
},
|
121 |
+
"nbformat": 4,
|
122 |
+
"nbformat_minor": 2
|
123 |
+
}
|