ekaterinatao commited on
Commit
f08f4e6
·
verified ·
1 Parent(s): 48394e3

Update utils/func.py

Browse files
Files changed (1) hide show
  1. utils/func.py +86 -1
utils/func.py CHANGED
@@ -2,4 +2,89 @@ import torch
2
  import faiss
3
  import numpy as np
4
  import datasets
5
- from transformers import AutoTokenizer, AutoModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import faiss
3
  import numpy as np
4
  import datasets
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+
10
+ def embed_bert_cls(text, model, tokenizer):
11
+ t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
12
+ with torch.no_grad():
13
+ model_output = model(**{k: v.to(model.device) for k, v in t.items()})
14
+ embeds = model_output.last_hidden_state[:, 0, :]
15
+ embeds = torch.nn.functional.normalize(embeds)
16
+ return embeds[0].cpu().numpy()
17
+
18
+
19
+ def get_ranked_docs(query, vec_query_base, data,
20
+ bi_model, bi_tok, cross_model, cross_tok):
21
+
22
+ vec_shape = vec_query_base.shape[1]
23
+ index = faiss.IndexFlatL2(vec_shape)
24
+ index.add(vec_query_base)
25
+ xq = embed_bert_cls(query, bi_model, bi_tok)
26
+ _, I = index.search(xq.reshape(1, vec_shape), 50)
27
+ corpus = [data[int(i)]['answer'] for i in I[0]]
28
+
29
+ queries = [query] * len(corpus)
30
+ tokenized_texts = cross_tok(
31
+ queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
32
+ ).to(device)
33
+
34
+ with torch.no_grad():
35
+ model_output = cross_model(
36
+ **{k: v.to(cross_model.device) for k, v in tokenized_texts.items()}
37
+ )
38
+ ce_scores = model_output.last_hidden_state[:, 0, :]
39
+ ce_scores = np.matmul(ce_scores, ce_scores.T)
40
+ scores = ce_scores.cpu().numpy()
41
+ scores_ix = np.argsort(scores)[::-1]
42
+
43
+ return corpus[scores_ix[0][0]]
44
+
45
+
46
+ def load_dataset(url='ekaterinatao/house_md_context3'):
47
+
48
+ dataset = datasets.load_dataset(url, split='train')
49
+ house_dataset = dataset.filter(lambda row: row['labels'] == 0)
50
+
51
+ return house_dataset
52
+
53
+
54
+ def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
55
+
56
+ cls_dataset = datasets.load_dataset(url, split='train')
57
+ cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset])
58
+
59
+ return cls_base
60
+
61
+
62
+ def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):
63
+
64
+ bi_model = AutoModel.from_pretrained(checkpoint)
65
+ bi_tok = AutoTokenizer.from_pretrained(checkpoint)
66
+
67
+ return bi_model, bi_tok
68
+
69
+
70
+ def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'):
71
+
72
+ cross_model = AutoModel.from_pretrained(checkpoint)
73
+ cross_tok = AutoTokenizer.from_pretrained(checkpoint)
74
+
75
+ return cross_model, cross_tok
76
+
77
+
78
+ def get_answer(message):
79
+
80
+ dataset = load_dataset()
81
+ cls_base = load_cls_base()
82
+ bi_enc_model = load_bi_enc_model()
83
+ cross_enc_model = load_cross_enc_model()
84
+
85
+ answer = get_ranked_docs(
86
+ query=message, vec_query_base=cls_base, data=dataset,
87
+ bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1],
88
+ cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1]
89
+ )
90
+ return answer