nehalelkaref commited on
Commit
608e3f9
·
1 Parent(s): a567fbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -3,55 +3,33 @@ import numpy as np
3
  from huggingface_hub import hf_hub_download
4
  import os
5
 
6
-
7
- def extract_spannet_scores(model,input_sentence,length, pos_col, task_col):
8
-
9
- all_scores = []
10
-
11
- # model = SpanNet.load_model(model_path=path)
12
- scores = []
13
- model.eval()
14
-
15
- out_dict = model(sentences=[input_sentence], output_span_scores=True)
16
- scores.extend([[t.tolist() for t in o[:l]] for o, l in zip(out_dict['span_scores'], length)])
17
- all_scores.append(scores)
18
- return all_scores
19
-
20
- def pool_span_scores(score_dicts, sent_lens):
21
- TAGS = ['B', 'I', 'O']
22
- pooled_scores = [[np.argmax([sum([sd[sent_id][token_id][score_id] for sd in score_dicts])
23
- for score_id in range(len(score_dicts[0][sent_id][token_id]))])
24
- for token_id in range(sent_lens[sent_id])]
25
- for sent_id in range(len(sent_lens))]
26
-
27
- r = [[TAGS[ps] for ps in sent_ps] for sent_ps in pooled_scores]
28
- return r
29
 
30
  def predict_label(text):
31
- # model_path = 'models/span.model'
32
  ip = text.split()
33
  ip_len = [len(ip)]
34
- scores = extract_spannet_scores(model,ip,ip_len, pos_col=1, task_col=2)
35
  pooled_scores = pool_span_scores(scores, ip_len)
36
- # output=''
37
- # for op in pooled_scores[0]:
38
- # output+= op + ','
39
  return pooled_scores
40
 
41
 
42
  if __name__ == '__main__':
 
43
  space_key = os.environ.get('key')
44
- filenames = ['network.py', 'layers.py', 'utils.py', 'representation.py']
45
  for file in filenames:
46
  hf_hub_download('nehalelkaref/stagedNER',
47
  filename=file,
48
  local_dir='src',
49
  token=space_key)
50
 
51
- from src.network import SpanNet
 
52
 
53
- model_path = 'models/span.model'
54
- model = SpanNet.load_model(model_path)
 
55
 
56
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
57
  iface.launch(show_api=False)
 
3
  from huggingface_hub import hf_hub_download
4
  import os
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def predict_label(text):
8
+
9
  ip = text.split()
10
  ip_len = [len(ip)]
11
+ scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
12
  pooled_scores = pool_span_scores(scores, ip_len)
13
+
 
 
14
  return pooled_scores
15
 
16
 
17
  if __name__ == '__main__':
18
+
19
  space_key = os.environ.get('key')
20
+ filenames = ['network.py', 'layers.py', 'utils.py', 'representation.py', 'predict.py']
21
  for file in filenames:
22
  hf_hub_download('nehalelkaref/stagedNER',
23
  filename=file,
24
  local_dir='src',
25
  token=space_key)
26
 
27
+ from src.predict import extract_spannet_scores,pool_span_scores
28
+ from src.network import SpanNet, EntNet
29
 
30
+ span_path = 'models/span.model'
31
+ # span_msa_path = 'models/sp'
32
+ span_model = SpanNet.load_model(model_path)
33
 
34
  iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
35
  iface.launch(show_api=False)