minko186 commited on
Commit
394a4ce
·
1 Parent(s): 97dec89

Create predictors.py

Browse files
Files changed (1) hide show
  1. predictors.py +175 -0
predictors.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import httpx
3
+ import torch
4
+ import re
5
+ from bs4 import BeautifulSoup
6
+ import numpy as np
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ import asyncio
9
+ from evaluate import load
10
+ from datetime import date
11
+ import nltk
12
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
13
+ import plotly.graph_objects as go
14
+ import torch.nn.functional as F
15
+ import nltk
16
+ from unidecode import unidecode
17
+ import time
18
+ from scipy.special import softmax
19
+ import yaml
20
+ import os
21
+ from utils import *
22
+ from dotenv import load_dotenv
23
+ with open('config.yaml', 'r') as file:
24
+ params = yaml.safe_load(file)
25
+ nltk.download('punkt')
26
+ nltk.download('stopwords')
27
+ load_dotenv()
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ hf_token = os.getenv("HF_TOKEN")
30
+ text_bc_model_path = os.getenv("TEXT_BC_MODEL_PATH")
31
+ text_mc_model_path = os.getenv("TEXT_MC_MODEL_PATH")
32
+ text_quillbot_model_path = os.getenv("TEXT_QUILLBOT_MODEL_PATH")
33
+ quillbot_labels = params["QUILLBOT_LABELS"]
34
+ mc_label_map = params["MC_OUTPUT_LABELS"]
35
+ mc_token_size = int(os.getenv("MC_TOKEN_SIZE"))
36
+ bc_token_size = int(os.getenv("BC_TOKEN_SIZE"))
37
+ text_bc_tokenizer = AutoTokenizer.from_pretrained(text_bc_model_path, use_auth_token=hf_token)
38
+ text_bc_model = AutoModelForSequenceClassification.from_pretrained(text_bc_model_path, use_auth_token=hf_token).to(device)
39
+ text_mc_tokenizer = AutoTokenizer.from_pretrained(text_mc_model_path, use_auth_token=hf_token)
40
+ text_mc_model = AutoModelForSequenceClassification.from_pretrained(text_mc_model_path, use_auth_token=hf_token).to(device)
41
+ quillbot_tokenizer = AutoTokenizer.from_pretrained(text_quillbot_model_path, use_auth_token=hf_token)
42
+ quillbot_model = AutoModelForSequenceClassification.from_pretrained(text_quillbot_model_path, use_auth_token=hf_token).to(device)
43
+
44
+ def split_text_allow_complete_sentences_nltk(text, max_length=256, tolerance=30, min_last_segment_length=100, type_det='bc'):
45
+ sentences = nltk.sent_tokenize(text)
46
+ segments = []
47
+ current_segment = []
48
+ current_length = 0
49
+ if type_det == 'bc':
50
+ tokenizer = text_bc_tokenizer
51
+ max_length = bc_token_size
52
+ elif type_det == 'mc':
53
+ tokenizer = text_mc_tokenizer
54
+ max_length = mc_token_size
55
+ for sentence in sentences:
56
+ tokens = tokenizer.tokenize(sentence)
57
+ sentence_length = len(tokens)
58
+
59
+ if current_length + sentence_length <= max_length + tolerance - 2:
60
+ current_segment.append(sentence)
61
+ current_length += sentence_length
62
+ else:
63
+ if current_segment:
64
+ encoded_segment = tokenizer.encode(' '.join(current_segment), add_special_tokens=True, max_length=max_length+tolerance, truncation=True)
65
+ segments.append((current_segment, len(encoded_segment)))
66
+ current_segment = [sentence]
67
+ current_length = sentence_length
68
+
69
+ if current_segment:
70
+ encoded_segment = tokenizer.encode(' '.join(current_segment), add_special_tokens=True, max_length=max_length+tolerance, truncation=True)
71
+ segments.append((current_segment, len(encoded_segment)))
72
+
73
+ final_segments = []
74
+ for i, (seg, length) in enumerate(segments):
75
+ if i == len(segments) - 1:
76
+ if length < min_last_segment_length and len(final_segments) > 0:
77
+ prev_seg, prev_length = final_segments[-1]
78
+ combined_encoded = tokenizer.encode(' '.join(prev_seg + seg), add_special_tokens=True, max_length=max_length+tolerance, truncation=True)
79
+ if len(combined_encoded) <= max_length + tolerance:
80
+ final_segments[-1] = (prev_seg + seg, len(combined_encoded))
81
+ else:
82
+ final_segments.append((seg, length))
83
+ else:
84
+ final_segments.append((seg, length))
85
+ else:
86
+ final_segments.append((seg, length))
87
+
88
+ decoded_segments = []
89
+ encoded_segments = []
90
+ for seg, _ in final_segments:
91
+ encoded_segment = tokenizer.encode(' '.join(seg), add_special_tokens=True, max_length=max_length+tolerance, truncation=True)
92
+ decoded_segment = tokenizer.decode(encoded_segment)
93
+ decoded_segments.append(decoded_segment)
94
+ return decoded_segments
95
+
96
+ def predict_quillbot(text):
97
+ with torch.no_grad():
98
+ quillbot_model.eval()
99
+ tokenized_text = quillbot_tokenizer(text, padding="max_length", truncation=True, max_length=256, return_tensors="pt").to(device)
100
+ output = quillbot_model(**tokenized_text)
101
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
102
+ q_score = {"QuillBot": output_norm[1].item(), "Original": output_norm[0].item()}
103
+ return q_score
104
+
105
+ def predict_bc(model, tokenizer, text):
106
+ with torch.no_grad():
107
+ model.eval()
108
+ tokens = text_bc_tokenizer(
109
+ text, padding='max_length', truncation=True, max_length=bc_token_size, return_tensors="pt"
110
+ ).to(device)
111
+ output = model(**tokens)
112
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
113
+ return output_norm
114
+
115
+ def predict_mc(model, tokenizer, text):
116
+ with torch.no_grad():
117
+ model.eval()
118
+ tokens = text_mc_tokenizer(
119
+ text, padding='max_length', truncation=True, return_tensors="pt", max_length=mc_token_size
120
+ ).to(device)
121
+ output = model(**tokens)
122
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
123
+ return output_norm
124
+
125
+ def predict_mc_scores(input):
126
+ bc_scores = []
127
+ mc_scores = []
128
+
129
+ samples_len_bc = len(split_text_allow_complete_sentences_nltk(input, type_det = 'bc'))
130
+ segments_bc = split_text_allow_complete_sentences_nltk(input, type_det = 'bc')
131
+ for i in range(samples_len_bc):
132
+ cleaned_text_bc = remove_special_characters(segments_bc[i])
133
+ bc_score = predict_bc(text_bc_model, text_bc_tokenizer,cleaned_text_bc )
134
+ bc_scores.append(bc_score)
135
+ bc_scores_array = np.array(bc_scores)
136
+ average_bc_scores = np.mean(bc_scores_array, axis=0)
137
+ bc_score_list = average_bc_scores.tolist()
138
+ bc_score = {"AI": bc_score_list[1], "HUMAN": bc_score_list[0]}
139
+ segments_mc = split_text_allow_complete_sentences_nltk(input, type_det = 'mc')
140
+ samples_len_mc = len(split_text_allow_complete_sentences_nltk(input, type_det = 'mc'))
141
+ for i in range(samples_len_mc):
142
+ cleaned_text_mc = remove_special_characters(segments_mc[i])
143
+ mc_score = predict_mc(text_mc_model, text_mc_tokenizer, cleaned_text_mc)
144
+ mc_scores.append(mc_score)
145
+ mc_scores_array = np.array(mc_scores)
146
+ average_mc_scores = np.mean(mc_scores_array, axis=0)
147
+ mc_score_list = average_mc_scores.tolist()
148
+ mc_score = {}
149
+ for score, label in zip(mc_score_list, mc_label_map):
150
+ mc_score[label.upper()] = score
151
+
152
+ sum_prob = 1 - bc_score['HUMAN']
153
+ for key, value in mc_score.items():
154
+ mc_score[key] = value * sum_prob
155
+ if sum_prob < 0.01 :
156
+ mc_score = {}
157
+
158
+ mc_score['HUMAN'] = bc_score['HUMAN']
159
+ return mc_score
160
+
161
+
162
+ def predict_bc_scores(input):
163
+ bc_scores = []
164
+ mc_scores = []
165
+ samples_len_bc = len(split_text_allow_complete_sentences_nltk(input, type_det = 'bc'))
166
+ segments_bc = split_text_allow_complete_sentences_nltk(input, type_det = 'bc')
167
+ for i in range(samples_len_bc):
168
+ cleaned_text_bc = remove_special_characters(segments_bc[i])
169
+ bc_score = predict_bc(text_bc_model, text_bc_tokenizer,cleaned_text_bc )
170
+ bc_scores.append(bc_score)
171
+ bc_scores_array = np.array(bc_scores)
172
+ average_bc_scores = np.mean(bc_scores_array, axis=0)
173
+ bc_score_list = average_bc_scores.tolist()
174
+ bc_score = {"AI": bc_score_list[1], "HUMAN": bc_score_list[0]}
175
+ return bc_score