Spaces:
Build error
Build error
File size: 2,059 Bytes
62240fd 795587f 28502e7 7916f53 795587f 0c23ff5 1e7b155 608e3f9 3d70b45 59a84ff 608e3f9 4338e7c 7f319ed 86ce30b 7f319ed 4338e7c 420f35b 1e7b155 608e3f9 1e7b155 460a31b 1e7b155 134f9fd 608e3f9 7f319ed 1e7b155 608e3f9 42065ef 4338e7c 7f319ed 463935b 7f319ed 4338e7c f9eff7d 1e7b155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
import os
def predict_label(text):
ip = text.split()
ip_len = [len(ip)]
span_scores = extract_spannet_scores(span_model,ip,ip_len, pos_col=1, task_col=2)
span_pooled_scores = pool_span_scores(span_scores, ip_len)
msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len, pos_col=1, task_col=2)
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
ensemble_pred_tags = [entities_from_token_classes(sent_targs) for sent_targs in ensemble_pooled_scores]
print('ensemble_pred_tags: ', ensemble_pred_tags)
ent_scores = extract_ent_scores(entity_model,ip,ensemble_pred_tags, pos_col=1, task_col=2)
combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
return combined_sequences
if __name__ == '__main__':
space_key = os.environ.get('key')
filenames = ['network.py', 'layers.py', 'utils.py',
'representation.py', 'predict.py', 'validate.py']
for file in filenames:
hf_hub_download('nehalelkaref/stagedNER',
filename=file,
local_dir='src',
token=space_key)
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
from src.network import SpanNet, EntNet
from src.validate import entities_from_token_classes
span_path = 'models/span.model'
msa_span_path = 'models/msa.best.model'
entity_path= 'models/entity.msa.model'
span_model = SpanNet.load_model(span_path)
msa_span_model = SpanNet.load_model(msa_span_path)
entity_model = EntNet.load_model(entity_path)
# iface= gr.Base(primary_hue="green")
iface = gr.Interface(fn=predict_label, inputs="text", outputs="text")
iface.launch(show_api=False)
|