nehalelkaref commited on
Commit
7f319ed
·
1 Parent(s): 134f9fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -10,9 +10,16 @@ def predict_label(text):
10
  ip_len = [len(ip)]
11
 
12
  span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
13
- pooled_scores = pool_span_scores(span_scores, ip_len)
14
-
15
- ent_scores = extract_ent_scores(entity_model,ip,pooled_scores, pos_col=1, task_col=2)
 
 
 
 
 
 
 
16
  combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
17
 
18
  return combined_sequences
@@ -32,11 +39,14 @@ if __name__ == '__main__':
32
 
33
  from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
34
  from src.network import SpanNet, EntNet
 
35
 
36
  span_path = 'models/span.model'
 
37
  entity_path= 'models/entity.msa.model'
38
- # span_msa_path = 'models/sp'
39
  span_model = SpanNet.load_model(span_path)
 
40
  entity_model = EntNet.load_model(entity_path)
41
 
42
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
 
10
  ip_len = [len(ip)]
11
 
12
  span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
13
+ span_pooled_scores = pool_span_scores(span_scores, ip_len)
14
+
15
+ msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len, pos_col=1, task_col=2)
16
+ msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
17
+
18
+ ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
19
+ ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
20
+ ensemble_pred_tags = [entities_from_token_classes(sent_targs) for sent_targs in ensemble_pooled_scores]
21
+
22
+ ent_scores = extract_ent_scores(entity_model,ip,ensemble_pred_tags, pos_col=1, task_col=2)
23
  combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
24
 
25
  return combined_sequences
 
39
 
40
  from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
41
  from src.network import SpanNet, EntNet
42
+ from src.validate import entities_from_token_classes
43
 
44
  span_path = 'models/span.model'
45
+ msa_span_path = 'models/entity.msa.model'
46
  entity_path= 'models/entity.msa.model'
47
+
48
  span_model = SpanNet.load_model(span_path)
49
+ msa_span_model = SpanNet.load_model(msa_span_path)
50
  entity_model = EntNet.load_model(entity_path)
51
 
52
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")