File size: 2,733 Bytes
62240fd
795587f
7916f53
4cb376e
8295f3b
d59f693
 
0c23ff5
1e7b155
608e3f9
3d70b45
59a84ff
608e3f9
d59f693
7f319ed
 
d59f693
 
7f319ed
 
 
 
186dc98
abad191
4338e7c
 
 
420f35b
 
1e7b155
608e3f9
4889e9c
 
 
460a31b
4889e9c
 
 
 
 
d59f693
4889e9c
 
 
 
1e7b155
4889e9c
 
 
2c7b4a1
 
4889e9c
 
1e7b155
4889e9c
 
 
7f319ed
4889e9c
 
 
f9eff7d
cd4dfb7
 
 
 
 
 
 
cace402
7958cde
cd4dfb7
1e7b155
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import numpy as np
import os
from huggingface_hub import hf_hub_download
from camel_tools.data import CATALOGUE
from camel_tools.tagger.default import DefaultTagger
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator

def predict_label(text):
    
    ip = text.split()
    ip_len = [len(ip)]
    
    span_scores = extract_spannet_scores(span_model,ip,ip_len)
    span_pooled_scores = pool_span_scores(span_scores, ip_len)

    pos_tags = tagger.tag(ip)
    msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
    msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)

    ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
    ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
    
    ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores)
    combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
    
    return combined_sequences


if __name__ == '__main__':
    
    # space_key = os.environ.get('key')
    # filenames = ['network.py', 'layers.py', 'utils.py',
    #              'representation.py', 'predict.py', 'validate.py']
    
    # for file in filenames:
    #     hf_hub_download('nehalelkaref/stagedNER',
    #                 filename=file,
    #                 local_dir='src',
    #                 token=space_key)
        
    # CATALOGUE.download_package("all",
    #                        recursive=True,
    #                        force=True,
    #                        print_status=True)
    
    # from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
    # from src.network import SpanNet, EntNet
    # from src.validate import entities_from_token_classes


    # diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
    # tagger = DefaultTagger(diasmbig, 'pos')
    
    # span_path = 'models/span.model'
    # msa_span_path = 'new_models/msa.best.model'
    # entity_path= 'models/entity.msa.model'
    
    # span_model = SpanNet.load_model(span_path)
    # msa_span_model = SpanNet.load_model(msa_span_path)
    # entity_model = EntNet.load_model(entity_path)
    # iface= gr.Base(primary_hue="green")
    comp = 
    
    with gr.Blocks() as iface:
        example_input=gr.Textbox(label="Input Example", lines=1)
        
        gr.Interface(fn=predict_label, inputs="text", outputs="text",
                         examples=example,theme="finlaymacklon/smooth_slate")
    example = gr.Examples(
      examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8)."],
      inputs= example_input)
    iface.launch(show_api=False)