File size: 1,934 Bytes
62240fd
795587f
1d8d3a1
9675c65
7916f53
795587f
0c23ff5
44e2bdb
62240fd
0c23ff5
 
ac10e5c
0c23ff5
 
 
98173fb
0c23ff5
 
 
eaaa625
0c23ff5
 
 
 
 
 
 
3d70b45
0c23ff5
 
44e2bdb
1195aea
3d70b45
59a84ff
44e2bdb
 
5c5fdba
 
 
870601f
9e6957e
1cfd324
d08e430
1cfd324
d08e430
2892224
2088406
7916f53
 
5fa08a0
09297aa
 
0c23ff5
7916f53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import numpy as np
from network import SpanNet
from huggingface_hub import Repository
import os

    
def extract_spannet_scores(path,input_sentence,length, pos_col, task_col):

    all_scores = []
    
    model = SpanNet.load_model(model_path=path)
    scores = []
    model.eval()
        
    out_dict = model(sentences=[input_sentence], output_span_scores=True)
    scores.extend([[t.tolist() for t in o[:l]] for o, l in zip(out_dict['span_scores'], length)])
    all_scores.append(scores)
    return all_scores

def pool_span_scores(score_dicts, sent_lens):
    TAGS = ['B', 'I', 'O']
    pooled_scores = [[np.argmax([sum([sd[sent_id][token_id][score_id] for sd in score_dicts])
                                    for score_id in range(len(score_dicts[0][sent_id][token_id]))])
                                    for token_id in range(sent_lens[sent_id])]
                        for sent_id in range(len(sent_lens))]
                        
    r = [[TAGS[ps] for ps in sent_ps] for sent_ps in pooled_scores]
    return r 

def predict_label(text):
    model_path = 'models/span.model'
    ip = text.split()
    ip_len = [len(ip)]
    scores = extract_spannet_scores(model_path,ip,ip_len, pos_col=1, task_col=2)
    pooled_scores = pool_span_scores(scores, ip_len)
    output=''
    for op in pooled_scores[0]:
        output+= op + ','
    print('OUTPUT HERE')
    return 'output'
def temp(text):
    print('IN FUNCTION')
    return text
print('STARTING ..')
# model_path = 'models/span.model'
# model = SpanNet.load_model(model_path)
space_key = os.environ.get('key')
gr.load(name="nehalelkaref/flat-arabic-entity-classification", hf_token=space_key, src='spaces')
iface = gr.Interface(fn=temp, inputs="text", outputs="text", batch=False)
# iface = gr.Interface(fn=predict_label, inputs="text", outputs="text",auth=True)
iface.launch(share=True, blocked_paths=['models'])

iface.launch(show_api=False)