ajanz commited on
Commit
3ebb702
·
1 Parent(s): 66d0fee

an extended tokenizing function (as it was proposed in source project)

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -24,16 +24,16 @@ textbox = gr.Textbox(
24
 
25
 
26
  def prepare_query(tokenizer, query, max_seq_length=300):
27
- # temporary solution
28
- mention_start_token: str = "[unused0]"
29
- mention_end_token: str = "[unused1]"
30
 
31
- left_context = query.split(mention_start_token)[0]
32
- right_context = query.split(mention_end_token)[-1]
33
- mention = query.split(mention_start_token)[-1].split(mention_end_token)[0]
34
 
35
  mention_ids = tokenizer(
36
- mention_start_token + mention + mention_end_token,
37
  add_special_tokens=False
38
  )['input_ids']
39
 
@@ -79,16 +79,15 @@ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
79
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
80
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
81
  model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
82
- pipe = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
83
- return pipe
84
 
85
 
86
- model = load_model()
87
  index = load_index()
88
 
89
 
90
  def predict(text: str = sample_text, top_k: int=3):
91
- query = prepare_query(text)
92
  index_data, faiss_index = index
93
  # takes only the [CLS] embedding (for now)
94
  query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
 
24
 
25
 
26
  def prepare_query(tokenizer, query, max_seq_length=300):
27
+ # temporary solution b/c of padding (which is unnecessary for inference)
28
+ start_token: str = "[unused0]"
29
+ end_token: str = "[unused1]"
30
 
31
+ left_context = query.split(start_token)[0]
32
+ right_context = query.split(end_token)[-1]
33
+ mention = query.split(start_token)[-1].split(end_token)[0]
34
 
35
  mention_ids = tokenizer(
36
+ f"{start_token}{mention}{end_token}",
37
  add_special_tokens=False
38
  )['input_ids']
39
 
 
79
  def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
80
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
81
  model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
82
+ return tokenizer, model
 
83
 
84
 
85
+ tokenizer, model = load_model()
86
  index = load_index()
87
 
88
 
89
  def predict(text: str = sample_text, top_k: int=3):
90
+ query = prepare_query(tokenizer, text)
91
  index_data, faiss_index = index
92
  # takes only the [CLS] embedding (for now)
93
  query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)