ekaterinatao commited on
Commit
f4a73c5
·
verified ·
1 Parent(s): 2e205ce

Update utils/func.py

Browse files
Files changed (1) hide show
  1. utils/func.py +26 -26
utils/func.py CHANGED
@@ -3,33 +3,44 @@ import faiss
3
  import numpy as np
4
  import datasets
5
  from transformers import AutoTokenizer, AutoModel
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
 
 
 
 
 
 
 
9
 
10
- def embed_bert_cls(text, model, tokenizer):
11
  t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
12
  with torch.no_grad():
13
  model_output = model(**{k: v.to(model.device) for k, v in t.items()})
14
  embeds = model_output.last_hidden_state[:, 0, :]
15
  embeds = torch.nn.functional.normalize(embeds)
 
16
  return embeds[0].cpu().numpy()
17
 
18
 
19
- def get_ranked_docs(query, vec_query_base, data,
20
- bi_model, bi_tok, cross_model, cross_tok):
 
 
 
21
 
22
  vec_shape = vec_query_base.shape[1]
23
  index = faiss.IndexFlatL2(vec_shape)
24
  index.add(vec_query_base)
25
  xq = embed_bert_cls(query, bi_model, bi_tok)
26
- _, I = index.search(xq.reshape(1, vec_shape), 50)
27
  corpus = [data[int(i)]['answer'] for i in I[0]]
28
 
29
  queries = [query] * len(corpus)
30
  tokenized_texts = cross_tok(
31
  queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
32
- ).to(device)
33
 
34
  with torch.no_grad():
35
  model_output = cross_model(
@@ -43,7 +54,7 @@ def get_ranked_docs(query, vec_query_base, data,
43
  return corpus[scores_ix[0][0]]
44
 
45
 
46
- def load_dataset(url='ekaterinatao/house_md_context3'):
47
 
48
  dataset = datasets.load_dataset(url, split='train')
49
  house_dataset = dataset.filter(lambda row: row['labels'] == 0)
@@ -51,7 +62,7 @@ def load_dataset(url='ekaterinatao/house_md_context3'):
51
  return house_dataset
52
 
53
 
54
- def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
55
 
56
  cls_dataset = datasets.load_dataset(url, split='train')
57
  cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset])
@@ -59,7 +70,9 @@ def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
59
  return cls_base
60
 
61
 
62
- def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):
 
 
63
 
64
  bi_model = AutoModel.from_pretrained(checkpoint)
65
  bi_tok = AutoTokenizer.from_pretrained(checkpoint)
@@ -67,24 +80,11 @@ def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):
67
  return bi_model, bi_tok
68
 
69
 
70
- def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'):
 
 
71
 
72
  cross_model = AutoModel.from_pretrained(checkpoint)
73
  cross_tok = AutoTokenizer.from_pretrained(checkpoint)
74
 
75
- return cross_model, cross_tok
76
-
77
-
78
- def get_answer(message):
79
-
80
- dataset = load_dataset()
81
- cls_base = load_cls_base()
82
- bi_enc_model = load_bi_enc_model()
83
- cross_enc_model = load_cross_enc_model()
84
-
85
- answer = get_ranked_docs(
86
- query=message, vec_query_base=cls_base, data=dataset,
87
- bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1],
88
- cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1]
89
- )
90
- return answer
 
3
  import numpy as np
4
  import datasets
5
  from transformers import AutoTokenizer, AutoModel
6
+ from config_data.config import Config, load_config
7
 
 
8
 
9
+ config: Config = load_config()
10
+
11
+
12
+ def embed_bert_cls(
13
+ text: str,
14
+ model: AutoModel,
15
+ tokenizer: AutoTokenizer
16
+ ) -> np.ndarray:
17
 
 
18
  t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
19
  with torch.no_grad():
20
  model_output = model(**{k: v.to(model.device) for k, v in t.items()})
21
  embeds = model_output.last_hidden_state[:, 0, :]
22
  embeds = torch.nn.functional.normalize(embeds)
23
+
24
  return embeds[0].cpu().numpy()
25
 
26
 
27
+ def get_ranked_docs(
28
+ query: str, vec_query_base: np.ndarray, data: datasets,
29
+ bi_model: AutoModel, bi_tok: AutoTokenizer,
30
+ cross_model: AutoModel, cross_tok: AutoTokenizer
31
+ ) -> str:
32
 
33
  vec_shape = vec_query_base.shape[1]
34
  index = faiss.IndexFlatL2(vec_shape)
35
  index.add(vec_query_base)
36
  xq = embed_bert_cls(query, bi_model, bi_tok)
37
+ _, I = index.search(xq.reshape(1, vec_shape), 50) # corpus contains 50 similar queries
38
  corpus = [data[int(i)]['answer'] for i in I[0]]
39
 
40
  queries = [query] * len(corpus)
41
  tokenized_texts = cross_tok(
42
  queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
43
+ ).to(config.model.device)
44
 
45
  with torch.no_grad():
46
  model_output = cross_model(
 
54
  return corpus[scores_ix[0][0]]
55
 
56
 
57
+ def load_dataset(url: str=config.data.dataset) -> datasets:
58
 
59
  dataset = datasets.load_dataset(url, split='train')
60
  house_dataset = dataset.filter(lambda row: row['labels'] == 0)
 
62
  return house_dataset
63
 
64
 
65
+ def load_cls_base(url: str=config.data.cls_vec) -> np.array:
66
 
67
  cls_dataset = datasets.load_dataset(url, split='train')
68
  cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset])
 
70
  return cls_base
71
 
72
 
73
+ def load_bi_enc_model(
74
+ checkpoint: str=config.model.bi_checkpoint
75
+ ) -> tuple[AutoTokenizer, AutoModel]:
76
 
77
  bi_model = AutoModel.from_pretrained(checkpoint)
78
  bi_tok = AutoTokenizer.from_pretrained(checkpoint)
 
80
  return bi_model, bi_tok
81
 
82
 
83
+ def load_cross_enc_model(
84
+ checkpoint: str=config.model.cross_checkpoint
85
+ ) -> tuple[AutoTokenizer, AutoModel]:
86
 
87
  cross_model = AutoModel.from_pretrained(checkpoint)
88
  cross_tok = AutoTokenizer.from_pretrained(checkpoint)
89
 
90
+ return cross_model, cross_tok