nehalelkaref commited on
Commit
0c23ff5
·
1 Parent(s): 549e03f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -5
app.py CHANGED
@@ -2,12 +2,40 @@ import gradio as gr
2
  from network import SpanNet
3
  from huggingface_hub import Repository
4
 
 
5
  def greet(name):
6
  return "Hello " + name + "!!"
 
 
7
 
8
- # args= dict({'transformer_model_name':'nehalelkaref/plain_span',
9
- # 'device':'cuda'})
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
11
- model = SpanNet(transformer_model_name= 'nehalelkaref/plain_span',device='cuda')
 
 
 
 
 
 
 
 
 
12
 
13
- # iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from network import SpanNet
3
  from huggingface_hub import Repository
4
 
5
+ git clone https://huggingface.co/spaces/nehalelkaref/plain_span
6
  def greet(name):
7
  return "Hello " + name + "!!"
8
+
9
+ def extract_spannet_scores(path,input_sentence, pos_col, task_col):
10
 
11
+ sent = input_sentence.split()
12
+ length = [len(sent)]
13
+
14
+ all_scores = []
15
+
16
+ model = SpanNet.load_model(model_path=path, device='cuda').cuda()
17
+ scores = []
18
+ model.eval()
19
+
20
+ out_dict = model(sentences=[sent], output_span_scores=True)
21
+ scores.extend([[t.tolist() for t in o[:l]] for o, l in zip(out_dict['span_scores'], length)])
22
+ all_scores.append(scores)
23
+ return all_scores
24
 
25
+ def pool_span_scores(score_dicts, sent_lens):
26
+ TAGS = ['B', 'I', 'O']
27
+ pooled_scores = [[np.argmax([sum([sd[sent_id][token_id][score_id] for sd in score_dicts])
28
+ for score_id in range(len(score_dicts[0][sent_id][token_id]))])
29
+ for token_id in range(sent_lens[sent_id])]
30
+ for sent_id in range(len(sent_lens))]
31
+
32
+ r = [[TAGS[ps] for ps in sent_ps]for sent_ps in pooled_scores]
33
+ return r
34
+
35
+ input_text = gr.textbox(label="Text 1")
36
+
37
+ iface = gr.Interface(fn=greet, inputs=input_text, outputs="text")
38
+ iface.
39
+ model_path = 'models/span_model'
40
+ scores = extract_spannet_scores(model_path,input_text, pos_col=1, task_col=2)
41
+ iface.launch()