Spaces:
Build error
Build error
Commit
·
d59f693
1
Parent(s):
8295f3b
Update app.py
Browse files
app.py
CHANGED
@@ -3,22 +3,19 @@ import numpy as np
|
|
3 |
import os
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from camel_tools.data import CATALOGUE
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
CATALOGUE.download_package("light",
|
10 |
-
recursive=True,
|
11 |
-
force=True,
|
12 |
-
print_status=True)
|
13 |
def predict_label(text):
|
14 |
|
15 |
ip = text.split()
|
16 |
ip_len = [len(ip)]
|
17 |
|
18 |
-
span_scores = extract_spannet_scores(span_model,ip,ip_len
|
19 |
span_pooled_scores = pool_span_scores(span_scores, ip_len)
|
20 |
|
21 |
-
|
|
|
22 |
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
|
23 |
|
24 |
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
|
@@ -41,21 +38,22 @@ if __name__ == '__main__':
|
|
41 |
filename=file,
|
42 |
local_dir='src',
|
43 |
token=space_key)
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
|
46 |
from src.network import SpanNet, EntNet
|
47 |
from src.validate import entities_from_token_classes
|
48 |
|
49 |
-
from camel_tools.disambig.mle import MLEDisambiguator
|
50 |
-
from camel_tools.tagger.default import DefaultTagger
|
51 |
-
|
52 |
-
mled = MLEDisambiguator.pretrained()
|
53 |
-
tagger = DefaultTagger(mled, 'pos')
|
54 |
|
55 |
-
|
|
|
56 |
|
57 |
span_path = 'models/span.model'
|
58 |
-
msa_span_path = '
|
59 |
entity_path= 'models/entity.msa.model'
|
60 |
|
61 |
span_model = SpanNet.load_model(span_path)
|
|
|
3 |
import os
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from camel_tools.data import CATALOGUE
|
6 |
+
from camel_tools.tagger.default import DefaultTagger
|
7 |
+
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def predict_label(text):
|
10 |
|
11 |
ip = text.split()
|
12 |
ip_len = [len(ip)]
|
13 |
|
14 |
+
span_scores = extract_spannet_scores(span_model,ip,ip_len)
|
15 |
span_pooled_scores = pool_span_scores(span_scores, ip_len)
|
16 |
|
17 |
+
pos_tags = tagger.tag(ip)
|
18 |
+
msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
|
19 |
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
|
20 |
|
21 |
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
|
|
|
38 |
filename=file,
|
39 |
local_dir='src',
|
40 |
token=space_key)
|
41 |
+
|
42 |
+
CATALOGUE.download_package("light",
|
43 |
+
recursive=True,
|
44 |
+
force=True,
|
45 |
+
print_status=True)
|
46 |
|
47 |
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
|
48 |
from src.network import SpanNet, EntNet
|
49 |
from src.validate import entities_from_token_classes
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
|
53 |
+
tagger = DefaultTagger(diasmbig, 'pos')
|
54 |
|
55 |
span_path = 'models/span.model'
|
56 |
+
msa_span_path = 'new_models/msa.best.model'
|
57 |
entity_path= 'models/entity.msa.model'
|
58 |
|
59 |
span_model = SpanNet.load_model(span_path)
|