Spaces:
Build error
Build error
File size: 3,117 Bytes
62240fd 795587f 7916f53 4cb376e 8295f3b d59f693 0c23ff5 1e7b155 6a541d9 3d70b45 59a84ff 608e3f9 d59f693 7f319ed d59f693 7f319ed 186dc98 abad191 4338e7c 8caf785 895c98f 4338e7c 895c98f 420f35b 1e7b155 608e3f9 c70f0bf bbdd0e4 460a31b bbdd0e4 d59f693 bbdd0e4 1e7b155 bbdd0e4 2c7b4a1 bbdd0e4 1e7b155 467a0eb d5d1243 7f319ed bbdd0e4 19cd698 dd3bb53 cd4dfb7 2d7bba1 19cd698 cd4dfb7 8caf785 236e944 8caf785 1db1d29 189ddb9 8caf785 cd4dfb7 19cd698 1e7b155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import numpy as np
import os
from huggingface_hub import hf_hub_download
from camel_tools.data import CATALOGUE
from camel_tools.tagger.default import DefaultTagger
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
def predict_label(text):
ip = text.split()
ip_len = [len(ip)]
span_scores = extract_spannet_scores(span_model,ip,ip_len)
span_pooled_scores = pool_span_scores(span_scores, ip_len)
pos_tags = tagger.tag(ip)
msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores)
combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
# ops = [[i,o] for i,o in zip(ip,combined_sequences[-1])]
return combined_sequences
if __name__ == '__main__':
space_key = os.environ.get('key')
filenames = ['network.py', 'layers.py', 'utils.py',
'representation.py', 'predict.py', 'validate.py']
for file in filenames:
hf_hub_download('nehalelkaref/stagedNER',
filename=file,
local_dir='src',
token=space_key)
CATALOGUE.download_package("all",
recursive=True,
force=True,
print_status=True)
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
from src.network import SpanNet, EntNet
from src.validate import entities_from_token_classes
diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
tagger = DefaultTagger(diasmbig, 'pos')
# entity_path = 'nehalelkaref/entity_model/entity.msa.model'
# span_path = 'models/span.model'
# msa_span_path = 'new_models/msa.best.model'
# entity_path= 'models/entity.msa.model'
span_model = SpanNet.load_model(span_path)
msa_span_model = SpanNet.load_model(msa_span_path)
entity_model = EntNet.load_model(entity_path)
with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface:
example_input=gr.Textbox(label="Input Example", lines=3)
prediction=gr.Text(label="Predicted Entities")
gr.Interface(fn=predict_label, inputs=example_input,
outputs=prediction,theme="smooth_slate",
title="Flat Entity Classification for Levantine Arabic")
gr.Examples(
examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).",
"صورة لمدينة أريحا القديمة :تل السلطان",
"صورة اطفال مخيم للاجئين الفلسطينيين في لبنان"],
inputs= example_input)
iface.launch(show_api=False)
|