nehalelkaref commited on
Commit
d59f693
·
1 Parent(s): 8295f3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -3,22 +3,19 @@ import numpy as np
3
  import os
4
  from huggingface_hub import hf_hub_download
5
  from camel_tools.data import CATALOGUE
 
 
6
 
7
-
8
-
9
- CATALOGUE.download_package("light",
10
- recursive=True,
11
- force=True,
12
- print_status=True)
13
  def predict_label(text):
14
 
15
  ip = text.split()
16
  ip_len = [len(ip)]
17
 
18
- span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
19
  span_pooled_scores = pool_span_scores(span_scores, ip_len)
20
 
21
- msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len, pos_col=1, task_col=2, pos='not none')
 
22
  msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
23
 
24
  ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
@@ -41,21 +38,22 @@ if __name__ == '__main__':
41
  filename=file,
42
  local_dir='src',
43
  token=space_key)
 
 
 
 
 
44
 
45
  from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
46
  from src.network import SpanNet, EntNet
47
  from src.validate import entities_from_token_classes
48
 
49
- from camel_tools.disambig.mle import MLEDisambiguator
50
- from camel_tools.tagger.default import DefaultTagger
51
-
52
- mled = MLEDisambiguator.pretrained()
53
- tagger = DefaultTagger(mled, 'pos')
54
 
55
- print(tagger.tag('ذهبت الى المدرسة'.split()))
 
56
 
57
  span_path = 'models/span.model'
58
- msa_span_path = 'models/msa.best.model'
59
  entity_path= 'models/entity.msa.model'
60
 
61
  span_model = SpanNet.load_model(span_path)
 
3
  import os
4
  from huggingface_hub import hf_hub_download
5
  from camel_tools.data import CATALOGUE
6
+ from camel_tools.tagger.default import DefaultTagger
7
+ from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
8
 
 
 
 
 
 
 
9
  def predict_label(text):
10
 
11
  ip = text.split()
12
  ip_len = [len(ip)]
13
 
14
+ span_scores = extract_spannet_scores(span_model,ip,ip_len)
15
  span_pooled_scores = pool_span_scores(span_scores, ip_len)
16
 
17
+ pos_tags = tagger.tag(ip)
18
+ msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
19
  msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
20
 
21
  ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
 
38
  filename=file,
39
  local_dir='src',
40
  token=space_key)
41
+
42
+ CATALOGUE.download_package("light",
43
+ recursive=True,
44
+ force=True,
45
+ print_status=True)
46
 
47
  from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
48
  from src.network import SpanNet, EntNet
49
  from src.validate import entities_from_token_classes
50
 
 
 
 
 
 
51
 
52
+ diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
53
+ tagger = DefaultTagger(diasmbig, 'pos')
54
 
55
  span_path = 'models/span.model'
56
+ msa_span_path = 'new_models/msa.best.model'
57
  entity_path= 'models/entity.msa.model'
58
 
59
  span_model = SpanNet.load_model(span_path)