Spaces:
Build error
Build error
Commit
·
608e3f9
1
Parent(s):
a567fbb
Update app.py
Browse files
app.py
CHANGED
@@ -3,55 +3,33 @@ import numpy as np
|
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import os
|
5 |
|
6 |
-
|
7 |
-
def extract_spannet_scores(model,input_sentence,length, pos_col, task_col):
|
8 |
-
|
9 |
-
all_scores = []
|
10 |
-
|
11 |
-
# model = SpanNet.load_model(model_path=path)
|
12 |
-
scores = []
|
13 |
-
model.eval()
|
14 |
-
|
15 |
-
out_dict = model(sentences=[input_sentence], output_span_scores=True)
|
16 |
-
scores.extend([[t.tolist() for t in o[:l]] for o, l in zip(out_dict['span_scores'], length)])
|
17 |
-
all_scores.append(scores)
|
18 |
-
return all_scores
|
19 |
-
|
20 |
-
def pool_span_scores(score_dicts, sent_lens):
|
21 |
-
TAGS = ['B', 'I', 'O']
|
22 |
-
pooled_scores = [[np.argmax([sum([sd[sent_id][token_id][score_id] for sd in score_dicts])
|
23 |
-
for score_id in range(len(score_dicts[0][sent_id][token_id]))])
|
24 |
-
for token_id in range(sent_lens[sent_id])]
|
25 |
-
for sent_id in range(len(sent_lens))]
|
26 |
-
|
27 |
-
r = [[TAGS[ps] for ps in sent_ps] for sent_ps in pooled_scores]
|
28 |
-
return r
|
29 |
|
30 |
def predict_label(text):
|
31 |
-
|
32 |
ip = text.split()
|
33 |
ip_len = [len(ip)]
|
34 |
-
scores = extract_spannet_scores(
|
35 |
pooled_scores = pool_span_scores(scores, ip_len)
|
36 |
-
|
37 |
-
# for op in pooled_scores[0]:
|
38 |
-
# output+= op + ','
|
39 |
return pooled_scores
|
40 |
|
41 |
|
42 |
if __name__ == '__main__':
|
|
|
43 |
space_key = os.environ.get('key')
|
44 |
-
filenames = ['network.py', 'layers.py', 'utils.py', 'representation.py']
|
45 |
for file in filenames:
|
46 |
hf_hub_download('nehalelkaref/stagedNER',
|
47 |
filename=file,
|
48 |
local_dir='src',
|
49 |
token=space_key)
|
50 |
|
51 |
-
from src.
|
|
|
52 |
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
|
57 |
iface.launch(show_api=False)
|
|
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import os
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def predict_label(text):
|
8 |
+
|
9 |
ip = text.split()
|
10 |
ip_len = [len(ip)]
|
11 |
+
scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
|
12 |
pooled_scores = pool_span_scores(scores, ip_len)
|
13 |
+
|
|
|
|
|
14 |
return pooled_scores
|
15 |
|
16 |
|
17 |
if __name__ == '__main__':
|
18 |
+
|
19 |
space_key = os.environ.get('key')
|
20 |
+
filenames = ['network.py', 'layers.py', 'utils.py', 'representation.py', 'predict.py']
|
21 |
for file in filenames:
|
22 |
hf_hub_download('nehalelkaref/stagedNER',
|
23 |
filename=file,
|
24 |
local_dir='src',
|
25 |
token=space_key)
|
26 |
|
27 |
+
from src.predict import extract_spannet_scores,pool_span_scores
|
28 |
+
from src.network import SpanNet, EntNet
|
29 |
|
30 |
+
span_path = 'models/span.model'
|
31 |
+
# span_msa_path = 'models/sp'
|
32 |
+
span_model = SpanNet.load_model(model_path)
|
33 |
|
34 |
iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
|
35 |
iface.launch(show_api=False)
|