kaushalya commited on
Commit
9f17cdb
·
1 Parent(s): aaa16b9

Update training scripts

Browse files
.gitattributes CHANGED
@@ -14,3 +14,4 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *.json filter=lfs diff=lfs merge=lfs -text
data/train_dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6f8f9ecea3f4c6f8196f194510159fccde43ee7f2192b259a11d6bc9ad684cb
3
+ size 13426560
data/valid_dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dbb940f0dee7cb4a85959dc6018aafc824a988b46e3ae8ca2fea6500251ee0a
3
+ size 4132661
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ jax>=0.2.8
2
+ jaxlib>=0.1.59
3
+ flax>=0.3.4
4
+ optax>=0.0.8
5
+ -f https://download.pytorch.org/whl/torch_stable.html
6
+ torch==1.9.0+cpu
7
+ -f https://download.pytorch.org/whl/torch_stable.html
8
+ torchvision==0.10.0+cpu
run_medclip.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python src/hybrid_clip/run_hybrid_clip.py \
2
+ --output_dir snapshots \
3
+ --text_model_name_or_path="roberta-base" \
4
+ --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
+ --tokenizer_name="roberta-base" \
6
+ --train_file="data/train_dataset.json" \
7
+ --validation_file="data/valid_dataset.json" \
8
+ --do_train --do_eval \
9
+ --num_train_epochs="40" --max_seq_length 96 \
10
+ --per_device_train_batch_size="64" \
11
+ --per_device_eval_batch_size="64" \
12
+ --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
13
+ --overwrite_output_dir \
14
+ --preprocessing_num_workers 32 \
15
+ # --push_to_hub
src/hybrid_clip/utils/create_roco_dataset.py ADDED
File without changes
src/hybrid_clip/utils/roco_dataset.ipynb ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "source": [
7
+ "import json\n",
8
+ "import os\n",
9
+ "import pandas as pd\n",
10
+ "import matplotlib.pyplot as plt"
11
+ ],
12
+ "outputs": [],
13
+ "metadata": {}
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "source": [
19
+ "train_path = '/home/kaumad/roco-dataset/train'"
20
+ ],
21
+ "outputs": [],
22
+ "metadata": {}
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "source": [
28
+ "img_dir = os.path.join(train_path, 'radiology', 'images')"
29
+ ],
30
+ "outputs": [],
31
+ "metadata": {}
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 5,
36
+ "source": [
37
+ "train_csv = pd.read_csv(os.path.join(train_path, 'radiology', 'traindata.csv'))"
38
+ ],
39
+ "outputs": [],
40
+ "metadata": {}
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 6,
45
+ "source": [
46
+ "lines = []\n",
47
+ "for id, row in train_csv.iterrows():\n",
48
+ " img_path = os.path.join(img_dir, 'radiology', row['name'])\n",
49
+ " line = json.dumps({\"image_path\": img_path, \"captions\": row['caption']})\n",
50
+ " lines.append(line)\n",
51
+ " # if id>100:\n",
52
+ " # break\n"
53
+ ],
54
+ "outputs": [],
55
+ "metadata": {}
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 7,
60
+ "source": [
61
+ "len(lines)"
62
+ ],
63
+ "outputs": [
64
+ {
65
+ "output_type": "execute_result",
66
+ "data": {
67
+ "text/plain": [
68
+ "65450"
69
+ ]
70
+ },
71
+ "metadata": {},
72
+ "execution_count": 7
73
+ }
74
+ ],
75
+ "metadata": {}
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 8,
80
+ "source": [
81
+ "train_lines = lines[:45000]\n",
82
+ "val_lines = lines[-45000:]\n",
83
+ "\n",
84
+ "json_dir = '../../../data'\n",
85
+ "with open(os.path.join(json_dir, \"train_dataset.json\"), \"w\") as f:\n",
86
+ " f.write(\"\\n\".join(train_lines))\n",
87
+ "\n",
88
+ "with open(os.path.join(json_dir, \"valid_dataset.json\"), \"w\") as f:\n",
89
+ " f.write(\"\\n\".join(val_lines))"
90
+ ],
91
+ "outputs": [],
92
+ "metadata": {}
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 14,
97
+ "source": [
98
+ "os.listdir('../../../data')"
99
+ ],
100
+ "outputs": [
101
+ {
102
+ "output_type": "execute_result",
103
+ "data": {
104
+ "text/plain": [
105
+ "['train_dataset.json', 'valid_dataset.json']"
106
+ ]
107
+ },
108
+ "metadata": {},
109
+ "execution_count": 14
110
+ }
111
+ ],
112
+ "metadata": {}
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "orig_nbformat": 4,
117
+ "language_info": {
118
+ "name": "python"
119
+ }
120
+ },
121
+ "nbformat": 4,
122
+ "nbformat_minor": 2
123
+ }