Allan Victor commited on
Commit
1ba9831
·
1 Parent(s): 34fb5ee

Upload Util_funs.py

Browse files
Files changed (1) hide show
  1. Util_funs.py +247 -2
Util_funs.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import numpy as np
4
  import random
5
  import json, pickle
6
- # from ML_SLRC import SLR_DataSet, SLR_Classifier
7
 
8
  import torch.nn.functional as F
9
  import torch.nn as nn
@@ -596,4 +595,250 @@ class diagnosis():
596
  self.i= self.start-1
597
 
598
  clear_output()
599
- display(self.next_b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import random
5
  import json, pickle
 
6
 
7
  import torch.nn.functional as F
8
  import torch.nn as nn
 
595
  self.i= self.start-1
596
 
597
  clear_output()
598
+ display(self.next_b)
599
+
600
+
601
+
602
+
603
+
604
+
605
+
606
+
607
+
608
+
609
+ import torch.nn.functional as F
610
+ import torch.nn as nn
611
+ import math
612
+ import torch
613
+ import numpy as np
614
+ import pandas as pd
615
+ import time
616
+ import transformers
617
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
618
+ from sklearn.manifold import TSNE
619
+ from copy import deepcopy, copy
620
+ import seaborn as sns
621
+ import matplotlib.pylab as plt
622
+ from pprint import pprint
623
+ import shutil
624
+ import datetime
625
+ import re
626
+ import json
627
+ from pathlib import Path
628
+
629
+ import torch
630
+ import torch.nn as nn
631
+ from torch.utils.data import Dataset, DataLoader
632
+ import unicodedata
633
+ import re
634
+
635
+ import torch
636
+ import torch.nn as nn
637
+ from torch.utils.data import Dataset, DataLoader
638
+
639
+
640
+
641
+ # Pre-trained model
642
+ class Encoder(nn.Module):
643
+ def __init__(self, layers, freeze_bert, model):
644
+ super(Encoder, self).__init__()
645
+
646
+ # Dummy Parameter
647
+ self.dummy_param = nn.Parameter(torch.empty(0))
648
+
649
+ # Pre-trained model
650
+ self.model = deepcopy(model)
651
+
652
+ # Freezing bert parameters
653
+ if freeze_bert:
654
+ for param in self.model.parameters():
655
+ param.requires_grad = freeze_bert
656
+
657
+ # Selecting hidden layers of the pre-trained model
658
+ old_model_encoder = self.model.encoder.layer
659
+ new_model_encoder = nn.ModuleList()
660
+
661
+ for i in layers:
662
+ new_model_encoder.append(old_model_encoder[i])
663
+
664
+ self.model.encoder.layer = new_model_encoder
665
+
666
+ # Feed forward
667
+ def forward(self, **x):
668
+ return self.model(**x)['pooler_output']
669
+
670
+ # Complete model
671
+ class SLR_Classifier(nn.Module):
672
+ def __init__(self, **data):
673
+ super(SLR_Classifier, self).__init__()
674
+
675
+ # Dummy Parameter
676
+ self.dummy_param = nn.Parameter(torch.empty(0))
677
+
678
+ # Loss function
679
+ # Binary Cross Entropy with logits reduced to mean
680
+ self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
681
+ pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
682
+
683
+ # Pre-trained model
684
+ self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
685
+ freeze_bert = data.get("freeze_bert", False),
686
+ model = data.get("model"),
687
+ )
688
+
689
+ # Feature Map Layer
690
+ self.feature_map = nn.Sequential(
691
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
692
+ nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
693
+ # nn.Dropout(data.get("drop", 0.5)),
694
+ nn.Linear(self.Encoder.model.config.hidden_size, 200),
695
+ nn.Dropout(data.get("drop", 0.5)),
696
+ )
697
+
698
+ # Classifier Layer
699
+ self.classifier = nn.Sequential(
700
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
701
+ # nn.Dropout(data.get("drop", 0.5)),
702
+ # nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
703
+ # nn.Dropout(data.get("drop", 0.5)),
704
+ nn.Tanh(),
705
+ nn.Linear(200, 1)
706
+ )
707
+
708
+ # Initializing layer parameters
709
+ nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
710
+ nn.init.zeros_(self.feature_map[1].bias)
711
+
712
+ # Feed forward
713
+ def forward(self, input_ids, attention_mask, token_type_ids, labels):
714
+
715
+ predict = self.Encoder(**{"input_ids":input_ids,
716
+ "attention_mask":attention_mask,
717
+ "token_type_ids":token_type_ids})
718
+ feature = self.feature_map(predict)
719
+ logit = self.classifier(feature)
720
+
721
+ predict = torch.sigmoid(logit)
722
+
723
+ # Loss function
724
+ loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
725
+
726
+ return [loss, [feature, logit], predict]
727
+
728
+
729
+ # Undesirable patterns within texts
730
+ patterns = {
731
+ 'CONCLUSIONS AND IMPLICATIONS':'',
732
+ 'BACKGROUND AND PURPOSE':'',
733
+ 'EXPERIMENTAL APPROACH':'',
734
+ 'KEY RESULTS AEA':'',
735
+ '©':'',
736
+ '®':'',
737
+ 'μ':'',
738
+ '(C)':'',
739
+ 'OBJECTIVE:':'',
740
+ 'MATERIALS AND METHODS:':'',
741
+ 'SIGNIFICANCE:':'',
742
+ 'BACKGROUND:':'',
743
+ 'RESULTS:':'',
744
+ 'METHODS:':'',
745
+ 'CONCLUSIONS:':'',
746
+ 'AIM:':'',
747
+ 'STUDY DESIGN:':'',
748
+ 'CLINICAL RELEVANCE:':'',
749
+ 'CONCLUSION:':'',
750
+ 'HYPOTHESIS:':'',
751
+ 'CLINICAL RELEVANCE:':'',
752
+ 'Questions/Purposes:':'',
753
+ 'Introduction:':'',
754
+ 'PURPOSE:':'',
755
+ 'PATIENTS AND METHODS:':'',
756
+ 'FINDINGS:':'',
757
+ 'INTERPRETATIONS:':'',
758
+ 'FUNDING:':'',
759
+ 'PROGRESS:':'',
760
+ 'CONTEXT:':'',
761
+ 'MEASURES:':'',
762
+ 'DESIGN:':'',
763
+ 'BACKGROUND AND OBJECTIVES:':'',
764
+ '<p>':'',
765
+ '</p>':'',
766
+ '<<ETX>>':'',
767
+ '+/-':'',
768
+ }
769
+
770
+ patterns = {x.lower():y for x,y in patterns.items()}
771
+
772
+ LABEL_MAP = {'negative': 0, 'positive': 1}
773
+
774
+ class SLR_DataSet(Dataset):
775
+ def __init__(self, **args):
776
+ self.tokenizer = args.get('tokenizer')
777
+ self.data = args.get('data')
778
+ self.max_seq_length = args.get("max_seq_length", 512)
779
+ self.INPUT_NAME = args.get("input", 'x')
780
+ self.LABEL_NAME = args.get("output", 'y')
781
+
782
+ # Tokenizing and processing text
783
+ def encode_text(self, example):
784
+ comment_text = example[self.INPUT_NAME]
785
+ comment_text = self.treat_text(comment_text)
786
+
787
+ try:
788
+ labels = LABEL_MAP[example[self.LABEL_NAME]]
789
+ except:
790
+ labels = -1
791
+
792
+ encoding = self.tokenizer.encode_plus(
793
+ (comment_text, "It is great text"),
794
+ add_special_tokens=True,
795
+ max_length=self.max_seq_length,
796
+ return_token_type_ids=True,
797
+ padding="max_length",
798
+ truncation=True,
799
+ return_attention_mask=True,
800
+ return_tensors='pt',
801
+ )
802
+
803
+
804
+ return tuple((
805
+ encoding["input_ids"].flatten(),
806
+ encoding["attention_mask"].flatten(),
807
+ encoding["token_type_ids"].flatten(),
808
+ torch.tensor([torch.tensor(labels).to(int)])
809
+ ))
810
+
811
+ # Text processing function
812
+ def treat_text(self, text):
813
+ text = unicodedata.normalize("NFKD",str(text))
814
+ text = multiple_replace(patterns,text.lower())
815
+ text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
816
+ text = re.sub('( +)',' ', text)
817
+ text = re.sub('(, ,)|(,,)',',', text)
818
+ text = re.sub('(%)|(per cent)',' percent', text)
819
+ return text
820
+
821
+ def __len__(self):
822
+ return len(self.data)
823
+
824
+ # Returning data
825
+ def __getitem__(self, index: int):
826
+ # print(index)
827
+ data_row = self.data.reset_index().iloc[index]
828
+ temp_data = self.encode_text(data_row)
829
+ return temp_data
830
+
831
+
832
+
833
+ # Regex multiple replace function
834
+ def multiple_replace(dict, text):
835
+
836
+ # Building regex from dict keys
837
+ regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
838
+
839
+ # Substitution
840
+ return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
841
+
842
+ # Undesirable patterns within texts
843
+
844
+