Spaces:
Build error
Build error
import gradio as gr | |
from network import SpanNet | |
from huggingface_hub import Repository | |
def extract_spannet_scores(path,input_sentence,length, pos_col, task_col): | |
# sent = input_sentence.split() | |
# length = [len(sent)] | |
all_scores = [] | |
model = SpanNet.load_model(model_path=path) | |
scores = [] | |
model.eval() | |
out_dict = model(sentences=[sent], output_span_scores=True) | |
scores.extend([[t.tolist() for t in o[:l]] for o, l in zip(out_dict['span_scores'], length)]) | |
all_scores.append(scores) | |
return all_scores | |
def pool_span_scores(score_dicts, sent_lens): | |
TAGS = ['B', 'I', 'O'] | |
pooled_scores = [[np.argmax([sum([sd[sent_id][token_id][score_id] for sd in score_dicts]) | |
for score_id in range(len(score_dicts[0][sent_id][token_id]))]) | |
for token_id in range(sent_lens[sent_id])] | |
for sent_id in range(len(sent_lens))] | |
r = [[TAGS[ps] for ps in sent_ps] for sent_ps in pooled_scores] | |
return r | |
def predict_label(text): | |
model_path = 'models/span.model' | |
ip = text.split() | |
ip_len = [len(ip)] | |
scores = extract_spannet_scores(model_path,ip,ip_len, pos_col=1, task_col=2) | |
pooled_scores = pool_span_scores(scores, ip_len) | |
return pooled_scores | |
iface = gr.Interface(fn=predict_label, inputs="text", outputs="text") | |
iface.launch() | |