nehalelkaref commited on
Commit
4338e7c
·
1 Parent(s): 463935b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -8,10 +8,14 @@ def predict_label(text):
8
 
9
  ip = text.split()
10
  ip_len = [len(ip)]
11
- scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
12
- pooled_scores = pool_span_scores(scores, ip_len)
13
 
14
- return pooled_scores
 
 
 
 
 
 
15
 
16
 
17
  if __name__ == '__main__':
@@ -30,8 +34,10 @@ if __name__ == '__main__':
30
  from src.network import SpanNet, EntNet
31
 
32
  span_path = 'models/span.model'
 
33
  # span_msa_path = 'models/sp'
34
  span_model = SpanNet.load_model(span_path)
 
35
 
36
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
37
  iface.launch(show_api=False)
 
8
 
9
  ip = text.split()
10
  ip_len = [len(ip)]
 
 
11
 
12
+ span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
13
+ pooled_scores = pool_span_scores(span_scores, ip_len)
14
+
15
+ ent_scores = extract_ent_scores(ent_model, ip,pooled_scores, pos_col=1, task_col=2)
16
+ combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
17
+
18
+ return combined_sequences
19
 
20
 
21
  if __name__ == '__main__':
 
34
  from src.network import SpanNet, EntNet
35
 
36
  span_path = 'models/span.model'
37
+ entity_path= 'models/entity.msa.model'
38
  # span_msa_path = 'models/sp'
39
  span_model = SpanNet.load_model(span_path)
40
+ entity_model = EntNet.load_model(entity_path)
41
 
42
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
43
  iface.launch(show_api=False)