Spaces:
Runtime error
Runtime error
reverting unnecessary changes
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import datasets
|
|
3 |
import faiss
|
4 |
import os
|
5 |
|
6 |
-
from transformers import pipeline, AutoModel, AutoTokenizer
|
7 |
|
8 |
|
9 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
@@ -23,49 +23,6 @@ textbox = gr.Textbox(
|
|
23 |
)
|
24 |
|
25 |
|
26 |
-
def prepare_query(tokenizer, query, max_seq_length=300):
|
27 |
-
# temporary solution b/c of padding (which is unnecessary for inference)
|
28 |
-
start_token: str = "[unused0]"
|
29 |
-
end_token: str = "[unused1]"
|
30 |
-
|
31 |
-
left_context = query.split(start_token)[0]
|
32 |
-
right_context = query.split(end_token)[-1]
|
33 |
-
mention = query.split(start_token)[-1].split(end_token)[0]
|
34 |
-
|
35 |
-
mention_ids = tokenizer(
|
36 |
-
f"{start_token}{mention}{end_token}",
|
37 |
-
add_special_tokens=False
|
38 |
-
)['input_ids']
|
39 |
-
|
40 |
-
left_ids = tokenizer(left_context, add_special_tokens=False)['input_ids']
|
41 |
-
left_quota = (max_seq_length - len(mention_ids)) // 2 - 1
|
42 |
-
|
43 |
-
right_ids = tokenizer(right_context, add_special_tokens=False)['input_ids']
|
44 |
-
right_quota = max_seq_length - len(mention_ids) - left_quota - 2
|
45 |
-
|
46 |
-
left_add, right_add = len(left_ids), len(right_ids)
|
47 |
-
if left_add <= left_quota:
|
48 |
-
right_quota += left_quota - left_add if right_add > right_quota else 0
|
49 |
-
else:
|
50 |
-
left_quota += right_quota - right_add if right_add <= right_quota else 0
|
51 |
-
|
52 |
-
context_ids = [
|
53 |
-
tokenizer.cls_token_id,
|
54 |
-
*left_ids[-left_quota:],
|
55 |
-
*mention_ids,
|
56 |
-
*right_ids[:right_quota],
|
57 |
-
tokenizer.sep_token_id
|
58 |
-
]
|
59 |
-
|
60 |
-
padding_length = max_seq_length - len(context_ids)
|
61 |
-
# attention_mask = [1] * len(context_ids) + [0] * padding_length
|
62 |
-
|
63 |
-
context_ids += [tokenizer.pad_token_id] * padding_length
|
64 |
-
|
65 |
-
assert len(context_ids) == max_seq_length
|
66 |
-
return context_ids
|
67 |
-
|
68 |
-
|
69 |
def load_index(index_data: str = "clarin-knext/entity-linking-index"):
|
70 |
ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
|
71 |
index_data = {
|
@@ -77,20 +34,23 @@ def load_index(index_data: str = "clarin-knext/entity-linking-index"):
|
|
77 |
|
78 |
|
79 |
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
|
80 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
81 |
-
model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
|
82 |
-
|
|
|
|
|
83 |
|
84 |
|
85 |
-
tokenizer, model = load_model()
|
|
|
86 |
index = load_index()
|
87 |
|
88 |
|
89 |
def predict(text: str = sample_text, top_k: int=3):
|
90 |
-
query = prepare_query(tokenizer, text)
|
91 |
index_data, faiss_index = index
|
92 |
# takes only the [CLS] embedding (for now)
|
93 |
-
query = model(query
|
94 |
|
95 |
scores, indices = faiss_index.search(query, top_k)
|
96 |
scores, indices = scores.tolist(), indices.tolist()
|
|
|
3 |
import faiss
|
4 |
import os
|
5 |
|
6 |
+
from transformers import pipeline # , AutoModel, AutoTokenizer
|
7 |
|
8 |
|
9 |
auth_token = os.environ.get("CLARIN_KNEXT")
|
|
|
23 |
)
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def load_index(index_data: str = "clarin-knext/entity-linking-index"):
|
27 |
ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
|
28 |
index_data = {
|
|
|
34 |
|
35 |
|
36 |
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
|
37 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
38 |
+
# model = AutoModel.from_pretrained(model_name, use_auth_token=auth_token)
|
39 |
+
model = pipeline("feature-extraction", model="clarin-knext/entity-linking-encoder", use_auth_token=auth_token)
|
40 |
+
# return tokenizer, model
|
41 |
+
return model
|
42 |
|
43 |
|
44 |
+
# tokenizer, model = load_model()
|
45 |
+
model = load_model()
|
46 |
index = load_index()
|
47 |
|
48 |
|
49 |
def predict(text: str = sample_text, top_k: int=3):
|
50 |
+
# query = prepare_query(tokenizer, text)
|
51 |
index_data, faiss_index = index
|
52 |
# takes only the [CLS] embedding (for now)
|
53 |
+
query = model(query)[0][0].numpy().reshape(1, -1)
|
54 |
|
55 |
scores, indices = faiss_index.search(query, top_k)
|
56 |
scores, indices = scores.tolist(), indices.tolist()
|