File size: 1,915 Bytes
5a79fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn, cuda, load, device
from ps4_models.classifiers import PS4_Mega, PS4_Conv


def load_trained_model(load_path, model_name='PS4_Mega'):

    if model_name.lower() not in ['ps4_conv', 'ps4_mega']:
        raise ValueError(f'Model name {model_name} not recognised, please choose from PS4_Conv, PS4_Mega')

    model: nn.Module = PS4_Mega() if model_name.lower() == 'ps4_mega' else PS4_Conv()

    if load_path != '':
        try:
            if cuda.is_available():
                model.load_state_dict(load(load_path)['model_state_dict'])
            else:
                model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
            print("loded params from", load_path)
        except:
            raise ImportError(f'No file located at {load_path}, could not load parameters')
    print(model)

    pytorch_total_params = sum(par.numel() for par in model.parameters() if par.requires_grad)
    print(pytorch_total_params)

    return model


# MARK: sampling from new sequence
def sample_new_sequence(embs, weights_load_path, model_name='PS4_Mega'):

    model = load_trained_model(weights_load_path, model_name)

    seq_size = len(embs)
    R = embs.view(1, seq_size)

    pred_ss = ''

    with torch.no_grad():
        y_hat = model(R)
        probs = torch.softmax(y_hat, 2)
        _, ss_preds = torch.max(probs, 2)

        for i in range(seq_size):
            ss = ss_preds[0][i].item()
            ss = ss_tokeniser(ss, reverse=True)
            pred_ss += ss

    return pred_ss


def ss_tokeniser(ss, reverse=False):

    ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C']

    if reverse:
        return inverse_ss_tokeniser(ss)
    else:
        return 0 if (ss == 'P' or ss == ' ') else ss_set.index(ss)


def inverse_ss_tokeniser(ss):

    ss_set = ['C', 'T', 'G', 'H', 'S', 'B', 'I', 'E', 'C', 'C']

    return ss_set[ss]