ngocminhta commited on
Commit
3fef185
·
1 Parent(s): 5287779
Files changed (12) hide show
  1. app.py +80 -0
  2. gen_database.py +300 -0
  3. infer.py +130 -0
  4. requirements.txt +11 -0
  5. src/.DS_Store +0 -0
  6. src/__init__.py +0 -0
  7. src/index.py +80 -0
  8. src/simclr.py +280 -0
  9. src/text_embedding.py +55 -0
  10. utils/__init__.py +0 -0
  11. utils/load_dataset.py +205 -0
  12. utils/utils.py +132 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import torch
5
+ from src.text_embedding import TextEmbeddingModel
6
+ from src.index import Indexer
7
+ import os
8
+ import pickle
9
+ from infer import infer_3_class
10
+ import uvicorn
11
+
12
+ app = FastAPI()
13
+
14
+ origins = ["*"]
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=origins,
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ class Opt:
25
+ def __init__(self):
26
+ self.model_name = "./unsup-simcse-xlm-roberta-base"
27
+ self.model_path = "core/model.pth"
28
+ self.database_path = "core/seen_db"
29
+ self.embedding_dim = 768
30
+ self.device_num = 1
31
+
32
+ opt = Opt()
33
+
34
+ def load_pkl(path):
35
+ with open(path, 'rb') as f:
36
+ return pickle.load(f)
37
+
38
+ @app.on_event("startup")
39
+ def load_model_resources():
40
+ global model, tokenizer, index, label_dict, is_mixed_dict
41
+
42
+ model = TextEmbeddingModel(opt.model_name)
43
+ state_dict = torch.load(opt.model_path, map_location=model.model.device)
44
+ new_state_dict={}
45
+ for key in state_dict.keys():
46
+ if key.startswith('model.'):
47
+ new_state_dict[key[6:]]=state_dict[key]
48
+ model.load_state_dict(state_dict)
49
+ tokenizer=model.tokenizer
50
+
51
+ index = Indexer(opt.embedding_dim)
52
+ index.deserialize_from(opt.database_path)
53
+ label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
54
+ is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
55
+
56
+
57
+ @app.route('/predict', methods=['POST'])
58
+ async def predict(request: Request):
59
+ data = await request.json()
60
+ mode = data.get("mode", "normal").lower()
61
+ text_list = data.get("text", [])
62
+
63
+ if mode == "normal":
64
+ results = []
65
+ for text in text_list:
66
+ result = infer_3_class(model=model,
67
+ tokenizer=tokenizer,
68
+ index=index,
69
+ label_dict=label_dict,
70
+ is_mixed_dict=is_mixed_dict,
71
+ text=text,
72
+ K=20)
73
+ results.append(result)
74
+ return JSONResponse(content={"results": results})
75
+ elif mode == "advanced":
76
+ return 0
77
+
78
+ if __name__ == "__main__":
79
+ port = int(os.getenv("PORT", 8000))
80
+ uvicorn.run(app, host="0.0.0.0", port=port)
gen_database.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import faiss
5
+ from src.index import Indexer
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch.utils.data import DataLoader
10
+ from lightning import Fabric
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from src.text_embedding import TextEmbeddingModel
14
+ from utils.load_dataset import load_dataset, TextDataset, load_outdomain_dataset
15
+
16
+ def load_pkl(path):
17
+ with open(path, 'rb') as f:
18
+ return pickle.load(f)
19
+
20
+ def infer(passages_dataloder,fabric,tokenizer,model,ood=False):
21
+ if fabric.global_rank == 0 :
22
+ passages_dataloder=tqdm(passages_dataloder,total=len(passages_dataloder))
23
+ if ood:
24
+ allids, allembeddings,alllabels,all_is_mixed= [],[],[],[]
25
+ else:
26
+ allids, allembeddings,alllabels,all_is_mixed,all_write_model= [],[],[],[],[]
27
+ model.model.eval()
28
+ with torch.no_grad():
29
+ for batch in passages_dataloder:
30
+ if ood:
31
+ ids, text, label, is_mixed = batch
32
+ encoded_batch = tokenizer.batch_encode_plus(
33
+ text,
34
+ return_tensors="pt",
35
+ max_length=512,
36
+ padding="max_length",
37
+ # padding=True,
38
+ truncation=True,
39
+ )
40
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
41
+ # output = model(**encoded_batch).last_hidden_state
42
+ # embeddings = pooling(output, encoded_batch)
43
+ # print(encoded_batch)
44
+ embeddings = model(encoded_batch)
45
+ # print(encoded_batch['input_ids'].shape)
46
+ embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1))
47
+ label = fabric.all_gather(label).view(-1)
48
+ ids = fabric.all_gather(ids).view(-1)
49
+ is_mixed = fabric.all_gather(is_mixed).view(-1)
50
+ if fabric.global_rank == 0 :
51
+ allembeddings.append(embeddings.cpu())
52
+ allids.extend(ids.cpu().tolist())
53
+ alllabels.extend(label.cpu().tolist())
54
+ all_is_mixed.extend(is_mixed.cpu().tolist())
55
+ else:
56
+ ids, text, label, is_mixed, write_model = batch
57
+ encoded_batch = tokenizer.batch_encode_plus(
58
+ text,
59
+ return_tensors="pt",
60
+ max_length=512,
61
+ padding="max_length",
62
+ # padding=True,
63
+ truncation=True,
64
+ )
65
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
66
+ # output = model(**encoded_batch).last_hidden_state
67
+ # embeddings = pooling(output, encoded_batch)
68
+ # print(encoded_batch)
69
+ embeddings = model(encoded_batch)
70
+ # print(encoded_batch['input_ids'].shape)
71
+ embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1))
72
+ label = fabric.all_gather(label).view(-1)
73
+ ids = fabric.all_gather(ids).view(-1)
74
+ is_mixed = fabric.all_gather(is_mixed).view(-1)
75
+ write_model = fabric.all_gather(write_model).view(-1)
76
+ if fabric.global_rank == 0 :
77
+ allembeddings.append(embeddings.cpu())
78
+ allids.extend(ids.cpu().tolist())
79
+ alllabels.extend(label.cpu().tolist())
80
+ all_is_mixed.extend(is_mixed.cpu().tolist())
81
+ all_write_model.extend(write_model.cpu().tolist())
82
+ if fabric.global_rank == 0 :
83
+ allembeddings = torch.cat(allembeddings, dim=0)
84
+ epsilon = 1e-6
85
+ if ood:
86
+ emb_dict,label_dict,is_mixed_dict={},{},{}
87
+ allembeddings= F.normalize(allembeddings,dim=-1)
88
+ for i in range(len(allids)):
89
+ emb_dict[allids[i]]=allembeddings[i]
90
+ label_dict[allids[i]]=alllabels[i]
91
+ is_mixed_dict[allids[i]]=all_is_mixed[i]
92
+ allids,allembeddings,alllabels,all_is_mixed=[],[],[],[]
93
+ for key in emb_dict:
94
+ allids.append(key)
95
+ allembeddings.append(emb_dict[key])
96
+ alllabels.append(label_dict[key])
97
+ all_is_mixed.append(is_mixed_dict[key])
98
+ allembeddings = torch.stack(allembeddings, dim=0)
99
+ return allids,allembeddings.numpy(),alllabels,all_is_mixed
100
+ else:
101
+ emb_dict,label_dict,is_mixed_dict,write_model_dict={},{},{},{}
102
+ allembeddings= F.normalize(allembeddings,dim=-1)
103
+ for i in range(len(allids)):
104
+ emb_dict[allids[i]]=allembeddings[i]
105
+ label_dict[allids[i]]=alllabels[i]
106
+ is_mixed_dict[allids[i]]=all_is_mixed[i]
107
+ write_model_dict[allids[i]]=all_write_model[i]
108
+ allids,allembeddings,alllabels,all_is_mixed,all_write_model=[],[],[],[],[]
109
+ for key in emb_dict:
110
+ allids.append(key)
111
+ allembeddings.append(emb_dict[key])
112
+ alllabels.append(label_dict[key])
113
+ all_is_mixed.append(is_mixed_dict[key])
114
+ all_write_model.append(write_model_dict[key])
115
+ allembeddings = torch.stack(allembeddings, dim=0)
116
+ return allids, allembeddings.numpy(),alllabels,all_is_mixed,all_write_model
117
+ else:
118
+ if ood:
119
+ return [],[],[],[]
120
+ return [],[],[],[],[]
121
+
122
+ def set_seed(seed):
123
+ torch.manual_seed(seed)
124
+ torch.cuda.manual_seed(seed)
125
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
126
+ np.random.seed(seed) # Numpy module.
127
+ random.seed(seed) # Python random module.
128
+
129
+ def test(opt):
130
+ if opt.device_num>1:
131
+ fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp')
132
+ else:
133
+ fabric = Fabric(accelerator="cuda",devices=opt.device_num)
134
+ fabric.launch()
135
+ model = TextEmbeddingModel(opt.model_name).cuda()
136
+ state_dict = torch.load(opt.model_path, map_location=model.model.device)
137
+ new_state_dict={}
138
+ for key in state_dict.keys():
139
+ if key.startswith('model.'):
140
+ new_state_dict[key[6:]]=state_dict[key]
141
+ model.load_state_dict(state_dict)
142
+ tokenizer=model.tokenizer
143
+ database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name]
144
+ passage_dataset = TextDataset(database,need_ids=True)
145
+ print(len(passage_dataset))
146
+
147
+ passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True)
148
+ passages_dataloder=fabric.setup_dataloaders(passages_dataloder)
149
+ model=fabric.setup(model)
150
+
151
+ train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model)
152
+ fabric.barrier()
153
+
154
+ if fabric.global_rank == 0:
155
+ index = Indexer(opt.embedding_dim)
156
+ index.index_data(train_ids, train_embeddings)
157
+ label_dict={}
158
+ is_mixed_dict={}
159
+ write_model_dict={}
160
+ for i in range(len(train_ids)):
161
+ label_dict[train_ids[i]]=train_labels[i]
162
+ is_mixed_dict[train_ids[i]]=train_is_mixed[i]
163
+ write_model_dict[train_ids[i]]=train_write_model[i]
164
+
165
+ if not os.path.exists(opt.save_path):
166
+ os.makedirs(opt.save_path)
167
+ index.serialize(opt.save_path)
168
+ #save label_dict using pickle
169
+ with open(os.path.join(opt.save_path, 'label_dict.pkl'), 'wb') as f:
170
+ pickle.dump(label_dict, f)
171
+ #save is_mixed_dict using pickle
172
+ with open(os.path.join(opt.save_path, 'is_mixed_dict.pkl'), 'wb') as f:
173
+ pickle.dump(is_mixed_dict, f)
174
+ #save write_model_dict using pickle
175
+ with open(os.path.join(opt.save_path, 'write_model_dict.pkl'), 'wb') as f:
176
+ pickle.dump(write_model_dict, f)
177
+
178
+ def add_to_existed_index(opt):
179
+ if opt.device_num>1:
180
+ fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp')
181
+ else:
182
+ fabric = Fabric(accelerator="cuda",devices=opt.device_num)
183
+ fabric.launch()
184
+ model = TextEmbeddingModel(opt.model_name).cuda()
185
+ state_dict = torch.load(opt.model_path, map_location=model.model.device)
186
+ new_state_dict={}
187
+ for key in state_dict.keys():
188
+ if key.startswith('model.'):
189
+ new_state_dict[key[6:]]=state_dict[key]
190
+ model.load_state_dict(state_dict)
191
+ tokenizer=model.tokenizer
192
+
193
+ if opt.ood:
194
+ database = load_outdomain_dataset(opt.database_path)[opt.database_name]
195
+ else:
196
+ database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name]
197
+
198
+ passage_dataset = TextDataset(database,need_ids=True,out_domain=opt.ood)
199
+ print(len(passage_dataset))
200
+
201
+ passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True)
202
+ passages_dataloder=fabric.setup_dataloaders(passages_dataloder)
203
+ model=fabric.setup(model)
204
+
205
+ if opt.ood:
206
+ train_ids, train_embeddings,train_labels, train_is_mixed = infer(passages_dataloder,fabric,tokenizer,model,ood=True)
207
+ else:
208
+ train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model)
209
+ fabric.barrier()
210
+
211
+ if fabric.global_rank == 0:
212
+ new_index = Indexer(opt.embedding_dim)
213
+ new_index.index_data(train_ids, train_embeddings)
214
+
215
+ old_index = Indexer(opt.embedding_dim)
216
+ old_index.deserialize_from(opt.existed_index_path)
217
+ old_ids = old_index.index_id_to_db_id
218
+
219
+ # Ensure both indexes are of type IndexFlatIP
220
+ # assert isinstance(new_index.index, faiss.IndexFlatIP)
221
+ # assert isinstance(old_index.index, faiss.IndexFlatIP)
222
+
223
+ # Ensure both indexes have the same dimensionality
224
+ assert new_index.index.d == old_index.index.d
225
+
226
+ # Extract vectors from old_index.index
227
+ vectors = old_index.index.reconstruct_n(0, old_index.index.ntotal)
228
+
229
+ # Add vectors to new_index.index
230
+ new_index.index_data(old_ids, vectors)
231
+
232
+ if not os.path.exists(opt.new_save_path):
233
+ os.makedirs(opt.new_save_path)
234
+ new_index.serialize(opt.new_save_path)
235
+
236
+ if opt.ood:
237
+ label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl'))
238
+ is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl'))
239
+ for i in range(len(train_ids)):
240
+ label_dict[train_ids[i]]=train_labels[i]
241
+ is_mixed_dict[train_ids[i]]=train_is_mixed[i]
242
+ #save label_dict using pickle
243
+ with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f:
244
+ pickle.dump(label_dict, f)
245
+ #save is_mixed_dict using pickle
246
+ with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f:
247
+ pickle.dump(is_mixed_dict, f)
248
+
249
+ else:
250
+ label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl'))
251
+ is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl'))
252
+ write_model_dict=load_pkl(os.path.join(opt.existed_index_path, 'write_model_dict.pkl'))
253
+ for i in range(len(train_ids)):
254
+ label_dict[train_ids[i]]=train_labels[i]
255
+ is_mixed_dict[train_ids[i]]=train_is_mixed[i]
256
+ write_model_dict[train_ids[i]]=train_write_model[i]
257
+ #save label_dict using pickle
258
+ with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f:
259
+ pickle.dump(label_dict, f)
260
+ #save is_mixed_dict using pickle
261
+ with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f:
262
+ pickle.dump(is_mixed_dict, f)
263
+ #save write_model_dict using pickle
264
+ with open(os.path.join(opt.new_save_path, 'write_model_dict.pkl'), 'wb') as f:
265
+ pickle.dump(write_model_dict, f)
266
+
267
+
268
+
269
+
270
+
271
+ if __name__ == "__main__":
272
+ parser = argparse.ArgumentParser()
273
+ parser.add_argument('--device_num', type=int, default=1)
274
+ parser.add_argument('--batch_size', type=int, default=128)
275
+ parser.add_argument('--num_workers', type=int, default=8)
276
+ parser.add_argument('--embedding_dim', type=int, default=768)
277
+
278
+ # parser.add_argument('--mode', type=str, default='deepfake', help="deepfake,MGT or MGTDetect_CoCo")
279
+ parser.add_argument("--database_path", type=str, default="data/FALCONSet", help="Path to the data")
280
+ parser.add_argument('--dataset_name', type=str, default='falconset', help="falconset, llmdetectaive, hart")
281
+ parser.add_argument('--database_name', type=str, default='train', help="train,valid,test,test_ood")
282
+ parser.add_argument("--model_path", type=str, default="runs/authscan_v6/model_best.pth",\
283
+ help="Path to the embedding model checkpoint")
284
+ parser.add_argument('--model_name', type=str, default="FacebookAI/xlm-roberta-base", help="Model name")
285
+ parser.add_argument("--save_path", type=str, default="/output", help="Path to save the database")
286
+ parser.add_argument("--add_to_existed_index", type=int, default=0)
287
+ # parser.add_argument("--add_to_existed_index_path", type=str, default="/output", help="Path to save the database")
288
+ parser.add_argument("--ood", type=int, default=0)
289
+ parser.add_argument("--existed_index_path", type=str, default="/output", help="Path of existed index")
290
+ parser.add_argument("--new_save_path", type=str, default="/new_db", help="Path to save the database")
291
+
292
+ parser.add_argument('--seed', type=int, default=0)
293
+ opt = parser.parse_args()
294
+ set_seed(opt.seed)
295
+
296
+ if not opt.add_to_existed_index:
297
+ test(opt)
298
+ else:
299
+ add_to_existed_index(opt)
300
+
infer.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+ from src.index import Indexer
5
+ import torch
6
+ import argparse
7
+ from src.text_embedding import TextEmbeddingModel
8
+ import random
9
+ from collections import Counter
10
+
11
+
12
+ def softmax_weights(scores, temperature=1.0):
13
+ scores = np.array(scores)
14
+ scores = scores / temperature
15
+ e_scores = np.exp(scores - np.max(scores))
16
+ return e_scores / np.sum(e_scores)
17
+
18
+ def normalize_fuzzy_cnt(fuzzy_cnt):
19
+ total = sum(fuzzy_cnt.values())
20
+ if total == 0:
21
+ return fuzzy_cnt
22
+ for key in fuzzy_cnt:
23
+ fuzzy_cnt[key] /= total
24
+ return fuzzy_cnt
25
+
26
+ def class_type_boost(query_type, candidate_type):
27
+ if query_type == candidate_type:
28
+ return 1.3
29
+ elif abs(query_type - candidate_type) == 1:
30
+ return 1.1
31
+ elif abs(query_type - candidate_type) == 2:
32
+ return 0.9
33
+ else:
34
+ return 0.8
35
+
36
+ def set_seed(seed):
37
+ torch.manual_seed(seed)
38
+ torch.cuda.manual_seed(seed)
39
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
40
+ np.random.seed(seed) # Numpy module.
41
+ random.seed(seed) # Python random module.
42
+
43
+ def load_pkl(path):
44
+ with open(path, 'rb') as f:
45
+ return pickle.load(f)
46
+
47
+ def infer_3_class(model, tokenizer, index, label_dict, is_mixed_dict, text, K):
48
+ # model = TextEmbeddingModel(opt.model_name).cuda()
49
+ # state_dict = torch.load(opt.model_path, map_location=model.model.device)
50
+ # new_state_dict={}
51
+ # for key in state_dict.keys():
52
+ # if key.startswith('model.'):
53
+ # new_state_dict[key[6:]]=state_dict[key]
54
+ # model.load_state_dict(state_dict)
55
+ # tokenizer=model.tokenizer
56
+
57
+ # index = Indexer(opt.embedding_dim)
58
+ # index.deserialize_from(opt.database_path)
59
+ # label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
60
+ # is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
61
+
62
+ # text = opt.text
63
+ encoded_text = tokenizer.batch_encode_plus(
64
+ [text],
65
+ return_tensors="pt",
66
+ max_length=512,
67
+ padding="max_length",
68
+ truncation=True,
69
+ )
70
+ encoded_text = {k: v for k, v in encoded_text.items()}
71
+ embeddings = model(encoded_text).cpu().detach().numpy()
72
+ top_ids_and_scores = index.search_knn(embeddings, K)
73
+ pred = []
74
+ for i, (ids, scores) in enumerate(top_ids_and_scores):
75
+ print(f"Top {K} results for text:")
76
+ sorted_scores = np.argsort(scores)
77
+ sorted_scores = sorted_scores[::-1]
78
+
79
+ topk_ids = [ids[j] for j in sorted_scores]
80
+ topk_scores = [scores[j] for j in sorted_scores]
81
+ weights = softmax_weights(topk_scores, temperature=0.1)
82
+
83
+ candidate_models = [is_mixed_dict[int(_id)] for _id in topk_ids]
84
+ initial_pred = Counter(candidate_models).most_common(1)[0][0]
85
+
86
+ fuzzy_cnt = {(1,0): 0.0, (0,10^3): 0.0, (1,1): 0.0}
87
+ for id, weight in zip(topk_ids, weights):
88
+ label = (label_dict[int(id)], is_mixed_dict[int(id)])
89
+ boost = class_type_boost(is_mixed_dict[int(id)],initial_pred)
90
+ fuzzy_cnt[label] += weight * boost
91
+
92
+ final = max(fuzzy_cnt, key=fuzzy_cnt.get)
93
+
94
+ # print(f"Top {opt.K} results for text:")
95
+ # cnt = {(1,0):0,(0,10^3):0,(1,1):0}
96
+ # for j, (id, score) in enumerate(zip(ids, scores)):
97
+ # print(f"{j+1}. ID {id} Label {label_dict[int(id)]} Is_mixed {is_mixed_dict[int(id)]} Score {score}")
98
+ # cnt[(label_dict[int(id)], is_mixed_dict[int(id)])]+=1
99
+ # final = max(cnt, key=cnt.get)
100
+ # pred.append(final)
101
+ if final==(1,0):
102
+ print("Human")
103
+ return 0
104
+ elif final==(0,10^3):
105
+ print("AI")
106
+ return 1
107
+ else:
108
+ print("Mixed")
109
+ return 2
110
+ # pred.append(final)
111
+ return -1
112
+
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument('--embedding_dim', type=int, default=768)
117
+ parser.add_argument('--database_path', type=str, default="database", help="Path to the index file")
118
+
119
+ parser.add_argument("--model_path", type=str, default="core/model.pth",\
120
+ help="Path to the embedding model checkpoint")
121
+ parser.add_argument('--model_name', type=str, default="ZurichNLPZurichNLP/unsup-simcse-xlm-roberta-base", help="Model name")
122
+
123
+ parser.add_argument('--K', type=int, default=20, help="Search [1,K] nearest neighbors,choose the best K")
124
+ parser.add_argument('--pooling', type=str, default="average", help="Pooling method, average or cls")
125
+ parser.add_argument('--text', type=str, default="")
126
+ parser.add_argument('--seed', type=int, default=0)
127
+
128
+ opt = parser.parse_args()
129
+ set_seed(opt.seed)
130
+ infer(opt)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas~=2.0.3
2
+ tqdm~=4.66.4
3
+ torch~=2.3.0
4
+ transformers~=4.41.1
5
+ scikit-learn~=1.3.2
6
+ datasets~=2.19.1
7
+ nltk~=3.8.1
8
+ tiktoken~=0.7.0
9
+ faiss-cpu
10
+ uvicorn
11
+ fastapi
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
File without changes
src/index.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import pickle
9
+ from typing import List, Tuple
10
+
11
+ import faiss
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ class Indexer(object):
16
+
17
+ def __init__(self, vector_sz,device='cpu'):
18
+ self.index = faiss.IndexFlatIP(vector_sz)
19
+ self.device = device
20
+ if self.device == 'cuda':
21
+ self.index = faiss.index_cpu_to_all_gpus(self.index)
22
+ self.index_id_to_db_id = []
23
+
24
+ def index_data(self, ids, embeddings):
25
+ self._update_id_mapping(ids)
26
+ embeddings = embeddings.astype('float32')
27
+ if not self.index.is_trained:
28
+ self.index.train(embeddings)
29
+ self.index.add(embeddings)
30
+
31
+ print(f'Total data indexed {self.index.ntotal}')
32
+
33
+ def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
34
+ query_vectors = query_vectors.astype('float32')
35
+ result = []
36
+ nbatch = (len(query_vectors)-1) // index_batch_size + 1
37
+ for k in tqdm(range(nbatch)):
38
+ start_idx = k*index_batch_size
39
+ end_idx = min((k+1)*index_batch_size, len(query_vectors))
40
+ q = query_vectors[start_idx: end_idx]
41
+ scores, indexes = self.index.search(q, top_docs)
42
+ # convert to external ids
43
+ db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
44
+ result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
45
+ return result
46
+
47
+ def serialize(self, dir_path):
48
+ index_file = os.path.join(dir_path, 'index.faiss')
49
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
50
+ print(f'Serializing index to {index_file}, meta data to {meta_file}')
51
+ if self.device == 'cuda':
52
+ save_index = faiss.index_gpu_to_cpu(self.index)
53
+ else:
54
+ save_index = self.index
55
+ faiss.write_index(save_index, index_file)
56
+ with open(meta_file, mode='wb') as f:
57
+ pickle.dump(self.index_id_to_db_id, f)
58
+
59
+ def deserialize_from(self, dir_path):
60
+ index_file = os.path.join(dir_path, 'index.faiss')
61
+ meta_file = os.path.join(dir_path, 'index_meta.faiss')
62
+ print(f'Loading index from {index_file}, meta data from {meta_file}')
63
+
64
+ self.index = faiss.read_index(index_file)
65
+ if self.device == 'cuda':
66
+ self.index = faiss.index_cpu_to_all_gpus(self.index)
67
+ print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
68
+
69
+ with open(meta_file, "rb") as reader:
70
+ self.index_id_to_db_id = pickle.load(reader)
71
+ assert len(
72
+ self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
73
+
74
+ def _update_id_mapping(self, db_ids: List):
75
+ self.index_id_to_db_id.extend(db_ids)
76
+
77
+ def reset(self):
78
+ self.index.reset()
79
+ self.index_id_to_db_id = []
80
+ print(f'Index reset, total data indexed {self.index.ntotal}')
src/simclr.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from src.text_embedding import TextEmbeddingModel
5
+
6
+ class ClassificationHead(nn.Module):
7
+ """Head for sentence-level classification tasks."""
8
+
9
+ def __init__(self, in_dim, out_dim):
10
+ super(ClassificationHead, self).__init__()
11
+ self.dense1 = nn.Linear(in_dim, in_dim//4)
12
+ self.dense2 = nn.Linear(in_dim//4, in_dim//16)
13
+ self.out_proj = nn.Linear(in_dim//16, out_dim)
14
+
15
+ nn.init.xavier_uniform_(self.dense1.weight)
16
+ nn.init.xavier_uniform_(self.dense2.weight)
17
+ nn.init.xavier_uniform_(self.out_proj.weight)
18
+ nn.init.normal_(self.dense1.bias, std=1e-6)
19
+ nn.init.normal_(self.dense2.bias, std=1e-6)
20
+ nn.init.normal_(self.out_proj.bias, std=1e-6)
21
+
22
+ def forward(self, features):
23
+ x = features
24
+ x = self.dense1(x)
25
+ x = torch.tanh(x)
26
+ x = self.dense2(x)
27
+ x = torch.tanh(x)
28
+ x = self.out_proj(x)
29
+ return x
30
+
31
+ class SimCLR_Classifier_SCL(nn.Module):
32
+ def __init__(self, opt,fabric):
33
+ super(SimCLR_Classifier_SCL, self).__init__()
34
+
35
+ self.temperature = opt.temperature
36
+ self.opt=opt
37
+ self.fabric = fabric
38
+ self.model = TextEmbeddingModel(opt.model_name)
39
+ self.device=self.model.model.device
40
+ if opt.resum:
41
+ state_dict = torch.load(opt.pth_path, map_location=self.device)
42
+ self.model.load_state_dict(state_dict)
43
+ self.esp=torch.tensor(1e-6,device=self.device)
44
+ self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
45
+
46
+ self.a=torch.tensor(opt.a,device=self.device)
47
+ self.d=torch.tensor(opt.d,device=self.device)
48
+ self.only_classifier=opt.only_classifier
49
+
50
+
51
+ def get_encoder(self):
52
+ return self.model
53
+
54
+ def _compute_logits(self, q,q_index1, q_index2,q_label,k,k_index1,k_index2,k_label):
55
+ def cosine_similarity_matrix(q, k):
56
+
57
+ q_norm = F.normalize(q,dim=-1)
58
+ k_norm = F.normalize(k,dim=-1)
59
+ cosine_similarity = q_norm@k_norm.T
60
+
61
+ return cosine_similarity
62
+
63
+ logits=cosine_similarity_matrix(q,k)/self.temperature
64
+
65
+ q_labels=q_label.view(-1, 1)# N,1
66
+ k_labels=k_label.view(1, -1)# 1,N+K
67
+
68
+ same_label=(q_labels==k_labels)# N,N+K
69
+
70
+ #model:model set
71
+ pos_logits_model = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
72
+ neg_logits_model=logits*torch.logical_not(same_label)
73
+ logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1)
74
+
75
+ return logits_model
76
+
77
+ def forward(self, batch, indices1, indices2,label):
78
+ bsz = batch['input_ids'].size(0)
79
+ q = self.model(batch)
80
+ k = q.clone().detach()
81
+ k = self.fabric.all_gather(k).view(-1, k.size(1))
82
+ k_label = self.fabric.all_gather(label).view(-1)
83
+ k_index1 = self.fabric.all_gather(indices1).view(-1)
84
+ k_index2 = self.fabric.all_gather(indices2).view(-1)
85
+ #q:N
86
+ #k:4N
87
+ logits_label = self._compute_logits(q,indices1, indices2,label,k,k_index1,k_index2,k_label)
88
+
89
+ out = self.classifier(q)
90
+
91
+ if self.opt.AA:
92
+ loss_classfiy = F.cross_entropy(out, indices1)
93
+ else:
94
+ loss_classfiy = F.cross_entropy(out, label)
95
+
96
+ gt = torch.zeros(bsz, dtype=torch.long,device=logits_label.device)
97
+
98
+ if self.only_classifier:
99
+ loss_label = torch.tensor(0,device=self.device)
100
+ else:
101
+ loss_label = F.cross_entropy(logits_label, gt)
102
+
103
+ loss = self.a*loss_label+self.d*loss_classfiy
104
+ if self.training:
105
+ return loss,loss_label,loss_classfiy,k,k_label
106
+ else:
107
+ out = self.fabric.all_gather(out).view(-1, out.size(1))
108
+ return loss,out,k,k_label
109
+
110
+
111
+ class SimCLR_Classifier_test(nn.Module):
112
+ def __init__(self, opt,fabric):
113
+ super(SimCLR_Classifier_test, self).__init__()
114
+
115
+ self.fabric = fabric
116
+ self.model = TextEmbeddingModel(opt.model_name)
117
+ self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
118
+ self.device=self.model.model.device
119
+
120
+ def forward(self, batch):
121
+ q = self.model(batch)
122
+ out = self.classifier(q)
123
+ return out
124
+
125
+ class SimCLR_Classifier(nn.Module):
126
+ def __init__(self, opt,fabric):
127
+ super(SimCLR_Classifier, self).__init__()
128
+
129
+ self.temperature = opt.temperature
130
+ self.opt=opt
131
+ self.fabric = fabric
132
+
133
+ self.model = TextEmbeddingModel(opt.model_name)
134
+ if opt.resum:
135
+ state_dict = torch.load(opt.pth_path,
136
+ map_location=self.model.device)
137
+ self.model.load_state_dict(state_dict)
138
+
139
+ self.device = self.model.model.device
140
+ self.esp = torch.tensor(1e-6,device=self.device)
141
+ self.a = torch.tensor(opt.a,
142
+ device=self.device)
143
+ self.b = torch.tensor(opt.b,
144
+ device=self.device)
145
+ self.c = torch.tensor(opt.c,
146
+ device=self.device)
147
+
148
+ self.classifier = ClassificationHead(opt.projection_size,
149
+ opt.classifier_dim)
150
+ self.only_classifier = opt.only_classifier
151
+
152
+
153
+ def get_encoder(self):
154
+ return self.model
155
+
156
+ def _compute_logits(self,
157
+ q,q_index1, q_index2, q_label,
158
+ k,k_index1,k_index2,k_label):
159
+ def cosine_similarity_matrix(q, k):
160
+
161
+ q_norm = F.normalize(q,dim=-1)
162
+ k_norm = F.normalize(k,dim=-1)
163
+ cosine_similarity = q_norm@k_norm.T
164
+ return cosine_similarity
165
+
166
+ logits=cosine_similarity_matrix(q,k)/self.temperature
167
+
168
+ q_index1=q_index1.view(-1, 1)# change to tensor of size N, 1
169
+ q_index2=q_index2.view(-1, 1)# change to tensor of size N, 1
170
+ q_labels=q_label.view(-1, 1)# change to tensor of size N, 1
171
+
172
+ k_index1=k_index1.view(1, -1)# 1,N+K
173
+ k_index2=k_index2.view(1, -1) #1, N+K
174
+ k_labels=k_label.view(1, -1)# 1,N+K
175
+
176
+ same_mixed = (q_index1== k_index1)
177
+ same_set=(q_index2==k_index2)# N,N+K
178
+ same_label=(q_labels==k_labels)# N,N+K
179
+
180
+ is_human=(q_label==1).view(-1)
181
+ is_machine=(q_label==0).view(-1)
182
+
183
+ is_mixed=(q_index1==1).view(-1)
184
+
185
+ #human: human
186
+ pos_logits_human = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
187
+ neg_logits_human=logits*torch.logical_not(same_label)
188
+ logits_human=torch.cat((pos_logits_human.unsqueeze(1), neg_logits_human), dim=1)
189
+ logits_human=logits_human[is_human]
190
+
191
+ #human+ai: general
192
+ pos_logits_mixed = torch.sum(logits*same_mixed,dim=1)/torch.maximum(torch.sum(same_mixed,dim=1),self.esp)
193
+ neg_logits_mixed=logits*torch.logical_not(same_mixed)
194
+ logits_mixed=torch.cat((pos_logits_mixed.unsqueeze(1), neg_logits_mixed), dim=1)
195
+ logits_mixed=logits_mixed[is_mixed]
196
+
197
+ #human+ai: model
198
+ pos_logits_mixed_set = torch.sum(logits*torch.logical_and(same_mixed, same_set),dim=1)/torch.max(torch.sum(torch.logical_and(same_mixed, same_set),dim=1),self.esp)
199
+ neg_logits_mixed_set=logits*torch.logical_not(torch.logical_and(same_mixed, same_set))
200
+ logits_mixed_set=torch.cat((pos_logits_mixed_set.unsqueeze(1), neg_logits_mixed_set), dim=1)
201
+ logits_mixed_set=logits_mixed_set[is_mixed]
202
+
203
+ #model set:label
204
+ pos_logits_set = torch.sum(logits*same_set,dim=1)/torch.max(torch.sum(same_set,dim=1),self.esp)
205
+ neg_logits_set=logits*torch.logical_not(same_set)
206
+ logits_set=torch.cat((pos_logits_set.unsqueeze(1), neg_logits_set), dim=1)
207
+ logits_set=logits_set[is_machine]
208
+
209
+ #label: label
210
+ pos_logits_label = torch.sum(logits*same_label, dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
211
+ neg_logits_label=logits*torch.logical_not(same_label)
212
+ logits_label=torch.cat((pos_logits_label.unsqueeze(1), neg_logits_label), dim=1)
213
+ logits_label=logits_label[is_machine]
214
+
215
+ return logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label
216
+
217
+ def forward(self, encoded_batch, label, indices1, indices2):#, weights):
218
+ # print(len(text))
219
+ q = self.model(encoded_batch)
220
+ k = q.clone().detach()
221
+ k = self.fabric.all_gather(k).view(-1, k.size(1))
222
+ k_label = self.fabric.all_gather(label).view(-1)
223
+ k_index1 = self.fabric.all_gather(indices1).view(-1)
224
+ k_index2 = self.fabric.all_gather(indices2).view(-1)
225
+ #q:N
226
+ #k:4N
227
+ logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label = self._compute_logits(q,indices1, indices2,label,
228
+ k,k_index1,k_index2,k_label)
229
+ out = self.classifier(q)
230
+
231
+ if self.opt.AA:
232
+ loss_classfiy = F.cross_entropy(out, indices1)
233
+ else:
234
+ loss_classfiy = F.cross_entropy(out, label) #, weight=weights)
235
+
236
+ gt_mixed = torch.zeros(logits_mixed.size(0),
237
+ dtype=torch.long,
238
+ device=logits_mixed.device)
239
+ gt_mixed_set = torch.zeros(logits_mixed_set.size(0),
240
+ dtype=torch.long,
241
+ device=logits_mixed_set.device)
242
+ gt_set = torch.zeros(logits_set.size(0),
243
+ dtype=torch.long,
244
+ device=logits_set.device)
245
+ gt_label = torch.zeros(logits_label.size(0),
246
+ dtype=torch.long,
247
+ device=logits_label.device)
248
+ gt_human = torch.zeros(logits_human.size(0),
249
+ dtype=torch.long,
250
+ device=logits_human.device)
251
+
252
+
253
+ loss_mixed = F.cross_entropy(logits_mixed,
254
+ gt_mixed)
255
+ loss_mixed_set = F.cross_entropy(logits_mixed_set,
256
+ gt_mixed_set)
257
+ loss_set = F.cross_entropy(logits_set,
258
+ gt_set)
259
+ loss_label = F.cross_entropy(logits_label,
260
+ gt_label)
261
+ if logits_human.numel()!=0:
262
+ loss_human = F.cross_entropy(logits_human.to(torch.float64),
263
+ gt_human)
264
+ else:
265
+ loss_human=torch.tensor(0,device=self.device)
266
+
267
+ loss = self.a*loss_set + (4*self.b-self.a)*loss_label + self.b*loss_human+ self.b*loss_mixed + \
268
+ 2*self.b*loss_mixed_set+self.c*loss_classfiy
269
+
270
+ if self.training:
271
+ if self.opt.AA:
272
+ return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_human,loss_classfiy,k,k_index1
273
+ else:
274
+ return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_classfiy,loss_human,k,k_label
275
+ else:
276
+ out = self.fabric.all_gather(out).view(-1, out.size(1))
277
+ if self.opt.AA:
278
+ return loss,out,k,k_index1
279
+ else:
280
+ return loss,out,k,k_label
src/text_embedding.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel
4
+
5
+ class TextEmbeddingModel(nn.Module):
6
+ def __init__(self, model_name,output_hidden_states=False):
7
+ super(TextEmbeddingModel, self).__init__()
8
+ self.model_name = model_name
9
+ if output_hidden_states:
10
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, output_hidden_states=True)
11
+ else:
12
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+
15
+ def pooling(self, model_output, attention_mask, use_pooling='average',hidden_states=False):
16
+ if hidden_states:
17
+ model_output.masked_fill(~attention_mask[None,..., None].bool(), 0.0)
18
+ if use_pooling == "average":
19
+ emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None]
20
+ else:
21
+ emb = model_output[:,:, 0]
22
+ emb = emb.permute(1, 0, 2)
23
+ else:
24
+ model_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
25
+ if use_pooling == "average":
26
+ emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
27
+ elif use_pooling == "cls":
28
+ emb = model_output[:, 0]
29
+ return emb
30
+
31
+ def forward(self, encoded_batch, use_pooling='average',hidden_states=False):
32
+ if "t5" in self.model_name.lower():
33
+ input_ids = encoded_batch['input_ids']
34
+ decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
35
+ model_output = self.model(**encoded_batch,
36
+ decoder_input_ids=decoder_input_ids)
37
+ else:
38
+ model_output = self.model(**encoded_batch)
39
+
40
+ if 'bge' in self.model_name.lower() or 'mxbai' in self.model_name.lower():
41
+ use_pooling = 'cls'
42
+ if isinstance(model_output, tuple):
43
+ model_output = model_output[0]
44
+ if isinstance(model_output, dict):
45
+ if hidden_states:
46
+ model_output = model_output["hidden_states"]
47
+ model_output = torch.stack(model_output, dim=0)
48
+ else:
49
+ model_output = model_output["last_hidden_state"]
50
+
51
+ emb = self.pooling(model_output, encoded_batch['attention_mask'], use_pooling,hidden_states)
52
+ emb = torch.nn.functional.normalize(emb, dim=-1)
53
+ return emb
54
+
55
+
utils/__init__.py ADDED
File without changes
utils/load_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+
3
+ import os
4
+ import json
5
+ import random
6
+ import hashlib
7
+
8
+ def stable_long_hash(input_string):
9
+ hash_object = hashlib.sha256(input_string.encode())
10
+ hex_digest = hash_object.hexdigest()
11
+ int_hash = int(hex_digest, 16)
12
+ long_long_hash = (int_hash & ((1 << 63) - 1))
13
+ return long_long_hash
14
+
15
+ model_map_authscan = {
16
+ "gpt-4o-mini-text": 1,
17
+ "gemini-2.0-text": 2,
18
+ "deepseek-text": 3,
19
+ "llama-text": 4
20
+ }
21
+
22
+ model_map_llmdetectaive = {
23
+ "gemma-text": 1,
24
+ "mixtral-text": 2,
25
+ "llama3-text": 3
26
+ }
27
+
28
+ model_map_hart = {
29
+ "claude-text": 1,
30
+ "gemini-text": 2,
31
+ "gpt-text": 3
32
+ }
33
+
34
+ def load_dataset(dataset_name,path=None):
35
+ dataset = {
36
+ "train": [],
37
+ "valid": [],
38
+ "test": []
39
+ }
40
+ if dataset_name == "falconset":
41
+ model_map = model_map_authscan
42
+ elif dataset_name == "llmdetectaive":
43
+ model_map = model_map_llmdetectaive
44
+ elif dataset_name == "hart":
45
+ model_map = model_map_hart
46
+
47
+ folder = os.listdir(path)
48
+ # print(folder)
49
+ for sub in folder:
50
+ sub_path = os.path.join(path, sub)
51
+ files = os.listdir(sub_path)
52
+ for file in files:
53
+ if not file.endswith('.jsonl'):
54
+ continue
55
+ file_path = os.path.join(sub_path, file)
56
+ key_name = file.split('.')[0]
57
+
58
+ assert key_name in dataset.keys(), f'{key_name} is not in dataset.keys()'
59
+ with open(file_path, 'r') as f:
60
+ data = [json.loads(line) for line in f]
61
+ for i in range(len(data)):
62
+ dct = {}
63
+ dct['text'] = data[i]['text']
64
+ if sub == "human-text":
65
+ dct['label'] = "human"
66
+ dct['label_detailed'] = "human"
67
+ dct['index'] = (1,0,0)
68
+ elif sub.startswith("human---"):
69
+ dct['label'] = "human+AI"
70
+ model = sub.split("---")[1]
71
+ dct['label_detailed'] = model
72
+ dct['index'] = (1, 1, model_map[model])
73
+ else:
74
+ dct['label'] = "AI"
75
+ dct['label_detailed'] = sub
76
+ dct['index'] = (0, 10^3, model_map[sub])
77
+ dataset[key_name].append(dct)
78
+ return dataset
79
+
80
+ def load_outdomain_dataset(path):
81
+ dataset = {
82
+ "valid": [],
83
+ "test": []
84
+ }
85
+ folder = os.listdir(path)
86
+ for sub in folder:
87
+ sub_path = os.path.join(path, sub)
88
+ files = os.listdir(sub_path)
89
+ for file in files:
90
+ if not file.endswith('.jsonl'):
91
+ continue
92
+ file_path = os.path.join(sub_path, file)
93
+ key_name = file.split('.')[0]
94
+ assert key_name in dataset.keys(), f'{key_name} is not in dataset.keys()'
95
+ with open(file_path, 'r', encoding='utf-8') as f:
96
+ data = [json.loads(line) for line in f]
97
+ for i in range(len(data)):
98
+ dct = {}
99
+ dct['text'] = data[i]['text']
100
+ if sub == "human-text":
101
+ dct['label'] = "human"
102
+ dct['label_detailed'] = "human"
103
+ dct['index'] = (1,0)
104
+ elif sub.startswith("human---"):
105
+ dct['label'] = "human+AI"
106
+ model = sub.split("---")[1]
107
+ dct['label_detailed'] = model
108
+ dct['index'] = (1, 1)
109
+ else:
110
+ dct['label'] = "AI"
111
+ dct['label_detailed'] = sub
112
+ dct['index'] = (0, 10^3)
113
+ dataset[key_name].append(dct)
114
+ return dataset
115
+
116
+ def load_dataset_conditional_lang(path=None, language='vi', seed=42):
117
+ dataset = {
118
+ "train": [],
119
+ "val": [],
120
+ "test": []
121
+ }
122
+ combined_data = []
123
+
124
+ random.seed(seed) # for reproducibility
125
+ folder = os.listdir(path)
126
+ print("Subfolders:", folder)
127
+
128
+ for sub in folder:
129
+ sub_path = os.path.join(path, sub)
130
+ if not os.path.isdir(sub_path):
131
+ continue
132
+ files = os.listdir(sub_path)
133
+
134
+ for file in files:
135
+ if not file.endswith('.jsonl') or language not in file:
136
+ continue
137
+
138
+ file_path = os.path.join(sub_path, file)
139
+ with open(file_path, 'r', encoding='utf-8') as f:
140
+ data = [json.loads(line) for line in f]
141
+
142
+ for entry in data:
143
+ if 'content' not in entry:
144
+ print("Key does not exist!")
145
+ continue
146
+
147
+ dct = {}
148
+ dct['text'] = entry['content']
149
+
150
+ if sub == "human":
151
+ dct['label'] = "human"
152
+ dct['label_detailed'] = "human"
153
+ dct['index'] = (1, 0, 0)
154
+ elif sub == "human+AI":
155
+ model = entry['label_detailed'].split("+")[1]
156
+ dct['label'] = "human+AI"
157
+ dct['label_detailed'] = model
158
+ dct['index'] = (1, 1, model_map[model])
159
+ else:
160
+ dct['label'] = "AI"
161
+ dct['label_detailed'] = entry['label_detailed']
162
+ dct['index'] = (0, 10**3, model_map[entry['label_detailed']])
163
+
164
+ combined_data.append(dct)
165
+
166
+ random.shuffle(combined_data)
167
+ total = len(combined_data)
168
+ train_end = int(total * 0.9)
169
+ val_end = train_end + int(total * 0.05)
170
+
171
+ dataset['train'] = combined_data[:train_end]
172
+ dataset['val'] = combined_data[train_end:val_end]
173
+ dataset['test'] = combined_data[val_end:]
174
+
175
+ print(f"Total: {total} | Train: {len(dataset['train'])} | Val: {len(dataset['val'])} | Test: {len(dataset['test'])}")
176
+ return dataset
177
+
178
+
179
+
180
+ class TextDataset(Dataset):
181
+ def __init__(self, dataset,need_ids=True,out_domain=0):
182
+ self.dataset = dataset
183
+ self.need_ids=need_ids
184
+ self.out_domain = out_domain
185
+
186
+ def get_class(self):
187
+ return self.classes
188
+
189
+ def __len__(self):
190
+ return len(self.dataset)
191
+
192
+ def __getitem__(self, idx):
193
+ text, label, label_detailed, index = self.dataset[idx].values()
194
+ id = stable_long_hash(text)
195
+ if self.out_domain:
196
+ label, is_mixed = index
197
+ if self.need_ids:
198
+ return int(id), text, int(label), int(is_mixed)
199
+ return text, int(label), int(is_mixed)
200
+ else:
201
+ label, is_mixed, write_model = index
202
+ if self.need_ids:
203
+ return int(id), text, int(label), int(is_mixed), int(write_model)
204
+ return text, int(label), int(is_mixed), int(write_model)
205
+
utils/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, mean_absolute_error, hamming_loss
2
+ import numpy as np
3
+ from sklearn.preprocessing import MultiLabelBinarizer
4
+
5
+ def find_top_n(embeddings,n,index,data):
6
+ if len(embeddings.shape) == 1:
7
+ embeddings = embeddings.reshape(1, -1)
8
+ top_ids_and_scores = index.search_knn(embeddings, n)
9
+ data_ans=[]
10
+ for i, (ids, scores) in enumerate(top_ids_and_scores):
11
+ data_now=[]
12
+ for id in ids:
13
+ data_now.append((data[0][int(id)],data[1][int(id)],data[2][int(id)]))
14
+ data_ans.append(data_now)
15
+ return data_ans
16
+
17
+
18
+ def print_line(class_name, metrics, is_header=False):
19
+ if is_header:
20
+ line = f"| {'Class':<10} | " + " | ".join([f"{metric:<10}" for metric in metrics])
21
+ else:
22
+ line = f"| {class_name:<10} | " + " | ".join([f"{metrics[metric]:<10.3f}" for metric in metrics])
23
+ print(line)
24
+ if is_header:
25
+ print('-' * len(line))
26
+
27
+ def calculate_per_class_metrics(classes, ground_truth, predictions):
28
+ # Convert ground truth and predictions to numeric format
29
+ gt_numeric = np.array([int(gt) for gt in ground_truth])
30
+ pred_numeric = np.array([int(pred) for pred in predictions])
31
+
32
+ results = {}
33
+ for i, class_name in enumerate(classes):
34
+ # For each class, calculate the 'vs rest' binary labels
35
+ gt_binary = (gt_numeric == i).astype(int)
36
+ pred_binary = (pred_numeric == i).astype(int)
37
+
38
+ # Calculate metrics, handling cases where a class is not present in predictions or ground truth
39
+ precision = precision_score(gt_binary, pred_binary, zero_division=0)
40
+ recall = recall_score(gt_binary, pred_binary, zero_division=0)
41
+ f1 = f1_score(gt_binary, pred_binary, zero_division=0)
42
+ acc = np.mean(gt_binary == pred_binary)
43
+ # Calculate recall for all other classes as 'rest'
44
+ rest_recall = recall_score(1 - gt_binary, 1 - pred_binary, zero_division=0)
45
+
46
+ results[class_name] = {
47
+ 'Precision': precision,
48
+ 'Recall': recall,
49
+ 'F1 Score': f1,
50
+ 'Accuracy': acc,
51
+ 'Avg Recall (with rest)': (recall + rest_recall) / 2
52
+ }
53
+
54
+ print_line("Metric", results[classes[0]], is_header=True)
55
+ for class_name, metrics in results.items():
56
+ print_line(class_name, metrics)
57
+ overall_metrics = {metric_name: np.mean([metrics[metric_name] for metrics in results.values()]) for metric_name in results[classes[0]].keys()}
58
+ print_line("Overall", overall_metrics)
59
+
60
+ def calculate_metrics(y_true, y_pred):
61
+ accuracy = accuracy_score(y_true, y_pred)
62
+ avg_f1 = f1_score(y_true, y_pred, average='macro')
63
+ avg_recall = recall_score(y_true, y_pred, average='macro')
64
+ return accuracy, avg_f1,avg_recall
65
+
66
+ def compute_three_recalls(labels, preds):
67
+ all_n, all_p, tn, tp = 0, 0, 0, 0
68
+ for label, pred in zip(labels, preds):
69
+ if label == '0':
70
+ all_p += 1
71
+ if label == '1':
72
+ all_n += 1
73
+ if pred is not None and label == pred == '0':
74
+ tp += 1
75
+ if pred is not None and label == pred == '1':
76
+ tn += 1
77
+ if pred is None:
78
+ continue
79
+ machine_rec , human_rec= tp * 100 / all_p if all_p != 0 else 0, tn * 100 / all_n if all_n != 0 else 0
80
+ avg_rec = (human_rec + machine_rec) / 2
81
+ return (human_rec, machine_rec, avg_rec)
82
+
83
+
84
+ def compute_metrics(labels, preds,ids=None, full_labels=False):
85
+ if ids is not None:
86
+ # unique ids
87
+ dict_labels,dict_preds={},{}
88
+ for i in range(len(ids)):
89
+ dict_labels[ids[i]]=labels[i]
90
+ dict_preds[ids[i]]=preds[i]
91
+ labels=list(dict_labels.values())
92
+ preds=list(dict_preds.values())
93
+
94
+ if not full_labels:
95
+ labels_map = {(1,0): 0, (0,10^3): 1, (1,1): 2}
96
+ labels_bin = [labels_map[tup] for tup in labels]
97
+ preds_bin = [labels_map[tup] for tup in preds]
98
+
99
+ else:
100
+ labels_map ={
101
+ (1, 0, 0): 0, # Human
102
+ (0, 10^3, 1): 1, (0, 10^3, 2): 2, (0, 10^3, 3): 3, (0, 10^3, 4): 4, # AI
103
+ (1, 1, 1): 5, (1, 1, 2): 6, (1, 1, 3): 7, (1, 1, 4): 8 # Human+AI
104
+ }
105
+ labels_bin = [labels_map[tup] for tup in labels]
106
+ preds_bin = [labels_map[tup] for tup in preds]
107
+ acc = accuracy_score(labels_bin, preds_bin)
108
+ precision = precision_score(labels_bin, preds_bin, average="macro")
109
+ recall = recall_score(labels_bin, preds_bin, average="macro")
110
+ f1 = f1_score(labels_bin, preds_bin, average="macro")
111
+ mse = mean_squared_error(labels_bin, preds_bin)
112
+ mae = mean_absolute_error(labels_bin, preds_bin)
113
+
114
+ return (acc, precision, recall, f1, mse, mae)
115
+
116
+ def compute_metrics_train(labels, preds,ids=None):
117
+ if ids is not None:
118
+ # unique ids
119
+ dict_labels,dict_preds={},{}
120
+ for i in range(len(ids)):
121
+ dict_labels[ids[i]]=labels[i]
122
+ dict_preds[ids[i]]=preds[i]
123
+ labels=list(dict_labels.values())
124
+ preds=list(dict_preds.values())
125
+
126
+ human_rec, machine_rec, avg_rec = compute_three_recalls(labels, preds)
127
+ acc = accuracy_score(labels, preds)
128
+ precision = precision_score(labels, preds, average="macro")
129
+ recall = recall_score(labels, preds, average="macro")
130
+ f1 = f1_score(labels, preds, average="macro")
131
+
132
+ return (human_rec, machine_rec, avg_rec, acc, precision, recall, f1)