File size: 2,059 Bytes
62240fd
795587f
28502e7
7916f53
795587f
0c23ff5
1e7b155
608e3f9
3d70b45
59a84ff
608e3f9
4338e7c
7f319ed
 
 
 
 
 
 
 
86ce30b
7f319ed
4338e7c
 
 
420f35b
 
1e7b155
608e3f9
1e7b155
460a31b
 
 
1e7b155
 
 
 
 
 
134f9fd
608e3f9
7f319ed
1e7b155
608e3f9
42065ef
4338e7c
7f319ed
463935b
7f319ed
4338e7c
f9eff7d
1e7b155
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
import os


def predict_label(text):
    
    ip = text.split()
    ip_len = [len(ip)]
    
    span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
    span_pooled_scores = pool_span_scores(span_scores, ip_len)

    msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len, pos_col=1, task_col=2)
    msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)

    ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
    ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
    ensemble_pred_tags = [entities_from_token_classes(sent_targs) for sent_targs in ensemble_pooled_scores]
    print('ensemble_pred_tags: ', ensemble_pred_tags)
    ent_scores = extract_ent_scores(entity_model,ip,ensemble_pred_tags, pos_col=1, task_col=2)
    combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
    
    return combined_sequences


if __name__ == '__main__':
    
    space_key = os.environ.get('key')
    filenames = ['network.py', 'layers.py', 'utils.py',
                 'representation.py', 'predict.py', 'validate.py']
    
    for file in filenames:
        hf_hub_download('nehalelkaref/stagedNER',
                    filename=file,
                    local_dir='src',
                    token=space_key)
    
    from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
    from src.network import SpanNet, EntNet
    from src.validate import entities_from_token_classes
    
    span_path = 'models/span.model'
    msa_span_path = 'models/msa.best.model'
    entity_path= 'models/entity.msa.model'
    
    span_model = SpanNet.load_model(span_path)
    msa_span_model = SpanNet.load_model(msa_span_path)
    entity_model = EntNet.load_model(entity_path)
    # iface= gr.Base(primary_hue="green")
    iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
    iface.launch(show_api=False)