Spaces:
Runtime error
Runtime error
an extended tokenizing function (as it was proposed in source project)
Browse files
app.py
CHANGED
@@ -24,16 +24,16 @@ textbox = gr.Textbox(
|
|
24 |
|
25 |
|
26 |
def prepare_query(tokenizer, query, max_seq_length=300):
|
27 |
-
# temporary solution
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
left_context = query.split(
|
32 |
-
right_context = query.split(
|
33 |
-
mention = query.split(
|
34 |
|
35 |
mention_ids = tokenizer(
|
36 |
-
|
37 |
add_special_tokens=False
|
38 |
)['input_ids']
|
39 |
|
@@ -79,16 +79,15 @@ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
|
|
79 |
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
|
80 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
81 |
model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
|
82 |
-
|
83 |
-
return pipe
|
84 |
|
85 |
|
86 |
-
model = load_model()
|
87 |
index = load_index()
|
88 |
|
89 |
|
90 |
def predict(text: str = sample_text, top_k: int=3):
|
91 |
-
query = prepare_query(text)
|
92 |
index_data, faiss_index = index
|
93 |
# takes only the [CLS] embedding (for now)
|
94 |
query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
|
|
|
24 |
|
25 |
|
26 |
def prepare_query(tokenizer, query, max_seq_length=300):
|
27 |
+
# temporary solution b/c of padding (which is unnecessary for inference)
|
28 |
+
start_token: str = "[unused0]"
|
29 |
+
end_token: str = "[unused1]"
|
30 |
|
31 |
+
left_context = query.split(start_token)[0]
|
32 |
+
right_context = query.split(end_token)[-1]
|
33 |
+
mention = query.split(start_token)[-1].split(end_token)[0]
|
34 |
|
35 |
mention_ids = tokenizer(
|
36 |
+
f"{start_token}{mention}{end_token}",
|
37 |
add_special_tokens=False
|
38 |
)['input_ids']
|
39 |
|
|
|
79 |
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
|
80 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
81 |
model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
|
82 |
+
return tokenizer, model
|
|
|
83 |
|
84 |
|
85 |
+
tokenizer, model = load_model()
|
86 |
index = load_index()
|
87 |
|
88 |
|
89 |
def predict(text: str = sample_text, top_k: int=3):
|
90 |
+
query = prepare_query(tokenizer, text)
|
91 |
index_data, faiss_index = index
|
92 |
# takes only the [CLS] embedding (for now)
|
93 |
query = model(query, return_tensors = "pt")[0][0].numpy().reshape(1, -1)
|