Spaces:
Sleeping
Sleeping
Upload 226 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- AKSHAYRAJAA/README.md +52 -0
- AKSHAYRAJAA/__pycache__/config.cpython-39.pyc +0 -0
- AKSHAYRAJAA/__pycache__/dataset.cpython-39.pyc +0 -0
- AKSHAYRAJAA/__pycache__/model.cpython-39.pyc +0 -0
- AKSHAYRAJAA/__pycache__/utils.cpython-39.pyc +0 -0
- AKSHAYRAJAA/checkpoints.zip +3 -0
- AKSHAYRAJAA/checkpoints/x_ray_model.pth.tar +3 -0
- AKSHAYRAJAA/config.py +36 -0
- AKSHAYRAJAA/dataset.py +165 -0
- AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-3001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1001_IM-0004-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1001_IM-0004-1002.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1002_IM-0004-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1002_IM-0004-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1003_IM-0005-2002.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1004_IM-0005-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1004_IM-0005-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1005_IM-0006-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1005_IM-0006-3003.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1006_IM-0007-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1006_IM-0007-3003.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-3001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1008_IM-0009-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1008_IM-0009-4004.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1009_IM-0010-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1009_IM-0010-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR100_IM-0002-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR100_IM-0002-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1010_IM-0012-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1010_IM-0012-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1011_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1011_IM-0013-1002.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1012_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1013_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1013_IM-0013-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1014_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1014_IM-0013-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1015_IM-0001-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1015_IM-0001-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1015_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1015_IM-0013-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1016_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1016_IM-0013-2001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1017_IM-0013-1001.png +0 -0
- AKSHAYRAJAA/dataset/images/CXR1017_IM-0013-1002.png +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
AKSHAYRAJAA/ngrok-v3-stable-linux-amd64.zip.1 filter=lfs diff=lfs merge=lfs -text
|
AKSHAYRAJAA/README.md
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chest X-Ray Report Generator
|
2 |
+
|
3 |
+
> This project is part of a task for the college where I study, so `task-parts` contains files that associated with that task, whishing that I would get the full mark ;). In general the base code doesn't have any special parts except that folder.
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
7 |
+
After cloning the repository, install the required packages in a virtual environment.
|
8 |
+
|
9 |
+
Next, download the datasets and checkpoints, as describe below.
|
10 |
+
|
11 |
+
## Dataset
|
12 |
+
|
13 |
+
### IU X-Ray
|
14 |
+
|
15 |
+
1. Download the Chen et al. labels and the chest X-rays in png format for IU X-Ray from:
|
16 |
+
|
17 |
+
```
|
18 |
+
https://openi.nlm.nih.gov
|
19 |
+
```
|
20 |
+
|
21 |
+
2. Place the files into `dataset` folder, such that their paths are `dataset/reports` and `dataset/images`.
|
22 |
+
|
23 |
+
## Checkpoints
|
24 |
+
|
25 |
+
This approach uses `CheXNet`, and `DenseNet121` as a CNN Encoder model. By default the `CheXNet` pretrained weights are located in `weights` folder.
|
26 |
+
|
27 |
+
## Config
|
28 |
+
|
29 |
+
The model configurations for each task can be found in its `config.py` file.
|
30 |
+
|
31 |
+
## Training and Evaluation
|
32 |
+
|
33 |
+
### Training
|
34 |
+
|
35 |
+
Use the below command to train the model form a saved checkpoint or without a checkpoint.
|
36 |
+
|
37 |
+
```bash
|
38 |
+
python train.py
|
39 |
+
```
|
40 |
+
|
41 |
+
### Evaluation
|
42 |
+
|
43 |
+
The model performance measure is based of the `BLEU` metric.
|
44 |
+
|
45 |
+
> Feel free to change the performance measure metric in the `check_accuracy` method that is located in the `eval.py` file
|
46 |
+
|
47 |
+
Run the following command to calculate `BLEU` score.
|
48 |
+
|
49 |
+
```bash
|
50 |
+
python eval.py
|
51 |
+
```
|
52 |
+
|
AKSHAYRAJAA/__pycache__/config.cpython-39.pyc
ADDED
Binary file (829 Bytes). View file
|
|
AKSHAYRAJAA/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (4.9 kB). View file
|
|
AKSHAYRAJAA/__pycache__/model.cpython-39.pyc
ADDED
Binary file (5.96 kB). View file
|
|
AKSHAYRAJAA/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.74 kB). View file
|
|
AKSHAYRAJAA/checkpoints.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ed4226fab578a672602a194b066fdeb3b72225c3054b955fcdba2e8b59cb661
|
3 |
+
size 65840736
|
AKSHAYRAJAA/checkpoints/x_ray_model.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aca89cd72ca242bf8e2e8406ccd855fcdd922fe788583eb109af765100b169be
|
3 |
+
size 65840564
|
AKSHAYRAJAA/config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from albumentations.pytorch import ToTensorV2
|
5 |
+
|
6 |
+
|
7 |
+
CHECKPOINT_FILE = 'd:\\AKSHAYRAJAA\\checkpoints\\x_ray_model.pth.tar'
|
8 |
+
DATASET_PATH = 'D:\\AKSHAYRAJAA\\dataset\\'
|
9 |
+
IMAGES_DATASET = 'D:\\AKSHAYRAJAA\\dataset\\images'
|
10 |
+
|
11 |
+
DEVICE = 'cpu'
|
12 |
+
BATCH_SIZE = 16
|
13 |
+
PIN_MEMORY = False
|
14 |
+
VOCAB_THRESHOLD = 2
|
15 |
+
|
16 |
+
FEATURES_SIZE = 1024
|
17 |
+
EMBED_SIZE = 300
|
18 |
+
HIDDEN_SIZE = 256
|
19 |
+
|
20 |
+
LEARNING_RATE = 4e-5
|
21 |
+
EPOCHS = 50
|
22 |
+
|
23 |
+
LOAD_MODEL = True
|
24 |
+
SAVE_MODEL = True
|
25 |
+
|
26 |
+
basic_transforms = A.Compose([
|
27 |
+
A.Resize(
|
28 |
+
height=256,
|
29 |
+
width=256
|
30 |
+
),
|
31 |
+
A.Normalize(
|
32 |
+
mean=(0.485, 0.456, 0.406),
|
33 |
+
std=(0.229, 0.224, 0.225),
|
34 |
+
),
|
35 |
+
ToTensorV2()
|
36 |
+
])
|
AKSHAYRAJAA/dataset.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spacy
|
3 |
+
import torch
|
4 |
+
import config
|
5 |
+
import utils
|
6 |
+
import numpy as np
|
7 |
+
import xml.etree.ElementTree as ET
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
|
13 |
+
|
14 |
+
spacy_eng = spacy.load('en_core_web_sm')
|
15 |
+
|
16 |
+
|
17 |
+
class Vocabulary:
|
18 |
+
def __init__(self, freq_threshold):
|
19 |
+
self.itos = {
|
20 |
+
0: '<PAD>',
|
21 |
+
1: '<SOS>',
|
22 |
+
2: '<EOS>',
|
23 |
+
3: '<UNK>',
|
24 |
+
}
|
25 |
+
self.stoi = {
|
26 |
+
'<PAD>': 0,
|
27 |
+
'<SOS>': 1,
|
28 |
+
'<EOS>': 2,
|
29 |
+
'<UNK>': 3,
|
30 |
+
}
|
31 |
+
self.freq_threshold = freq_threshold
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def tokenizer(text):
|
35 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
36 |
+
|
37 |
+
def build_vocabulary(self, sentence_list):
|
38 |
+
frequencies = {}
|
39 |
+
idx = 4
|
40 |
+
|
41 |
+
for sent in sentence_list:
|
42 |
+
for word in self.tokenizer(sent):
|
43 |
+
if word not in frequencies:
|
44 |
+
frequencies[word] = 1
|
45 |
+
else:
|
46 |
+
frequencies[word] += 1
|
47 |
+
|
48 |
+
if frequencies[word] == self.freq_threshold:
|
49 |
+
self.stoi[word] = idx
|
50 |
+
self.itos[idx] = word
|
51 |
+
|
52 |
+
idx += 1
|
53 |
+
|
54 |
+
def numericalize(self, text):
|
55 |
+
tokenized_text = self.tokenizer(text)
|
56 |
+
|
57 |
+
return [
|
58 |
+
self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
|
59 |
+
for token in tokenized_text
|
60 |
+
]
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.itos)
|
64 |
+
|
65 |
+
|
66 |
+
class XRayDataset(Dataset):
|
67 |
+
def __init__(self, root, transform=None, freq_threshold=3, raw_caption=False):
|
68 |
+
self.root = root
|
69 |
+
self.transform = transform
|
70 |
+
self.raw_caption = raw_caption
|
71 |
+
|
72 |
+
self.vocab = Vocabulary(freq_threshold=freq_threshold)
|
73 |
+
|
74 |
+
self.captions = []
|
75 |
+
self.imgs = []
|
76 |
+
|
77 |
+
for file in os.listdir(os.path.join(self.root, 'reports')):
|
78 |
+
if file.endswith('.xml'):
|
79 |
+
tree = ET.parse(os.path.join(self.root, 'reports', file))
|
80 |
+
|
81 |
+
frontal_img = ''
|
82 |
+
findings = tree.find(".//AbstractText[@Label='FINDINGS']").text
|
83 |
+
|
84 |
+
if findings is None:
|
85 |
+
continue
|
86 |
+
|
87 |
+
for x in tree.findall('parentImage'):
|
88 |
+
if frontal_img != '':
|
89 |
+
break
|
90 |
+
|
91 |
+
img = x.attrib['id']
|
92 |
+
img = os.path.join(config.IMAGES_DATASET, f'{img}.png')
|
93 |
+
|
94 |
+
frontal_img = img
|
95 |
+
|
96 |
+
if frontal_img == '':
|
97 |
+
continue
|
98 |
+
|
99 |
+
self.captions.append(findings)
|
100 |
+
self.imgs.append(frontal_img)
|
101 |
+
|
102 |
+
|
103 |
+
self.vocab.build_vocabulary(self.captions)
|
104 |
+
|
105 |
+
def __getitem__(self, item):
|
106 |
+
img = self.imgs[item]
|
107 |
+
caption = utils.normalize_text(self.captions[item])
|
108 |
+
|
109 |
+
img = np.array(Image.open(img).convert('L'))
|
110 |
+
img = np.expand_dims(img, axis=-1)
|
111 |
+
img = img.repeat(3, axis=-1)
|
112 |
+
|
113 |
+
if self.transform is not None:
|
114 |
+
img = self.transform(image=img)['image']
|
115 |
+
|
116 |
+
if self.raw_caption:
|
117 |
+
return img, caption
|
118 |
+
|
119 |
+
numericalized_caption = [self.vocab.stoi['<SOS>']]
|
120 |
+
numericalized_caption += self.vocab.numericalize(caption)
|
121 |
+
numericalized_caption.append(self.vocab.stoi['<EOS>'])
|
122 |
+
|
123 |
+
return img, torch.as_tensor(numericalized_caption, dtype=torch.long)
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return len(self.captions)
|
127 |
+
|
128 |
+
def get_caption(self, item):
|
129 |
+
return self.captions[item].split(' ')
|
130 |
+
|
131 |
+
|
132 |
+
class CollateDataset:
|
133 |
+
def __init__(self, pad_idx):
|
134 |
+
self.pad_idx = pad_idx
|
135 |
+
|
136 |
+
def __call__(self, batch):
|
137 |
+
images, captions = zip(*batch)
|
138 |
+
|
139 |
+
images = torch.stack(images, 0)
|
140 |
+
|
141 |
+
targets = [item for item in captions]
|
142 |
+
targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
|
143 |
+
|
144 |
+
return images, targets
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == '__main__':
|
148 |
+
all_dataset = XRayDataset(
|
149 |
+
root=config.DATASET_PATH,
|
150 |
+
transform=config.basic_transforms,
|
151 |
+
freq_threshold=config.VOCAB_THRESHOLD,
|
152 |
+
)
|
153 |
+
|
154 |
+
train_loader = DataLoader(
|
155 |
+
dataset=all_dataset,
|
156 |
+
batch_size=config.BATCH_SIZE,
|
157 |
+
pin_memory=config.PIN_MEMORY,
|
158 |
+
drop_last=True,
|
159 |
+
shuffle=True,
|
160 |
+
collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['<PAD>']),
|
161 |
+
)
|
162 |
+
|
163 |
+
for img, caption in train_loader:
|
164 |
+
print(img.shape, caption.shape)
|
165 |
+
break
|
AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1000_IM-0003-3001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1001_IM-0004-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1001_IM-0004-1002.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1002_IM-0004-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1002_IM-0004-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1003_IM-0005-2002.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1004_IM-0005-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1004_IM-0005-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1005_IM-0006-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1005_IM-0006-3003.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1006_IM-0007-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1006_IM-0007-3003.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1007_IM-0008-3001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1008_IM-0009-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1008_IM-0009-4004.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1009_IM-0010-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1009_IM-0010-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR100_IM-0002-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR100_IM-0002-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1010_IM-0012-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1010_IM-0012-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1011_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1011_IM-0013-1002.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1012_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1013_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1013_IM-0013-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1014_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1014_IM-0013-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1015_IM-0001-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1015_IM-0001-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1015_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1015_IM-0013-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1016_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1016_IM-0013-2001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1017_IM-0013-1001.png
ADDED
![]() |
AKSHAYRAJAA/dataset/images/CXR1017_IM-0013-1002.png
ADDED
![]() |