ajanz commited on
Commit
e30b005
·
1 Parent(s): 3ebb702

reverting unnecessary changes

Browse files
Files changed (1) hide show
  1. app.py +10 -50
app.py CHANGED
@@ -3,7 +3,7 @@ import datasets
3
  import faiss
4
  import os
5
 
6
- from transformers import pipeline, AutoModel, AutoTokenizer
7
 
8
 
9
  auth_token = os.environ.get("CLARIN_KNEXT")
@@ -23,49 +23,6 @@ textbox = gr.Textbox(
23
  )
24
 
25
 
26
- def prepare_query(tokenizer, query, max_seq_length=300):
27
- # temporary solution b/c of padding (which is unnecessary for inference)
28
- start_token: str = "[unused0]"
29
- end_token: str = "[unused1]"
30
-
31
- left_context = query.split(start_token)[0]
32
- right_context = query.split(end_token)[-1]
33
- mention = query.split(start_token)[-1].split(end_token)[0]
34
-
35
- mention_ids = tokenizer(
36
- f"{start_token}{mention}{end_token}",
37
- add_special_tokens=False
38
- )['input_ids']
39
-
40
- left_ids = tokenizer(left_context, add_special_tokens=False)['input_ids']
41
- left_quota = (max_seq_length - len(mention_ids)) // 2 - 1
42
-
43
- right_ids = tokenizer(right_context, add_special_tokens=False)['input_ids']
44
- right_quota = max_seq_length - len(mention_ids) - left_quota - 2
45
-
46
- left_add, right_add = len(left_ids), len(right_ids)
47
- if left_add <= left_quota:
48
- right_quota += left_quota - left_add if right_add > right_quota else 0
49
- else:
50
- left_quota += right_quota - right_add if right_add <= right_quota else 0
51
-
52
- context_ids = [
53
- tokenizer.cls_token_id,
54
- *left_ids[-left_quota:],
55
- *mention_ids,
56
- *right_ids[:right_quota],
57
- tokenizer.sep_token_id
58
- ]
59
-
60
- padding_length = max_seq_length - len(context_ids)
61
- # attention_mask = [1] * len(context_ids) + [0] * padding_length
62
-
63
- context_ids += [tokenizer.pad_token_id] * padding_length
64
-
65
- assert len(context_ids) == max_seq_length
66
- return context_ids
67
-
68
-
69
  def load_index(index_data: str = "clarin-knext/entity-linking-index"):
70
  ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
71
  index_data = {
@@ -77,20 +34,23 @@ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
77
 
78
 
79
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
80
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
81
- model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
82
- return tokenizer, model
 
 
83
 
84
 
85
- tokenizer, model = load_model()
 
86
  index = load_index()
87
 
88
 
89
  def predict(text: str = sample_text, top_k: int=3):
90
- query = prepare_query(tokenizer, text)
91
  index_data, faiss_index = index
92
  # takes only the [CLS] embedding (for now)
93
- query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
94
 
95
  scores, indices = faiss_index.search(query, top_k)
96
  scores, indices = scores.tolist(), indices.tolist()
 
3
  import faiss
4
  import os
5
 
6
+ from transformers import pipeline # , AutoModel, AutoTokenizer
7
 
8
 
9
  auth_token = os.environ.get("CLARIN_KNEXT")
 
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_index(index_data: str = "clarin-knext/entity-linking-index"):
27
  ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
28
  index_data = {
 
34
 
35
 
36
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
37
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
38
+ # model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
39
+ model = pipeline("feature-extraction", model="clarin-knext/entity-linking-encoder", use_auth_token=auth_token)
40
+ # return tokenizer, model
41
+ return model
42
 
43
 
44
+ # tokenizer, model = load_model()
45
+ model = load_model()
46
  index = load_index()
47
 
48
 
49
  def predict(text: str = sample_text, top_k: int=3):
50
+ # query = prepare_query(tokenizer, text)
51
  index_data, faiss_index = index
52
  # takes only the [CLS] embedding (for now)
53
+ query = model(query)[0][0].numpy().reshape(1, -1)
54
 
55
  scores, indices = faiss_index.search(query, top_k)
56
  scores, indices = scores.tolist(), indices.tolist()