File size: 2,185 Bytes
62240fd
795587f
28502e7
7916f53
795587f
0c23ff5
1e7b155
608e3f9
3d70b45
59a84ff
608e3f9
4338e7c
7f319ed
 
383c547
7f319ed
 
 
 
186dc98
 
4338e7c
 
 
420f35b
 
1e7b155
608e3f9
1e7b155
460a31b
 
 
1e7b155
 
 
 
 
 
134f9fd
608e3f9
7f319ed
2c7b4a1
 
 
 
 
 
 
 
1e7b155
608e3f9
42065ef
4338e7c
7f319ed
463935b
7f319ed
4338e7c
f9eff7d
1e7b155
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
import os


def predict_label(text):
    
    ip = text.split()
    ip_len = [len(ip)]
    
    span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
    span_pooled_scores = pool_span_scores(span_scores, ip_len)

    msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len, pos_col=1, task_col=2, pos='not none')
    msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)

    ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
    ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
    
    ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores, pos_col=1, task_col=2)
    combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
    
    return combined_sequences


if __name__ == '__main__':
    
    space_key = os.environ.get('key')
    filenames = ['network.py', 'layers.py', 'utils.py',
                 'representation.py', 'predict.py', 'validate.py']
    
    for file in filenames:
        hf_hub_download('nehalelkaref/stagedNER',
                    filename=file,
                    local_dir='src',
                    token=space_key)
    
    from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
    from src.network import SpanNet, EntNet
    from src.validate import entities_from_token_classes

    from camel_tools.disambig.mle import MLEDisambiguator
    from camel_tools.tagger.default import DefaultTagger

    mled = MLEDisambiguator.pretrained()
    tagger = DefaultTagger(mled, 'pos')

    print(tagger.tag('ذهبت الى المدرسة'.split()))
    
    span_path = 'models/span.model'
    msa_span_path = 'models/msa.best.model'
    entity_path= 'models/entity.msa.model'
    
    span_model = SpanNet.load_model(span_path)
    msa_span_model = SpanNet.load_model(msa_span_path)
    entity_model = EntNet.load_model(entity_path)
    # iface= gr.Base(primary_hue="green")
    iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
    iface.launch(show_api=False)