File size: 2,895 Bytes
0faaa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import re
import html
import string
import torch
import config
import unicodedata
from nltk.tokenize import word_tokenize

from dataset import XRayDataset
from model import EncoderDecoderNet
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split as sklearn_train_test_split


def load_dataset(raw_caption=False):
    return XRayDataset(
        root=config.DATASET_PATH,
        transform=config.basic_transforms,
        freq_threshold=config.VOCAB_THRESHOLD,
        raw_caption=raw_caption
    )


def get_model_instance(vocabulary):
    model = EncoderDecoderNet(
        features_size=config.FEATURES_SIZE,
        embed_size=config.EMBED_SIZE,
        hidden_size=config.HIDDEN_SIZE,
        vocabulary=vocabulary,
        encoder_checkpoint='./weights/chexnet.pth.tar'
    )
    model = model.to(config.DEVICE)

    return model

def train_test_split(dataset, test_size=0.25, random_state=44):
    train_idx, test_idx = sklearn_train_test_split(
        list(range(len(dataset))),
        test_size=test_size,
        random_state=random_state
    )

    return Subset(dataset, train_idx), Subset(dataset, test_idx)


def save_checkpoint(checkpoint):
    print('=> Saving checkpoint')

    torch.save(checkpoint, config.CHECKPOINT_FILE)


def load_checkpoint(model, optimizer=None):
    print('=> Loading checkpoint')

    checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    return checkpoint['epoch']


def can_load_checkpoint():
    return os.path.exists(config.CHECKPOINT_FILE) and config.LOAD_MODEL


def remove_special_chars(text):
    re1 = re.compile(r'  +')
    x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace(
        ' @-@ ', '-').replace('\\', ' \\ ')

    return re1.sub(' ', html.unescape(x1))


def remove_non_ascii(text):
    return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')


def to_lowercase(text):
    return text.lower()


def remove_punctuation(text):
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)


def replace_numbers(text):
    return re.sub(r'\d+', '', text)


def text2words(text):
    return word_tokenize(text)


def normalize_text( text):
    text = remove_special_chars(text)
    text = remove_non_ascii(text)
    text = remove_punctuation(text)
    text = to_lowercase(text)
    text = replace_numbers(text)

    return text


def normalize_corpus(corpus):
    return [normalize_text(t) for t in corpus]