Allan Victor
commited on
Commit
·
1ba9831
1
Parent(s):
34fb5ee
Upload Util_funs.py
Browse files- Util_funs.py +247 -2
Util_funs.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
import numpy as np
|
4 |
import random
|
5 |
import json, pickle
|
6 |
-
# from ML_SLRC import SLR_DataSet, SLR_Classifier
|
7 |
|
8 |
import torch.nn.functional as F
|
9 |
import torch.nn as nn
|
@@ -596,4 +595,250 @@ class diagnosis():
|
|
596 |
self.i= self.start-1
|
597 |
|
598 |
clear_output()
|
599 |
-
display(self.next_b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
import random
|
5 |
import json, pickle
|
|
|
6 |
|
7 |
import torch.nn.functional as F
|
8 |
import torch.nn as nn
|
|
|
595 |
self.i= self.start-1
|
596 |
|
597 |
clear_output()
|
598 |
+
display(self.next_b)
|
599 |
+
|
600 |
+
|
601 |
+
|
602 |
+
|
603 |
+
|
604 |
+
|
605 |
+
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
import torch.nn.functional as F
|
610 |
+
import torch.nn as nn
|
611 |
+
import math
|
612 |
+
import torch
|
613 |
+
import numpy as np
|
614 |
+
import pandas as pd
|
615 |
+
import time
|
616 |
+
import transformers
|
617 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
618 |
+
from sklearn.manifold import TSNE
|
619 |
+
from copy import deepcopy, copy
|
620 |
+
import seaborn as sns
|
621 |
+
import matplotlib.pylab as plt
|
622 |
+
from pprint import pprint
|
623 |
+
import shutil
|
624 |
+
import datetime
|
625 |
+
import re
|
626 |
+
import json
|
627 |
+
from pathlib import Path
|
628 |
+
|
629 |
+
import torch
|
630 |
+
import torch.nn as nn
|
631 |
+
from torch.utils.data import Dataset, DataLoader
|
632 |
+
import unicodedata
|
633 |
+
import re
|
634 |
+
|
635 |
+
import torch
|
636 |
+
import torch.nn as nn
|
637 |
+
from torch.utils.data import Dataset, DataLoader
|
638 |
+
|
639 |
+
|
640 |
+
|
641 |
+
# Pre-trained model
|
642 |
+
class Encoder(nn.Module):
|
643 |
+
def __init__(self, layers, freeze_bert, model):
|
644 |
+
super(Encoder, self).__init__()
|
645 |
+
|
646 |
+
# Dummy Parameter
|
647 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
648 |
+
|
649 |
+
# Pre-trained model
|
650 |
+
self.model = deepcopy(model)
|
651 |
+
|
652 |
+
# Freezing bert parameters
|
653 |
+
if freeze_bert:
|
654 |
+
for param in self.model.parameters():
|
655 |
+
param.requires_grad = freeze_bert
|
656 |
+
|
657 |
+
# Selecting hidden layers of the pre-trained model
|
658 |
+
old_model_encoder = self.model.encoder.layer
|
659 |
+
new_model_encoder = nn.ModuleList()
|
660 |
+
|
661 |
+
for i in layers:
|
662 |
+
new_model_encoder.append(old_model_encoder[i])
|
663 |
+
|
664 |
+
self.model.encoder.layer = new_model_encoder
|
665 |
+
|
666 |
+
# Feed forward
|
667 |
+
def forward(self, **x):
|
668 |
+
return self.model(**x)['pooler_output']
|
669 |
+
|
670 |
+
# Complete model
|
671 |
+
class SLR_Classifier(nn.Module):
|
672 |
+
def __init__(self, **data):
|
673 |
+
super(SLR_Classifier, self).__init__()
|
674 |
+
|
675 |
+
# Dummy Parameter
|
676 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
677 |
+
|
678 |
+
# Loss function
|
679 |
+
# Binary Cross Entropy with logits reduced to mean
|
680 |
+
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
|
681 |
+
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
|
682 |
+
|
683 |
+
# Pre-trained model
|
684 |
+
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
|
685 |
+
freeze_bert = data.get("freeze_bert", False),
|
686 |
+
model = data.get("model"),
|
687 |
+
)
|
688 |
+
|
689 |
+
# Feature Map Layer
|
690 |
+
self.feature_map = nn.Sequential(
|
691 |
+
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
692 |
+
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
693 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
694 |
+
nn.Linear(self.Encoder.model.config.hidden_size, 200),
|
695 |
+
nn.Dropout(data.get("drop", 0.5)),
|
696 |
+
)
|
697 |
+
|
698 |
+
# Classifier Layer
|
699 |
+
self.classifier = nn.Sequential(
|
700 |
+
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
701 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
702 |
+
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
703 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
704 |
+
nn.Tanh(),
|
705 |
+
nn.Linear(200, 1)
|
706 |
+
)
|
707 |
+
|
708 |
+
# Initializing layer parameters
|
709 |
+
nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
|
710 |
+
nn.init.zeros_(self.feature_map[1].bias)
|
711 |
+
|
712 |
+
# Feed forward
|
713 |
+
def forward(self, input_ids, attention_mask, token_type_ids, labels):
|
714 |
+
|
715 |
+
predict = self.Encoder(**{"input_ids":input_ids,
|
716 |
+
"attention_mask":attention_mask,
|
717 |
+
"token_type_ids":token_type_ids})
|
718 |
+
feature = self.feature_map(predict)
|
719 |
+
logit = self.classifier(feature)
|
720 |
+
|
721 |
+
predict = torch.sigmoid(logit)
|
722 |
+
|
723 |
+
# Loss function
|
724 |
+
loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
|
725 |
+
|
726 |
+
return [loss, [feature, logit], predict]
|
727 |
+
|
728 |
+
|
729 |
+
# Undesirable patterns within texts
|
730 |
+
patterns = {
|
731 |
+
'CONCLUSIONS AND IMPLICATIONS':'',
|
732 |
+
'BACKGROUND AND PURPOSE':'',
|
733 |
+
'EXPERIMENTAL APPROACH':'',
|
734 |
+
'KEY RESULTS AEA':'',
|
735 |
+
'©':'',
|
736 |
+
'®':'',
|
737 |
+
'μ':'',
|
738 |
+
'(C)':'',
|
739 |
+
'OBJECTIVE:':'',
|
740 |
+
'MATERIALS AND METHODS:':'',
|
741 |
+
'SIGNIFICANCE:':'',
|
742 |
+
'BACKGROUND:':'',
|
743 |
+
'RESULTS:':'',
|
744 |
+
'METHODS:':'',
|
745 |
+
'CONCLUSIONS:':'',
|
746 |
+
'AIM:':'',
|
747 |
+
'STUDY DESIGN:':'',
|
748 |
+
'CLINICAL RELEVANCE:':'',
|
749 |
+
'CONCLUSION:':'',
|
750 |
+
'HYPOTHESIS:':'',
|
751 |
+
'CLINICAL RELEVANCE:':'',
|
752 |
+
'Questions/Purposes:':'',
|
753 |
+
'Introduction:':'',
|
754 |
+
'PURPOSE:':'',
|
755 |
+
'PATIENTS AND METHODS:':'',
|
756 |
+
'FINDINGS:':'',
|
757 |
+
'INTERPRETATIONS:':'',
|
758 |
+
'FUNDING:':'',
|
759 |
+
'PROGRESS:':'',
|
760 |
+
'CONTEXT:':'',
|
761 |
+
'MEASURES:':'',
|
762 |
+
'DESIGN:':'',
|
763 |
+
'BACKGROUND AND OBJECTIVES:':'',
|
764 |
+
'<p>':'',
|
765 |
+
'</p>':'',
|
766 |
+
'<<ETX>>':'',
|
767 |
+
'+/-':'',
|
768 |
+
}
|
769 |
+
|
770 |
+
patterns = {x.lower():y for x,y in patterns.items()}
|
771 |
+
|
772 |
+
LABEL_MAP = {'negative': 0, 'positive': 1}
|
773 |
+
|
774 |
+
class SLR_DataSet(Dataset):
|
775 |
+
def __init__(self, **args):
|
776 |
+
self.tokenizer = args.get('tokenizer')
|
777 |
+
self.data = args.get('data')
|
778 |
+
self.max_seq_length = args.get("max_seq_length", 512)
|
779 |
+
self.INPUT_NAME = args.get("input", 'x')
|
780 |
+
self.LABEL_NAME = args.get("output", 'y')
|
781 |
+
|
782 |
+
# Tokenizing and processing text
|
783 |
+
def encode_text(self, example):
|
784 |
+
comment_text = example[self.INPUT_NAME]
|
785 |
+
comment_text = self.treat_text(comment_text)
|
786 |
+
|
787 |
+
try:
|
788 |
+
labels = LABEL_MAP[example[self.LABEL_NAME]]
|
789 |
+
except:
|
790 |
+
labels = -1
|
791 |
+
|
792 |
+
encoding = self.tokenizer.encode_plus(
|
793 |
+
(comment_text, "It is great text"),
|
794 |
+
add_special_tokens=True,
|
795 |
+
max_length=self.max_seq_length,
|
796 |
+
return_token_type_ids=True,
|
797 |
+
padding="max_length",
|
798 |
+
truncation=True,
|
799 |
+
return_attention_mask=True,
|
800 |
+
return_tensors='pt',
|
801 |
+
)
|
802 |
+
|
803 |
+
|
804 |
+
return tuple((
|
805 |
+
encoding["input_ids"].flatten(),
|
806 |
+
encoding["attention_mask"].flatten(),
|
807 |
+
encoding["token_type_ids"].flatten(),
|
808 |
+
torch.tensor([torch.tensor(labels).to(int)])
|
809 |
+
))
|
810 |
+
|
811 |
+
# Text processing function
|
812 |
+
def treat_text(self, text):
|
813 |
+
text = unicodedata.normalize("NFKD",str(text))
|
814 |
+
text = multiple_replace(patterns,text.lower())
|
815 |
+
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
816 |
+
text = re.sub('( +)',' ', text)
|
817 |
+
text = re.sub('(, ,)|(,,)',',', text)
|
818 |
+
text = re.sub('(%)|(per cent)',' percent', text)
|
819 |
+
return text
|
820 |
+
|
821 |
+
def __len__(self):
|
822 |
+
return len(self.data)
|
823 |
+
|
824 |
+
# Returning data
|
825 |
+
def __getitem__(self, index: int):
|
826 |
+
# print(index)
|
827 |
+
data_row = self.data.reset_index().iloc[index]
|
828 |
+
temp_data = self.encode_text(data_row)
|
829 |
+
return temp_data
|
830 |
+
|
831 |
+
|
832 |
+
|
833 |
+
# Regex multiple replace function
|
834 |
+
def multiple_replace(dict, text):
|
835 |
+
|
836 |
+
# Building regex from dict keys
|
837 |
+
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
838 |
+
|
839 |
+
# Substitution
|
840 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
841 |
+
|
842 |
+
# Undesirable patterns within texts
|
843 |
+
|
844 |
+
|