Spaces:
Sleeping
Sleeping
ngocminhta
commited on
Commit
·
3fef185
1
Parent(s):
5287779
upload hf
Browse files- app.py +80 -0
- gen_database.py +300 -0
- infer.py +130 -0
- requirements.txt +11 -0
- src/.DS_Store +0 -0
- src/__init__.py +0 -0
- src/index.py +80 -0
- src/simclr.py +280 -0
- src/text_embedding.py +55 -0
- utils/__init__.py +0 -0
- utils/load_dataset.py +205 -0
- utils/utils.py +132 -0
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
import torch
|
5 |
+
from src.text_embedding import TextEmbeddingModel
|
6 |
+
from src.index import Indexer
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
from infer import infer_3_class
|
10 |
+
import uvicorn
|
11 |
+
|
12 |
+
app = FastAPI()
|
13 |
+
|
14 |
+
origins = ["*"]
|
15 |
+
|
16 |
+
app.add_middleware(
|
17 |
+
CORSMiddleware,
|
18 |
+
allow_origins=origins,
|
19 |
+
allow_credentials=True,
|
20 |
+
allow_methods=["*"],
|
21 |
+
allow_headers=["*"],
|
22 |
+
)
|
23 |
+
|
24 |
+
class Opt:
|
25 |
+
def __init__(self):
|
26 |
+
self.model_name = "./unsup-simcse-xlm-roberta-base"
|
27 |
+
self.model_path = "core/model.pth"
|
28 |
+
self.database_path = "core/seen_db"
|
29 |
+
self.embedding_dim = 768
|
30 |
+
self.device_num = 1
|
31 |
+
|
32 |
+
opt = Opt()
|
33 |
+
|
34 |
+
def load_pkl(path):
|
35 |
+
with open(path, 'rb') as f:
|
36 |
+
return pickle.load(f)
|
37 |
+
|
38 |
+
@app.on_event("startup")
|
39 |
+
def load_model_resources():
|
40 |
+
global model, tokenizer, index, label_dict, is_mixed_dict
|
41 |
+
|
42 |
+
model = TextEmbeddingModel(opt.model_name)
|
43 |
+
state_dict = torch.load(opt.model_path, map_location=model.model.device)
|
44 |
+
new_state_dict={}
|
45 |
+
for key in state_dict.keys():
|
46 |
+
if key.startswith('model.'):
|
47 |
+
new_state_dict[key[6:]]=state_dict[key]
|
48 |
+
model.load_state_dict(state_dict)
|
49 |
+
tokenizer=model.tokenizer
|
50 |
+
|
51 |
+
index = Indexer(opt.embedding_dim)
|
52 |
+
index.deserialize_from(opt.database_path)
|
53 |
+
label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
|
54 |
+
is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
|
55 |
+
|
56 |
+
|
57 |
+
@app.route('/predict', methods=['POST'])
|
58 |
+
async def predict(request: Request):
|
59 |
+
data = await request.json()
|
60 |
+
mode = data.get("mode", "normal").lower()
|
61 |
+
text_list = data.get("text", [])
|
62 |
+
|
63 |
+
if mode == "normal":
|
64 |
+
results = []
|
65 |
+
for text in text_list:
|
66 |
+
result = infer_3_class(model=model,
|
67 |
+
tokenizer=tokenizer,
|
68 |
+
index=index,
|
69 |
+
label_dict=label_dict,
|
70 |
+
is_mixed_dict=is_mixed_dict,
|
71 |
+
text=text,
|
72 |
+
K=20)
|
73 |
+
results.append(result)
|
74 |
+
return JSONResponse(content={"results": results})
|
75 |
+
elif mode == "advanced":
|
76 |
+
return 0
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
port = int(os.getenv("PORT", 8000))
|
80 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
gen_database.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import random
|
4 |
+
import faiss
|
5 |
+
from src.index import Indexer
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from lightning import Fabric
|
11 |
+
from tqdm import tqdm
|
12 |
+
import argparse
|
13 |
+
from src.text_embedding import TextEmbeddingModel
|
14 |
+
from utils.load_dataset import load_dataset, TextDataset, load_outdomain_dataset
|
15 |
+
|
16 |
+
def load_pkl(path):
|
17 |
+
with open(path, 'rb') as f:
|
18 |
+
return pickle.load(f)
|
19 |
+
|
20 |
+
def infer(passages_dataloder,fabric,tokenizer,model,ood=False):
|
21 |
+
if fabric.global_rank == 0 :
|
22 |
+
passages_dataloder=tqdm(passages_dataloder,total=len(passages_dataloder))
|
23 |
+
if ood:
|
24 |
+
allids, allembeddings,alllabels,all_is_mixed= [],[],[],[]
|
25 |
+
else:
|
26 |
+
allids, allembeddings,alllabels,all_is_mixed,all_write_model= [],[],[],[],[]
|
27 |
+
model.model.eval()
|
28 |
+
with torch.no_grad():
|
29 |
+
for batch in passages_dataloder:
|
30 |
+
if ood:
|
31 |
+
ids, text, label, is_mixed = batch
|
32 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
33 |
+
text,
|
34 |
+
return_tensors="pt",
|
35 |
+
max_length=512,
|
36 |
+
padding="max_length",
|
37 |
+
# padding=True,
|
38 |
+
truncation=True,
|
39 |
+
)
|
40 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
41 |
+
# output = model(**encoded_batch).last_hidden_state
|
42 |
+
# embeddings = pooling(output, encoded_batch)
|
43 |
+
# print(encoded_batch)
|
44 |
+
embeddings = model(encoded_batch)
|
45 |
+
# print(encoded_batch['input_ids'].shape)
|
46 |
+
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1))
|
47 |
+
label = fabric.all_gather(label).view(-1)
|
48 |
+
ids = fabric.all_gather(ids).view(-1)
|
49 |
+
is_mixed = fabric.all_gather(is_mixed).view(-1)
|
50 |
+
if fabric.global_rank == 0 :
|
51 |
+
allembeddings.append(embeddings.cpu())
|
52 |
+
allids.extend(ids.cpu().tolist())
|
53 |
+
alllabels.extend(label.cpu().tolist())
|
54 |
+
all_is_mixed.extend(is_mixed.cpu().tolist())
|
55 |
+
else:
|
56 |
+
ids, text, label, is_mixed, write_model = batch
|
57 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
58 |
+
text,
|
59 |
+
return_tensors="pt",
|
60 |
+
max_length=512,
|
61 |
+
padding="max_length",
|
62 |
+
# padding=True,
|
63 |
+
truncation=True,
|
64 |
+
)
|
65 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
66 |
+
# output = model(**encoded_batch).last_hidden_state
|
67 |
+
# embeddings = pooling(output, encoded_batch)
|
68 |
+
# print(encoded_batch)
|
69 |
+
embeddings = model(encoded_batch)
|
70 |
+
# print(encoded_batch['input_ids'].shape)
|
71 |
+
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(1))
|
72 |
+
label = fabric.all_gather(label).view(-1)
|
73 |
+
ids = fabric.all_gather(ids).view(-1)
|
74 |
+
is_mixed = fabric.all_gather(is_mixed).view(-1)
|
75 |
+
write_model = fabric.all_gather(write_model).view(-1)
|
76 |
+
if fabric.global_rank == 0 :
|
77 |
+
allembeddings.append(embeddings.cpu())
|
78 |
+
allids.extend(ids.cpu().tolist())
|
79 |
+
alllabels.extend(label.cpu().tolist())
|
80 |
+
all_is_mixed.extend(is_mixed.cpu().tolist())
|
81 |
+
all_write_model.extend(write_model.cpu().tolist())
|
82 |
+
if fabric.global_rank == 0 :
|
83 |
+
allembeddings = torch.cat(allembeddings, dim=0)
|
84 |
+
epsilon = 1e-6
|
85 |
+
if ood:
|
86 |
+
emb_dict,label_dict,is_mixed_dict={},{},{}
|
87 |
+
allembeddings= F.normalize(allembeddings,dim=-1)
|
88 |
+
for i in range(len(allids)):
|
89 |
+
emb_dict[allids[i]]=allembeddings[i]
|
90 |
+
label_dict[allids[i]]=alllabels[i]
|
91 |
+
is_mixed_dict[allids[i]]=all_is_mixed[i]
|
92 |
+
allids,allembeddings,alllabels,all_is_mixed=[],[],[],[]
|
93 |
+
for key in emb_dict:
|
94 |
+
allids.append(key)
|
95 |
+
allembeddings.append(emb_dict[key])
|
96 |
+
alllabels.append(label_dict[key])
|
97 |
+
all_is_mixed.append(is_mixed_dict[key])
|
98 |
+
allembeddings = torch.stack(allembeddings, dim=0)
|
99 |
+
return allids,allembeddings.numpy(),alllabels,all_is_mixed
|
100 |
+
else:
|
101 |
+
emb_dict,label_dict,is_mixed_dict,write_model_dict={},{},{},{}
|
102 |
+
allembeddings= F.normalize(allembeddings,dim=-1)
|
103 |
+
for i in range(len(allids)):
|
104 |
+
emb_dict[allids[i]]=allembeddings[i]
|
105 |
+
label_dict[allids[i]]=alllabels[i]
|
106 |
+
is_mixed_dict[allids[i]]=all_is_mixed[i]
|
107 |
+
write_model_dict[allids[i]]=all_write_model[i]
|
108 |
+
allids,allembeddings,alllabels,all_is_mixed,all_write_model=[],[],[],[],[]
|
109 |
+
for key in emb_dict:
|
110 |
+
allids.append(key)
|
111 |
+
allembeddings.append(emb_dict[key])
|
112 |
+
alllabels.append(label_dict[key])
|
113 |
+
all_is_mixed.append(is_mixed_dict[key])
|
114 |
+
all_write_model.append(write_model_dict[key])
|
115 |
+
allembeddings = torch.stack(allembeddings, dim=0)
|
116 |
+
return allids, allembeddings.numpy(),alllabels,all_is_mixed,all_write_model
|
117 |
+
else:
|
118 |
+
if ood:
|
119 |
+
return [],[],[],[]
|
120 |
+
return [],[],[],[],[]
|
121 |
+
|
122 |
+
def set_seed(seed):
|
123 |
+
torch.manual_seed(seed)
|
124 |
+
torch.cuda.manual_seed(seed)
|
125 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
126 |
+
np.random.seed(seed) # Numpy module.
|
127 |
+
random.seed(seed) # Python random module.
|
128 |
+
|
129 |
+
def test(opt):
|
130 |
+
if opt.device_num>1:
|
131 |
+
fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp')
|
132 |
+
else:
|
133 |
+
fabric = Fabric(accelerator="cuda",devices=opt.device_num)
|
134 |
+
fabric.launch()
|
135 |
+
model = TextEmbeddingModel(opt.model_name).cuda()
|
136 |
+
state_dict = torch.load(opt.model_path, map_location=model.model.device)
|
137 |
+
new_state_dict={}
|
138 |
+
for key in state_dict.keys():
|
139 |
+
if key.startswith('model.'):
|
140 |
+
new_state_dict[key[6:]]=state_dict[key]
|
141 |
+
model.load_state_dict(state_dict)
|
142 |
+
tokenizer=model.tokenizer
|
143 |
+
database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name]
|
144 |
+
passage_dataset = TextDataset(database,need_ids=True)
|
145 |
+
print(len(passage_dataset))
|
146 |
+
|
147 |
+
passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True)
|
148 |
+
passages_dataloder=fabric.setup_dataloaders(passages_dataloder)
|
149 |
+
model=fabric.setup(model)
|
150 |
+
|
151 |
+
train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model)
|
152 |
+
fabric.barrier()
|
153 |
+
|
154 |
+
if fabric.global_rank == 0:
|
155 |
+
index = Indexer(opt.embedding_dim)
|
156 |
+
index.index_data(train_ids, train_embeddings)
|
157 |
+
label_dict={}
|
158 |
+
is_mixed_dict={}
|
159 |
+
write_model_dict={}
|
160 |
+
for i in range(len(train_ids)):
|
161 |
+
label_dict[train_ids[i]]=train_labels[i]
|
162 |
+
is_mixed_dict[train_ids[i]]=train_is_mixed[i]
|
163 |
+
write_model_dict[train_ids[i]]=train_write_model[i]
|
164 |
+
|
165 |
+
if not os.path.exists(opt.save_path):
|
166 |
+
os.makedirs(opt.save_path)
|
167 |
+
index.serialize(opt.save_path)
|
168 |
+
#save label_dict using pickle
|
169 |
+
with open(os.path.join(opt.save_path, 'label_dict.pkl'), 'wb') as f:
|
170 |
+
pickle.dump(label_dict, f)
|
171 |
+
#save is_mixed_dict using pickle
|
172 |
+
with open(os.path.join(opt.save_path, 'is_mixed_dict.pkl'), 'wb') as f:
|
173 |
+
pickle.dump(is_mixed_dict, f)
|
174 |
+
#save write_model_dict using pickle
|
175 |
+
with open(os.path.join(opt.save_path, 'write_model_dict.pkl'), 'wb') as f:
|
176 |
+
pickle.dump(write_model_dict, f)
|
177 |
+
|
178 |
+
def add_to_existed_index(opt):
|
179 |
+
if opt.device_num>1:
|
180 |
+
fabric = Fabric(accelerator="cuda",devices=opt.device_num,strategy='ddp')
|
181 |
+
else:
|
182 |
+
fabric = Fabric(accelerator="cuda",devices=opt.device_num)
|
183 |
+
fabric.launch()
|
184 |
+
model = TextEmbeddingModel(opt.model_name).cuda()
|
185 |
+
state_dict = torch.load(opt.model_path, map_location=model.model.device)
|
186 |
+
new_state_dict={}
|
187 |
+
for key in state_dict.keys():
|
188 |
+
if key.startswith('model.'):
|
189 |
+
new_state_dict[key[6:]]=state_dict[key]
|
190 |
+
model.load_state_dict(state_dict)
|
191 |
+
tokenizer=model.tokenizer
|
192 |
+
|
193 |
+
if opt.ood:
|
194 |
+
database = load_outdomain_dataset(opt.database_path)[opt.database_name]
|
195 |
+
else:
|
196 |
+
database = load_dataset(opt.dataset_name,opt.database_path)[opt.database_name]
|
197 |
+
|
198 |
+
passage_dataset = TextDataset(database,need_ids=True,out_domain=opt.ood)
|
199 |
+
print(len(passage_dataset))
|
200 |
+
|
201 |
+
passages_dataloder = DataLoader(passage_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, pin_memory=True)
|
202 |
+
passages_dataloder=fabric.setup_dataloaders(passages_dataloder)
|
203 |
+
model=fabric.setup(model)
|
204 |
+
|
205 |
+
if opt.ood:
|
206 |
+
train_ids, train_embeddings,train_labels, train_is_mixed = infer(passages_dataloder,fabric,tokenizer,model,ood=True)
|
207 |
+
else:
|
208 |
+
train_ids, train_embeddings,train_labels, train_is_mixed, train_write_model = infer(passages_dataloder,fabric,tokenizer,model)
|
209 |
+
fabric.barrier()
|
210 |
+
|
211 |
+
if fabric.global_rank == 0:
|
212 |
+
new_index = Indexer(opt.embedding_dim)
|
213 |
+
new_index.index_data(train_ids, train_embeddings)
|
214 |
+
|
215 |
+
old_index = Indexer(opt.embedding_dim)
|
216 |
+
old_index.deserialize_from(opt.existed_index_path)
|
217 |
+
old_ids = old_index.index_id_to_db_id
|
218 |
+
|
219 |
+
# Ensure both indexes are of type IndexFlatIP
|
220 |
+
# assert isinstance(new_index.index, faiss.IndexFlatIP)
|
221 |
+
# assert isinstance(old_index.index, faiss.IndexFlatIP)
|
222 |
+
|
223 |
+
# Ensure both indexes have the same dimensionality
|
224 |
+
assert new_index.index.d == old_index.index.d
|
225 |
+
|
226 |
+
# Extract vectors from old_index.index
|
227 |
+
vectors = old_index.index.reconstruct_n(0, old_index.index.ntotal)
|
228 |
+
|
229 |
+
# Add vectors to new_index.index
|
230 |
+
new_index.index_data(old_ids, vectors)
|
231 |
+
|
232 |
+
if not os.path.exists(opt.new_save_path):
|
233 |
+
os.makedirs(opt.new_save_path)
|
234 |
+
new_index.serialize(opt.new_save_path)
|
235 |
+
|
236 |
+
if opt.ood:
|
237 |
+
label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl'))
|
238 |
+
is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl'))
|
239 |
+
for i in range(len(train_ids)):
|
240 |
+
label_dict[train_ids[i]]=train_labels[i]
|
241 |
+
is_mixed_dict[train_ids[i]]=train_is_mixed[i]
|
242 |
+
#save label_dict using pickle
|
243 |
+
with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f:
|
244 |
+
pickle.dump(label_dict, f)
|
245 |
+
#save is_mixed_dict using pickle
|
246 |
+
with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f:
|
247 |
+
pickle.dump(is_mixed_dict, f)
|
248 |
+
|
249 |
+
else:
|
250 |
+
label_dict=load_pkl(os.path.join(opt.existed_index_path, 'label_dict.pkl'))
|
251 |
+
is_mixed_dict=load_pkl(os.path.join(opt.existed_index_path, 'is_mixed_dict.pkl'))
|
252 |
+
write_model_dict=load_pkl(os.path.join(opt.existed_index_path, 'write_model_dict.pkl'))
|
253 |
+
for i in range(len(train_ids)):
|
254 |
+
label_dict[train_ids[i]]=train_labels[i]
|
255 |
+
is_mixed_dict[train_ids[i]]=train_is_mixed[i]
|
256 |
+
write_model_dict[train_ids[i]]=train_write_model[i]
|
257 |
+
#save label_dict using pickle
|
258 |
+
with open(os.path.join(opt.new_save_path, 'label_dict.pkl'), 'wb') as f:
|
259 |
+
pickle.dump(label_dict, f)
|
260 |
+
#save is_mixed_dict using pickle
|
261 |
+
with open(os.path.join(opt.new_save_path, 'is_mixed_dict.pkl'), 'wb') as f:
|
262 |
+
pickle.dump(is_mixed_dict, f)
|
263 |
+
#save write_model_dict using pickle
|
264 |
+
with open(os.path.join(opt.new_save_path, 'write_model_dict.pkl'), 'wb') as f:
|
265 |
+
pickle.dump(write_model_dict, f)
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
parser = argparse.ArgumentParser()
|
273 |
+
parser.add_argument('--device_num', type=int, default=1)
|
274 |
+
parser.add_argument('--batch_size', type=int, default=128)
|
275 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
276 |
+
parser.add_argument('--embedding_dim', type=int, default=768)
|
277 |
+
|
278 |
+
# parser.add_argument('--mode', type=str, default='deepfake', help="deepfake,MGT or MGTDetect_CoCo")
|
279 |
+
parser.add_argument("--database_path", type=str, default="data/FALCONSet", help="Path to the data")
|
280 |
+
parser.add_argument('--dataset_name', type=str, default='falconset', help="falconset, llmdetectaive, hart")
|
281 |
+
parser.add_argument('--database_name', type=str, default='train', help="train,valid,test,test_ood")
|
282 |
+
parser.add_argument("--model_path", type=str, default="runs/authscan_v6/model_best.pth",\
|
283 |
+
help="Path to the embedding model checkpoint")
|
284 |
+
parser.add_argument('--model_name', type=str, default="FacebookAI/xlm-roberta-base", help="Model name")
|
285 |
+
parser.add_argument("--save_path", type=str, default="/output", help="Path to save the database")
|
286 |
+
parser.add_argument("--add_to_existed_index", type=int, default=0)
|
287 |
+
# parser.add_argument("--add_to_existed_index_path", type=str, default="/output", help="Path to save the database")
|
288 |
+
parser.add_argument("--ood", type=int, default=0)
|
289 |
+
parser.add_argument("--existed_index_path", type=str, default="/output", help="Path of existed index")
|
290 |
+
parser.add_argument("--new_save_path", type=str, default="/new_db", help="Path to save the database")
|
291 |
+
|
292 |
+
parser.add_argument('--seed', type=int, default=0)
|
293 |
+
opt = parser.parse_args()
|
294 |
+
set_seed(opt.seed)
|
295 |
+
|
296 |
+
if not opt.add_to_existed_index:
|
297 |
+
test(opt)
|
298 |
+
else:
|
299 |
+
add_to_existed_index(opt)
|
300 |
+
|
infer.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
from src.index import Indexer
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
from src.text_embedding import TextEmbeddingModel
|
8 |
+
import random
|
9 |
+
from collections import Counter
|
10 |
+
|
11 |
+
|
12 |
+
def softmax_weights(scores, temperature=1.0):
|
13 |
+
scores = np.array(scores)
|
14 |
+
scores = scores / temperature
|
15 |
+
e_scores = np.exp(scores - np.max(scores))
|
16 |
+
return e_scores / np.sum(e_scores)
|
17 |
+
|
18 |
+
def normalize_fuzzy_cnt(fuzzy_cnt):
|
19 |
+
total = sum(fuzzy_cnt.values())
|
20 |
+
if total == 0:
|
21 |
+
return fuzzy_cnt
|
22 |
+
for key in fuzzy_cnt:
|
23 |
+
fuzzy_cnt[key] /= total
|
24 |
+
return fuzzy_cnt
|
25 |
+
|
26 |
+
def class_type_boost(query_type, candidate_type):
|
27 |
+
if query_type == candidate_type:
|
28 |
+
return 1.3
|
29 |
+
elif abs(query_type - candidate_type) == 1:
|
30 |
+
return 1.1
|
31 |
+
elif abs(query_type - candidate_type) == 2:
|
32 |
+
return 0.9
|
33 |
+
else:
|
34 |
+
return 0.8
|
35 |
+
|
36 |
+
def set_seed(seed):
|
37 |
+
torch.manual_seed(seed)
|
38 |
+
torch.cuda.manual_seed(seed)
|
39 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
40 |
+
np.random.seed(seed) # Numpy module.
|
41 |
+
random.seed(seed) # Python random module.
|
42 |
+
|
43 |
+
def load_pkl(path):
|
44 |
+
with open(path, 'rb') as f:
|
45 |
+
return pickle.load(f)
|
46 |
+
|
47 |
+
def infer_3_class(model, tokenizer, index, label_dict, is_mixed_dict, text, K):
|
48 |
+
# model = TextEmbeddingModel(opt.model_name).cuda()
|
49 |
+
# state_dict = torch.load(opt.model_path, map_location=model.model.device)
|
50 |
+
# new_state_dict={}
|
51 |
+
# for key in state_dict.keys():
|
52 |
+
# if key.startswith('model.'):
|
53 |
+
# new_state_dict[key[6:]]=state_dict[key]
|
54 |
+
# model.load_state_dict(state_dict)
|
55 |
+
# tokenizer=model.tokenizer
|
56 |
+
|
57 |
+
# index = Indexer(opt.embedding_dim)
|
58 |
+
# index.deserialize_from(opt.database_path)
|
59 |
+
# label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
|
60 |
+
# is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
|
61 |
+
|
62 |
+
# text = opt.text
|
63 |
+
encoded_text = tokenizer.batch_encode_plus(
|
64 |
+
[text],
|
65 |
+
return_tensors="pt",
|
66 |
+
max_length=512,
|
67 |
+
padding="max_length",
|
68 |
+
truncation=True,
|
69 |
+
)
|
70 |
+
encoded_text = {k: v for k, v in encoded_text.items()}
|
71 |
+
embeddings = model(encoded_text).cpu().detach().numpy()
|
72 |
+
top_ids_and_scores = index.search_knn(embeddings, K)
|
73 |
+
pred = []
|
74 |
+
for i, (ids, scores) in enumerate(top_ids_and_scores):
|
75 |
+
print(f"Top {K} results for text:")
|
76 |
+
sorted_scores = np.argsort(scores)
|
77 |
+
sorted_scores = sorted_scores[::-1]
|
78 |
+
|
79 |
+
topk_ids = [ids[j] for j in sorted_scores]
|
80 |
+
topk_scores = [scores[j] for j in sorted_scores]
|
81 |
+
weights = softmax_weights(topk_scores, temperature=0.1)
|
82 |
+
|
83 |
+
candidate_models = [is_mixed_dict[int(_id)] for _id in topk_ids]
|
84 |
+
initial_pred = Counter(candidate_models).most_common(1)[0][0]
|
85 |
+
|
86 |
+
fuzzy_cnt = {(1,0): 0.0, (0,10^3): 0.0, (1,1): 0.0}
|
87 |
+
for id, weight in zip(topk_ids, weights):
|
88 |
+
label = (label_dict[int(id)], is_mixed_dict[int(id)])
|
89 |
+
boost = class_type_boost(is_mixed_dict[int(id)],initial_pred)
|
90 |
+
fuzzy_cnt[label] += weight * boost
|
91 |
+
|
92 |
+
final = max(fuzzy_cnt, key=fuzzy_cnt.get)
|
93 |
+
|
94 |
+
# print(f"Top {opt.K} results for text:")
|
95 |
+
# cnt = {(1,0):0,(0,10^3):0,(1,1):0}
|
96 |
+
# for j, (id, score) in enumerate(zip(ids, scores)):
|
97 |
+
# print(f"{j+1}. ID {id} Label {label_dict[int(id)]} Is_mixed {is_mixed_dict[int(id)]} Score {score}")
|
98 |
+
# cnt[(label_dict[int(id)], is_mixed_dict[int(id)])]+=1
|
99 |
+
# final = max(cnt, key=cnt.get)
|
100 |
+
# pred.append(final)
|
101 |
+
if final==(1,0):
|
102 |
+
print("Human")
|
103 |
+
return 0
|
104 |
+
elif final==(0,10^3):
|
105 |
+
print("AI")
|
106 |
+
return 1
|
107 |
+
else:
|
108 |
+
print("Mixed")
|
109 |
+
return 2
|
110 |
+
# pred.append(final)
|
111 |
+
return -1
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
parser = argparse.ArgumentParser()
|
116 |
+
parser.add_argument('--embedding_dim', type=int, default=768)
|
117 |
+
parser.add_argument('--database_path', type=str, default="database", help="Path to the index file")
|
118 |
+
|
119 |
+
parser.add_argument("--model_path", type=str, default="core/model.pth",\
|
120 |
+
help="Path to the embedding model checkpoint")
|
121 |
+
parser.add_argument('--model_name', type=str, default="ZurichNLPZurichNLP/unsup-simcse-xlm-roberta-base", help="Model name")
|
122 |
+
|
123 |
+
parser.add_argument('--K', type=int, default=20, help="Search [1,K] nearest neighbors,choose the best K")
|
124 |
+
parser.add_argument('--pooling', type=str, default="average", help="Pooling method, average or cls")
|
125 |
+
parser.add_argument('--text', type=str, default="")
|
126 |
+
parser.add_argument('--seed', type=int, default=0)
|
127 |
+
|
128 |
+
opt = parser.parse_args()
|
129 |
+
set_seed(opt.seed)
|
130 |
+
infer(opt)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas~=2.0.3
|
2 |
+
tqdm~=4.66.4
|
3 |
+
torch~=2.3.0
|
4 |
+
transformers~=4.41.1
|
5 |
+
scikit-learn~=1.3.2
|
6 |
+
datasets~=2.19.1
|
7 |
+
nltk~=3.8.1
|
8 |
+
tiktoken~=0.7.0
|
9 |
+
faiss-cpu
|
10 |
+
uvicorn
|
11 |
+
fastapi
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/__init__.py
ADDED
File without changes
|
src/index.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import faiss
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
class Indexer(object):
|
16 |
+
|
17 |
+
def __init__(self, vector_sz,device='cpu'):
|
18 |
+
self.index = faiss.IndexFlatIP(vector_sz)
|
19 |
+
self.device = device
|
20 |
+
if self.device == 'cuda':
|
21 |
+
self.index = faiss.index_cpu_to_all_gpus(self.index)
|
22 |
+
self.index_id_to_db_id = []
|
23 |
+
|
24 |
+
def index_data(self, ids, embeddings):
|
25 |
+
self._update_id_mapping(ids)
|
26 |
+
embeddings = embeddings.astype('float32')
|
27 |
+
if not self.index.is_trained:
|
28 |
+
self.index.train(embeddings)
|
29 |
+
self.index.add(embeddings)
|
30 |
+
|
31 |
+
print(f'Total data indexed {self.index.ntotal}')
|
32 |
+
|
33 |
+
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
|
34 |
+
query_vectors = query_vectors.astype('float32')
|
35 |
+
result = []
|
36 |
+
nbatch = (len(query_vectors)-1) // index_batch_size + 1
|
37 |
+
for k in tqdm(range(nbatch)):
|
38 |
+
start_idx = k*index_batch_size
|
39 |
+
end_idx = min((k+1)*index_batch_size, len(query_vectors))
|
40 |
+
q = query_vectors[start_idx: end_idx]
|
41 |
+
scores, indexes = self.index.search(q, top_docs)
|
42 |
+
# convert to external ids
|
43 |
+
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
|
44 |
+
result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
|
45 |
+
return result
|
46 |
+
|
47 |
+
def serialize(self, dir_path):
|
48 |
+
index_file = os.path.join(dir_path, 'index.faiss')
|
49 |
+
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
50 |
+
print(f'Serializing index to {index_file}, meta data to {meta_file}')
|
51 |
+
if self.device == 'cuda':
|
52 |
+
save_index = faiss.index_gpu_to_cpu(self.index)
|
53 |
+
else:
|
54 |
+
save_index = self.index
|
55 |
+
faiss.write_index(save_index, index_file)
|
56 |
+
with open(meta_file, mode='wb') as f:
|
57 |
+
pickle.dump(self.index_id_to_db_id, f)
|
58 |
+
|
59 |
+
def deserialize_from(self, dir_path):
|
60 |
+
index_file = os.path.join(dir_path, 'index.faiss')
|
61 |
+
meta_file = os.path.join(dir_path, 'index_meta.faiss')
|
62 |
+
print(f'Loading index from {index_file}, meta data from {meta_file}')
|
63 |
+
|
64 |
+
self.index = faiss.read_index(index_file)
|
65 |
+
if self.device == 'cuda':
|
66 |
+
self.index = faiss.index_cpu_to_all_gpus(self.index)
|
67 |
+
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
|
68 |
+
|
69 |
+
with open(meta_file, "rb") as reader:
|
70 |
+
self.index_id_to_db_id = pickle.load(reader)
|
71 |
+
assert len(
|
72 |
+
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
|
73 |
+
|
74 |
+
def _update_id_mapping(self, db_ids: List):
|
75 |
+
self.index_id_to_db_id.extend(db_ids)
|
76 |
+
|
77 |
+
def reset(self):
|
78 |
+
self.index.reset()
|
79 |
+
self.index_id_to_db_id = []
|
80 |
+
print(f'Index reset, total data indexed {self.index.ntotal}')
|
src/simclr.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from src.text_embedding import TextEmbeddingModel
|
5 |
+
|
6 |
+
class ClassificationHead(nn.Module):
|
7 |
+
"""Head for sentence-level classification tasks."""
|
8 |
+
|
9 |
+
def __init__(self, in_dim, out_dim):
|
10 |
+
super(ClassificationHead, self).__init__()
|
11 |
+
self.dense1 = nn.Linear(in_dim, in_dim//4)
|
12 |
+
self.dense2 = nn.Linear(in_dim//4, in_dim//16)
|
13 |
+
self.out_proj = nn.Linear(in_dim//16, out_dim)
|
14 |
+
|
15 |
+
nn.init.xavier_uniform_(self.dense1.weight)
|
16 |
+
nn.init.xavier_uniform_(self.dense2.weight)
|
17 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
18 |
+
nn.init.normal_(self.dense1.bias, std=1e-6)
|
19 |
+
nn.init.normal_(self.dense2.bias, std=1e-6)
|
20 |
+
nn.init.normal_(self.out_proj.bias, std=1e-6)
|
21 |
+
|
22 |
+
def forward(self, features):
|
23 |
+
x = features
|
24 |
+
x = self.dense1(x)
|
25 |
+
x = torch.tanh(x)
|
26 |
+
x = self.dense2(x)
|
27 |
+
x = torch.tanh(x)
|
28 |
+
x = self.out_proj(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
class SimCLR_Classifier_SCL(nn.Module):
|
32 |
+
def __init__(self, opt,fabric):
|
33 |
+
super(SimCLR_Classifier_SCL, self).__init__()
|
34 |
+
|
35 |
+
self.temperature = opt.temperature
|
36 |
+
self.opt=opt
|
37 |
+
self.fabric = fabric
|
38 |
+
self.model = TextEmbeddingModel(opt.model_name)
|
39 |
+
self.device=self.model.model.device
|
40 |
+
if opt.resum:
|
41 |
+
state_dict = torch.load(opt.pth_path, map_location=self.device)
|
42 |
+
self.model.load_state_dict(state_dict)
|
43 |
+
self.esp=torch.tensor(1e-6,device=self.device)
|
44 |
+
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
|
45 |
+
|
46 |
+
self.a=torch.tensor(opt.a,device=self.device)
|
47 |
+
self.d=torch.tensor(opt.d,device=self.device)
|
48 |
+
self.only_classifier=opt.only_classifier
|
49 |
+
|
50 |
+
|
51 |
+
def get_encoder(self):
|
52 |
+
return self.model
|
53 |
+
|
54 |
+
def _compute_logits(self, q,q_index1, q_index2,q_label,k,k_index1,k_index2,k_label):
|
55 |
+
def cosine_similarity_matrix(q, k):
|
56 |
+
|
57 |
+
q_norm = F.normalize(q,dim=-1)
|
58 |
+
k_norm = F.normalize(k,dim=-1)
|
59 |
+
cosine_similarity = q_norm@k_norm.T
|
60 |
+
|
61 |
+
return cosine_similarity
|
62 |
+
|
63 |
+
logits=cosine_similarity_matrix(q,k)/self.temperature
|
64 |
+
|
65 |
+
q_labels=q_label.view(-1, 1)# N,1
|
66 |
+
k_labels=k_label.view(1, -1)# 1,N+K
|
67 |
+
|
68 |
+
same_label=(q_labels==k_labels)# N,N+K
|
69 |
+
|
70 |
+
#model:model set
|
71 |
+
pos_logits_model = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
|
72 |
+
neg_logits_model=logits*torch.logical_not(same_label)
|
73 |
+
logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1)
|
74 |
+
|
75 |
+
return logits_model
|
76 |
+
|
77 |
+
def forward(self, batch, indices1, indices2,label):
|
78 |
+
bsz = batch['input_ids'].size(0)
|
79 |
+
q = self.model(batch)
|
80 |
+
k = q.clone().detach()
|
81 |
+
k = self.fabric.all_gather(k).view(-1, k.size(1))
|
82 |
+
k_label = self.fabric.all_gather(label).view(-1)
|
83 |
+
k_index1 = self.fabric.all_gather(indices1).view(-1)
|
84 |
+
k_index2 = self.fabric.all_gather(indices2).view(-1)
|
85 |
+
#q:N
|
86 |
+
#k:4N
|
87 |
+
logits_label = self._compute_logits(q,indices1, indices2,label,k,k_index1,k_index2,k_label)
|
88 |
+
|
89 |
+
out = self.classifier(q)
|
90 |
+
|
91 |
+
if self.opt.AA:
|
92 |
+
loss_classfiy = F.cross_entropy(out, indices1)
|
93 |
+
else:
|
94 |
+
loss_classfiy = F.cross_entropy(out, label)
|
95 |
+
|
96 |
+
gt = torch.zeros(bsz, dtype=torch.long,device=logits_label.device)
|
97 |
+
|
98 |
+
if self.only_classifier:
|
99 |
+
loss_label = torch.tensor(0,device=self.device)
|
100 |
+
else:
|
101 |
+
loss_label = F.cross_entropy(logits_label, gt)
|
102 |
+
|
103 |
+
loss = self.a*loss_label+self.d*loss_classfiy
|
104 |
+
if self.training:
|
105 |
+
return loss,loss_label,loss_classfiy,k,k_label
|
106 |
+
else:
|
107 |
+
out = self.fabric.all_gather(out).view(-1, out.size(1))
|
108 |
+
return loss,out,k,k_label
|
109 |
+
|
110 |
+
|
111 |
+
class SimCLR_Classifier_test(nn.Module):
|
112 |
+
def __init__(self, opt,fabric):
|
113 |
+
super(SimCLR_Classifier_test, self).__init__()
|
114 |
+
|
115 |
+
self.fabric = fabric
|
116 |
+
self.model = TextEmbeddingModel(opt.model_name)
|
117 |
+
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
|
118 |
+
self.device=self.model.model.device
|
119 |
+
|
120 |
+
def forward(self, batch):
|
121 |
+
q = self.model(batch)
|
122 |
+
out = self.classifier(q)
|
123 |
+
return out
|
124 |
+
|
125 |
+
class SimCLR_Classifier(nn.Module):
|
126 |
+
def __init__(self, opt,fabric):
|
127 |
+
super(SimCLR_Classifier, self).__init__()
|
128 |
+
|
129 |
+
self.temperature = opt.temperature
|
130 |
+
self.opt=opt
|
131 |
+
self.fabric = fabric
|
132 |
+
|
133 |
+
self.model = TextEmbeddingModel(opt.model_name)
|
134 |
+
if opt.resum:
|
135 |
+
state_dict = torch.load(opt.pth_path,
|
136 |
+
map_location=self.model.device)
|
137 |
+
self.model.load_state_dict(state_dict)
|
138 |
+
|
139 |
+
self.device = self.model.model.device
|
140 |
+
self.esp = torch.tensor(1e-6,device=self.device)
|
141 |
+
self.a = torch.tensor(opt.a,
|
142 |
+
device=self.device)
|
143 |
+
self.b = torch.tensor(opt.b,
|
144 |
+
device=self.device)
|
145 |
+
self.c = torch.tensor(opt.c,
|
146 |
+
device=self.device)
|
147 |
+
|
148 |
+
self.classifier = ClassificationHead(opt.projection_size,
|
149 |
+
opt.classifier_dim)
|
150 |
+
self.only_classifier = opt.only_classifier
|
151 |
+
|
152 |
+
|
153 |
+
def get_encoder(self):
|
154 |
+
return self.model
|
155 |
+
|
156 |
+
def _compute_logits(self,
|
157 |
+
q,q_index1, q_index2, q_label,
|
158 |
+
k,k_index1,k_index2,k_label):
|
159 |
+
def cosine_similarity_matrix(q, k):
|
160 |
+
|
161 |
+
q_norm = F.normalize(q,dim=-1)
|
162 |
+
k_norm = F.normalize(k,dim=-1)
|
163 |
+
cosine_similarity = q_norm@k_norm.T
|
164 |
+
return cosine_similarity
|
165 |
+
|
166 |
+
logits=cosine_similarity_matrix(q,k)/self.temperature
|
167 |
+
|
168 |
+
q_index1=q_index1.view(-1, 1)# change to tensor of size N, 1
|
169 |
+
q_index2=q_index2.view(-1, 1)# change to tensor of size N, 1
|
170 |
+
q_labels=q_label.view(-1, 1)# change to tensor of size N, 1
|
171 |
+
|
172 |
+
k_index1=k_index1.view(1, -1)# 1,N+K
|
173 |
+
k_index2=k_index2.view(1, -1) #1, N+K
|
174 |
+
k_labels=k_label.view(1, -1)# 1,N+K
|
175 |
+
|
176 |
+
same_mixed = (q_index1== k_index1)
|
177 |
+
same_set=(q_index2==k_index2)# N,N+K
|
178 |
+
same_label=(q_labels==k_labels)# N,N+K
|
179 |
+
|
180 |
+
is_human=(q_label==1).view(-1)
|
181 |
+
is_machine=(q_label==0).view(-1)
|
182 |
+
|
183 |
+
is_mixed=(q_index1==1).view(-1)
|
184 |
+
|
185 |
+
#human: human
|
186 |
+
pos_logits_human = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
|
187 |
+
neg_logits_human=logits*torch.logical_not(same_label)
|
188 |
+
logits_human=torch.cat((pos_logits_human.unsqueeze(1), neg_logits_human), dim=1)
|
189 |
+
logits_human=logits_human[is_human]
|
190 |
+
|
191 |
+
#human+ai: general
|
192 |
+
pos_logits_mixed = torch.sum(logits*same_mixed,dim=1)/torch.maximum(torch.sum(same_mixed,dim=1),self.esp)
|
193 |
+
neg_logits_mixed=logits*torch.logical_not(same_mixed)
|
194 |
+
logits_mixed=torch.cat((pos_logits_mixed.unsqueeze(1), neg_logits_mixed), dim=1)
|
195 |
+
logits_mixed=logits_mixed[is_mixed]
|
196 |
+
|
197 |
+
#human+ai: model
|
198 |
+
pos_logits_mixed_set = torch.sum(logits*torch.logical_and(same_mixed, same_set),dim=1)/torch.max(torch.sum(torch.logical_and(same_mixed, same_set),dim=1),self.esp)
|
199 |
+
neg_logits_mixed_set=logits*torch.logical_not(torch.logical_and(same_mixed, same_set))
|
200 |
+
logits_mixed_set=torch.cat((pos_logits_mixed_set.unsqueeze(1), neg_logits_mixed_set), dim=1)
|
201 |
+
logits_mixed_set=logits_mixed_set[is_mixed]
|
202 |
+
|
203 |
+
#model set:label
|
204 |
+
pos_logits_set = torch.sum(logits*same_set,dim=1)/torch.max(torch.sum(same_set,dim=1),self.esp)
|
205 |
+
neg_logits_set=logits*torch.logical_not(same_set)
|
206 |
+
logits_set=torch.cat((pos_logits_set.unsqueeze(1), neg_logits_set), dim=1)
|
207 |
+
logits_set=logits_set[is_machine]
|
208 |
+
|
209 |
+
#label: label
|
210 |
+
pos_logits_label = torch.sum(logits*same_label, dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
|
211 |
+
neg_logits_label=logits*torch.logical_not(same_label)
|
212 |
+
logits_label=torch.cat((pos_logits_label.unsqueeze(1), neg_logits_label), dim=1)
|
213 |
+
logits_label=logits_label[is_machine]
|
214 |
+
|
215 |
+
return logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label
|
216 |
+
|
217 |
+
def forward(self, encoded_batch, label, indices1, indices2):#, weights):
|
218 |
+
# print(len(text))
|
219 |
+
q = self.model(encoded_batch)
|
220 |
+
k = q.clone().detach()
|
221 |
+
k = self.fabric.all_gather(k).view(-1, k.size(1))
|
222 |
+
k_label = self.fabric.all_gather(label).view(-1)
|
223 |
+
k_index1 = self.fabric.all_gather(indices1).view(-1)
|
224 |
+
k_index2 = self.fabric.all_gather(indices2).view(-1)
|
225 |
+
#q:N
|
226 |
+
#k:4N
|
227 |
+
logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label = self._compute_logits(q,indices1, indices2,label,
|
228 |
+
k,k_index1,k_index2,k_label)
|
229 |
+
out = self.classifier(q)
|
230 |
+
|
231 |
+
if self.opt.AA:
|
232 |
+
loss_classfiy = F.cross_entropy(out, indices1)
|
233 |
+
else:
|
234 |
+
loss_classfiy = F.cross_entropy(out, label) #, weight=weights)
|
235 |
+
|
236 |
+
gt_mixed = torch.zeros(logits_mixed.size(0),
|
237 |
+
dtype=torch.long,
|
238 |
+
device=logits_mixed.device)
|
239 |
+
gt_mixed_set = torch.zeros(logits_mixed_set.size(0),
|
240 |
+
dtype=torch.long,
|
241 |
+
device=logits_mixed_set.device)
|
242 |
+
gt_set = torch.zeros(logits_set.size(0),
|
243 |
+
dtype=torch.long,
|
244 |
+
device=logits_set.device)
|
245 |
+
gt_label = torch.zeros(logits_label.size(0),
|
246 |
+
dtype=torch.long,
|
247 |
+
device=logits_label.device)
|
248 |
+
gt_human = torch.zeros(logits_human.size(0),
|
249 |
+
dtype=torch.long,
|
250 |
+
device=logits_human.device)
|
251 |
+
|
252 |
+
|
253 |
+
loss_mixed = F.cross_entropy(logits_mixed,
|
254 |
+
gt_mixed)
|
255 |
+
loss_mixed_set = F.cross_entropy(logits_mixed_set,
|
256 |
+
gt_mixed_set)
|
257 |
+
loss_set = F.cross_entropy(logits_set,
|
258 |
+
gt_set)
|
259 |
+
loss_label = F.cross_entropy(logits_label,
|
260 |
+
gt_label)
|
261 |
+
if logits_human.numel()!=0:
|
262 |
+
loss_human = F.cross_entropy(logits_human.to(torch.float64),
|
263 |
+
gt_human)
|
264 |
+
else:
|
265 |
+
loss_human=torch.tensor(0,device=self.device)
|
266 |
+
|
267 |
+
loss = self.a*loss_set + (4*self.b-self.a)*loss_label + self.b*loss_human+ self.b*loss_mixed + \
|
268 |
+
2*self.b*loss_mixed_set+self.c*loss_classfiy
|
269 |
+
|
270 |
+
if self.training:
|
271 |
+
if self.opt.AA:
|
272 |
+
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_human,loss_classfiy,k,k_index1
|
273 |
+
else:
|
274 |
+
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_classfiy,loss_human,k,k_label
|
275 |
+
else:
|
276 |
+
out = self.fabric.all_gather(out).view(-1, out.size(1))
|
277 |
+
if self.opt.AA:
|
278 |
+
return loss,out,k,k_index1
|
279 |
+
else:
|
280 |
+
return loss,out,k,k_label
|
src/text_embedding.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
|
5 |
+
class TextEmbeddingModel(nn.Module):
|
6 |
+
def __init__(self, model_name,output_hidden_states=False):
|
7 |
+
super(TextEmbeddingModel, self).__init__()
|
8 |
+
self.model_name = model_name
|
9 |
+
if output_hidden_states:
|
10 |
+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, output_hidden_states=True)
|
11 |
+
else:
|
12 |
+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
|
15 |
+
def pooling(self, model_output, attention_mask, use_pooling='average',hidden_states=False):
|
16 |
+
if hidden_states:
|
17 |
+
model_output.masked_fill(~attention_mask[None,..., None].bool(), 0.0)
|
18 |
+
if use_pooling == "average":
|
19 |
+
emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None]
|
20 |
+
else:
|
21 |
+
emb = model_output[:,:, 0]
|
22 |
+
emb = emb.permute(1, 0, 2)
|
23 |
+
else:
|
24 |
+
model_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
25 |
+
if use_pooling == "average":
|
26 |
+
emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
27 |
+
elif use_pooling == "cls":
|
28 |
+
emb = model_output[:, 0]
|
29 |
+
return emb
|
30 |
+
|
31 |
+
def forward(self, encoded_batch, use_pooling='average',hidden_states=False):
|
32 |
+
if "t5" in self.model_name.lower():
|
33 |
+
input_ids = encoded_batch['input_ids']
|
34 |
+
decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
|
35 |
+
model_output = self.model(**encoded_batch,
|
36 |
+
decoder_input_ids=decoder_input_ids)
|
37 |
+
else:
|
38 |
+
model_output = self.model(**encoded_batch)
|
39 |
+
|
40 |
+
if 'bge' in self.model_name.lower() or 'mxbai' in self.model_name.lower():
|
41 |
+
use_pooling = 'cls'
|
42 |
+
if isinstance(model_output, tuple):
|
43 |
+
model_output = model_output[0]
|
44 |
+
if isinstance(model_output, dict):
|
45 |
+
if hidden_states:
|
46 |
+
model_output = model_output["hidden_states"]
|
47 |
+
model_output = torch.stack(model_output, dim=0)
|
48 |
+
else:
|
49 |
+
model_output = model_output["last_hidden_state"]
|
50 |
+
|
51 |
+
emb = self.pooling(model_output, encoded_batch['attention_mask'], use_pooling,hidden_states)
|
52 |
+
emb = torch.nn.functional.normalize(emb, dim=-1)
|
53 |
+
return emb
|
54 |
+
|
55 |
+
|
utils/__init__.py
ADDED
File without changes
|
utils/load_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
import hashlib
|
7 |
+
|
8 |
+
def stable_long_hash(input_string):
|
9 |
+
hash_object = hashlib.sha256(input_string.encode())
|
10 |
+
hex_digest = hash_object.hexdigest()
|
11 |
+
int_hash = int(hex_digest, 16)
|
12 |
+
long_long_hash = (int_hash & ((1 << 63) - 1))
|
13 |
+
return long_long_hash
|
14 |
+
|
15 |
+
model_map_authscan = {
|
16 |
+
"gpt-4o-mini-text": 1,
|
17 |
+
"gemini-2.0-text": 2,
|
18 |
+
"deepseek-text": 3,
|
19 |
+
"llama-text": 4
|
20 |
+
}
|
21 |
+
|
22 |
+
model_map_llmdetectaive = {
|
23 |
+
"gemma-text": 1,
|
24 |
+
"mixtral-text": 2,
|
25 |
+
"llama3-text": 3
|
26 |
+
}
|
27 |
+
|
28 |
+
model_map_hart = {
|
29 |
+
"claude-text": 1,
|
30 |
+
"gemini-text": 2,
|
31 |
+
"gpt-text": 3
|
32 |
+
}
|
33 |
+
|
34 |
+
def load_dataset(dataset_name,path=None):
|
35 |
+
dataset = {
|
36 |
+
"train": [],
|
37 |
+
"valid": [],
|
38 |
+
"test": []
|
39 |
+
}
|
40 |
+
if dataset_name == "falconset":
|
41 |
+
model_map = model_map_authscan
|
42 |
+
elif dataset_name == "llmdetectaive":
|
43 |
+
model_map = model_map_llmdetectaive
|
44 |
+
elif dataset_name == "hart":
|
45 |
+
model_map = model_map_hart
|
46 |
+
|
47 |
+
folder = os.listdir(path)
|
48 |
+
# print(folder)
|
49 |
+
for sub in folder:
|
50 |
+
sub_path = os.path.join(path, sub)
|
51 |
+
files = os.listdir(sub_path)
|
52 |
+
for file in files:
|
53 |
+
if not file.endswith('.jsonl'):
|
54 |
+
continue
|
55 |
+
file_path = os.path.join(sub_path, file)
|
56 |
+
key_name = file.split('.')[0]
|
57 |
+
|
58 |
+
assert key_name in dataset.keys(), f'{key_name} is not in dataset.keys()'
|
59 |
+
with open(file_path, 'r') as f:
|
60 |
+
data = [json.loads(line) for line in f]
|
61 |
+
for i in range(len(data)):
|
62 |
+
dct = {}
|
63 |
+
dct['text'] = data[i]['text']
|
64 |
+
if sub == "human-text":
|
65 |
+
dct['label'] = "human"
|
66 |
+
dct['label_detailed'] = "human"
|
67 |
+
dct['index'] = (1,0,0)
|
68 |
+
elif sub.startswith("human---"):
|
69 |
+
dct['label'] = "human+AI"
|
70 |
+
model = sub.split("---")[1]
|
71 |
+
dct['label_detailed'] = model
|
72 |
+
dct['index'] = (1, 1, model_map[model])
|
73 |
+
else:
|
74 |
+
dct['label'] = "AI"
|
75 |
+
dct['label_detailed'] = sub
|
76 |
+
dct['index'] = (0, 10^3, model_map[sub])
|
77 |
+
dataset[key_name].append(dct)
|
78 |
+
return dataset
|
79 |
+
|
80 |
+
def load_outdomain_dataset(path):
|
81 |
+
dataset = {
|
82 |
+
"valid": [],
|
83 |
+
"test": []
|
84 |
+
}
|
85 |
+
folder = os.listdir(path)
|
86 |
+
for sub in folder:
|
87 |
+
sub_path = os.path.join(path, sub)
|
88 |
+
files = os.listdir(sub_path)
|
89 |
+
for file in files:
|
90 |
+
if not file.endswith('.jsonl'):
|
91 |
+
continue
|
92 |
+
file_path = os.path.join(sub_path, file)
|
93 |
+
key_name = file.split('.')[0]
|
94 |
+
assert key_name in dataset.keys(), f'{key_name} is not in dataset.keys()'
|
95 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
96 |
+
data = [json.loads(line) for line in f]
|
97 |
+
for i in range(len(data)):
|
98 |
+
dct = {}
|
99 |
+
dct['text'] = data[i]['text']
|
100 |
+
if sub == "human-text":
|
101 |
+
dct['label'] = "human"
|
102 |
+
dct['label_detailed'] = "human"
|
103 |
+
dct['index'] = (1,0)
|
104 |
+
elif sub.startswith("human---"):
|
105 |
+
dct['label'] = "human+AI"
|
106 |
+
model = sub.split("---")[1]
|
107 |
+
dct['label_detailed'] = model
|
108 |
+
dct['index'] = (1, 1)
|
109 |
+
else:
|
110 |
+
dct['label'] = "AI"
|
111 |
+
dct['label_detailed'] = sub
|
112 |
+
dct['index'] = (0, 10^3)
|
113 |
+
dataset[key_name].append(dct)
|
114 |
+
return dataset
|
115 |
+
|
116 |
+
def load_dataset_conditional_lang(path=None, language='vi', seed=42):
|
117 |
+
dataset = {
|
118 |
+
"train": [],
|
119 |
+
"val": [],
|
120 |
+
"test": []
|
121 |
+
}
|
122 |
+
combined_data = []
|
123 |
+
|
124 |
+
random.seed(seed) # for reproducibility
|
125 |
+
folder = os.listdir(path)
|
126 |
+
print("Subfolders:", folder)
|
127 |
+
|
128 |
+
for sub in folder:
|
129 |
+
sub_path = os.path.join(path, sub)
|
130 |
+
if not os.path.isdir(sub_path):
|
131 |
+
continue
|
132 |
+
files = os.listdir(sub_path)
|
133 |
+
|
134 |
+
for file in files:
|
135 |
+
if not file.endswith('.jsonl') or language not in file:
|
136 |
+
continue
|
137 |
+
|
138 |
+
file_path = os.path.join(sub_path, file)
|
139 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
140 |
+
data = [json.loads(line) for line in f]
|
141 |
+
|
142 |
+
for entry in data:
|
143 |
+
if 'content' not in entry:
|
144 |
+
print("Key does not exist!")
|
145 |
+
continue
|
146 |
+
|
147 |
+
dct = {}
|
148 |
+
dct['text'] = entry['content']
|
149 |
+
|
150 |
+
if sub == "human":
|
151 |
+
dct['label'] = "human"
|
152 |
+
dct['label_detailed'] = "human"
|
153 |
+
dct['index'] = (1, 0, 0)
|
154 |
+
elif sub == "human+AI":
|
155 |
+
model = entry['label_detailed'].split("+")[1]
|
156 |
+
dct['label'] = "human+AI"
|
157 |
+
dct['label_detailed'] = model
|
158 |
+
dct['index'] = (1, 1, model_map[model])
|
159 |
+
else:
|
160 |
+
dct['label'] = "AI"
|
161 |
+
dct['label_detailed'] = entry['label_detailed']
|
162 |
+
dct['index'] = (0, 10**3, model_map[entry['label_detailed']])
|
163 |
+
|
164 |
+
combined_data.append(dct)
|
165 |
+
|
166 |
+
random.shuffle(combined_data)
|
167 |
+
total = len(combined_data)
|
168 |
+
train_end = int(total * 0.9)
|
169 |
+
val_end = train_end + int(total * 0.05)
|
170 |
+
|
171 |
+
dataset['train'] = combined_data[:train_end]
|
172 |
+
dataset['val'] = combined_data[train_end:val_end]
|
173 |
+
dataset['test'] = combined_data[val_end:]
|
174 |
+
|
175 |
+
print(f"Total: {total} | Train: {len(dataset['train'])} | Val: {len(dataset['val'])} | Test: {len(dataset['test'])}")
|
176 |
+
return dataset
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class TextDataset(Dataset):
|
181 |
+
def __init__(self, dataset,need_ids=True,out_domain=0):
|
182 |
+
self.dataset = dataset
|
183 |
+
self.need_ids=need_ids
|
184 |
+
self.out_domain = out_domain
|
185 |
+
|
186 |
+
def get_class(self):
|
187 |
+
return self.classes
|
188 |
+
|
189 |
+
def __len__(self):
|
190 |
+
return len(self.dataset)
|
191 |
+
|
192 |
+
def __getitem__(self, idx):
|
193 |
+
text, label, label_detailed, index = self.dataset[idx].values()
|
194 |
+
id = stable_long_hash(text)
|
195 |
+
if self.out_domain:
|
196 |
+
label, is_mixed = index
|
197 |
+
if self.need_ids:
|
198 |
+
return int(id), text, int(label), int(is_mixed)
|
199 |
+
return text, int(label), int(is_mixed)
|
200 |
+
else:
|
201 |
+
label, is_mixed, write_model = index
|
202 |
+
if self.need_ids:
|
203 |
+
return int(id), text, int(label), int(is_mixed), int(write_model)
|
204 |
+
return text, int(label), int(is_mixed), int(write_model)
|
205 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, mean_absolute_error, hamming_loss
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
4 |
+
|
5 |
+
def find_top_n(embeddings,n,index,data):
|
6 |
+
if len(embeddings.shape) == 1:
|
7 |
+
embeddings = embeddings.reshape(1, -1)
|
8 |
+
top_ids_and_scores = index.search_knn(embeddings, n)
|
9 |
+
data_ans=[]
|
10 |
+
for i, (ids, scores) in enumerate(top_ids_and_scores):
|
11 |
+
data_now=[]
|
12 |
+
for id in ids:
|
13 |
+
data_now.append((data[0][int(id)],data[1][int(id)],data[2][int(id)]))
|
14 |
+
data_ans.append(data_now)
|
15 |
+
return data_ans
|
16 |
+
|
17 |
+
|
18 |
+
def print_line(class_name, metrics, is_header=False):
|
19 |
+
if is_header:
|
20 |
+
line = f"| {'Class':<10} | " + " | ".join([f"{metric:<10}" for metric in metrics])
|
21 |
+
else:
|
22 |
+
line = f"| {class_name:<10} | " + " | ".join([f"{metrics[metric]:<10.3f}" for metric in metrics])
|
23 |
+
print(line)
|
24 |
+
if is_header:
|
25 |
+
print('-' * len(line))
|
26 |
+
|
27 |
+
def calculate_per_class_metrics(classes, ground_truth, predictions):
|
28 |
+
# Convert ground truth and predictions to numeric format
|
29 |
+
gt_numeric = np.array([int(gt) for gt in ground_truth])
|
30 |
+
pred_numeric = np.array([int(pred) for pred in predictions])
|
31 |
+
|
32 |
+
results = {}
|
33 |
+
for i, class_name in enumerate(classes):
|
34 |
+
# For each class, calculate the 'vs rest' binary labels
|
35 |
+
gt_binary = (gt_numeric == i).astype(int)
|
36 |
+
pred_binary = (pred_numeric == i).astype(int)
|
37 |
+
|
38 |
+
# Calculate metrics, handling cases where a class is not present in predictions or ground truth
|
39 |
+
precision = precision_score(gt_binary, pred_binary, zero_division=0)
|
40 |
+
recall = recall_score(gt_binary, pred_binary, zero_division=0)
|
41 |
+
f1 = f1_score(gt_binary, pred_binary, zero_division=0)
|
42 |
+
acc = np.mean(gt_binary == pred_binary)
|
43 |
+
# Calculate recall for all other classes as 'rest'
|
44 |
+
rest_recall = recall_score(1 - gt_binary, 1 - pred_binary, zero_division=0)
|
45 |
+
|
46 |
+
results[class_name] = {
|
47 |
+
'Precision': precision,
|
48 |
+
'Recall': recall,
|
49 |
+
'F1 Score': f1,
|
50 |
+
'Accuracy': acc,
|
51 |
+
'Avg Recall (with rest)': (recall + rest_recall) / 2
|
52 |
+
}
|
53 |
+
|
54 |
+
print_line("Metric", results[classes[0]], is_header=True)
|
55 |
+
for class_name, metrics in results.items():
|
56 |
+
print_line(class_name, metrics)
|
57 |
+
overall_metrics = {metric_name: np.mean([metrics[metric_name] for metrics in results.values()]) for metric_name in results[classes[0]].keys()}
|
58 |
+
print_line("Overall", overall_metrics)
|
59 |
+
|
60 |
+
def calculate_metrics(y_true, y_pred):
|
61 |
+
accuracy = accuracy_score(y_true, y_pred)
|
62 |
+
avg_f1 = f1_score(y_true, y_pred, average='macro')
|
63 |
+
avg_recall = recall_score(y_true, y_pred, average='macro')
|
64 |
+
return accuracy, avg_f1,avg_recall
|
65 |
+
|
66 |
+
def compute_three_recalls(labels, preds):
|
67 |
+
all_n, all_p, tn, tp = 0, 0, 0, 0
|
68 |
+
for label, pred in zip(labels, preds):
|
69 |
+
if label == '0':
|
70 |
+
all_p += 1
|
71 |
+
if label == '1':
|
72 |
+
all_n += 1
|
73 |
+
if pred is not None and label == pred == '0':
|
74 |
+
tp += 1
|
75 |
+
if pred is not None and label == pred == '1':
|
76 |
+
tn += 1
|
77 |
+
if pred is None:
|
78 |
+
continue
|
79 |
+
machine_rec , human_rec= tp * 100 / all_p if all_p != 0 else 0, tn * 100 / all_n if all_n != 0 else 0
|
80 |
+
avg_rec = (human_rec + machine_rec) / 2
|
81 |
+
return (human_rec, machine_rec, avg_rec)
|
82 |
+
|
83 |
+
|
84 |
+
def compute_metrics(labels, preds,ids=None, full_labels=False):
|
85 |
+
if ids is not None:
|
86 |
+
# unique ids
|
87 |
+
dict_labels,dict_preds={},{}
|
88 |
+
for i in range(len(ids)):
|
89 |
+
dict_labels[ids[i]]=labels[i]
|
90 |
+
dict_preds[ids[i]]=preds[i]
|
91 |
+
labels=list(dict_labels.values())
|
92 |
+
preds=list(dict_preds.values())
|
93 |
+
|
94 |
+
if not full_labels:
|
95 |
+
labels_map = {(1,0): 0, (0,10^3): 1, (1,1): 2}
|
96 |
+
labels_bin = [labels_map[tup] for tup in labels]
|
97 |
+
preds_bin = [labels_map[tup] for tup in preds]
|
98 |
+
|
99 |
+
else:
|
100 |
+
labels_map ={
|
101 |
+
(1, 0, 0): 0, # Human
|
102 |
+
(0, 10^3, 1): 1, (0, 10^3, 2): 2, (0, 10^3, 3): 3, (0, 10^3, 4): 4, # AI
|
103 |
+
(1, 1, 1): 5, (1, 1, 2): 6, (1, 1, 3): 7, (1, 1, 4): 8 # Human+AI
|
104 |
+
}
|
105 |
+
labels_bin = [labels_map[tup] for tup in labels]
|
106 |
+
preds_bin = [labels_map[tup] for tup in preds]
|
107 |
+
acc = accuracy_score(labels_bin, preds_bin)
|
108 |
+
precision = precision_score(labels_bin, preds_bin, average="macro")
|
109 |
+
recall = recall_score(labels_bin, preds_bin, average="macro")
|
110 |
+
f1 = f1_score(labels_bin, preds_bin, average="macro")
|
111 |
+
mse = mean_squared_error(labels_bin, preds_bin)
|
112 |
+
mae = mean_absolute_error(labels_bin, preds_bin)
|
113 |
+
|
114 |
+
return (acc, precision, recall, f1, mse, mae)
|
115 |
+
|
116 |
+
def compute_metrics_train(labels, preds,ids=None):
|
117 |
+
if ids is not None:
|
118 |
+
# unique ids
|
119 |
+
dict_labels,dict_preds={},{}
|
120 |
+
for i in range(len(ids)):
|
121 |
+
dict_labels[ids[i]]=labels[i]
|
122 |
+
dict_preds[ids[i]]=preds[i]
|
123 |
+
labels=list(dict_labels.values())
|
124 |
+
preds=list(dict_preds.values())
|
125 |
+
|
126 |
+
human_rec, machine_rec, avg_rec = compute_three_recalls(labels, preds)
|
127 |
+
acc = accuracy_score(labels, preds)
|
128 |
+
precision = precision_score(labels, preds, average="macro")
|
129 |
+
recall = recall_score(labels, preds, average="macro")
|
130 |
+
f1 = f1_score(labels, preds, average="macro")
|
131 |
+
|
132 |
+
return (human_rec, machine_rec, avg_rec, acc, precision, recall, f1)
|