Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from gradio import FlaggingCallback | |
from gradio.components import IOComponent | |
from transformers import pipeline | |
from typing import List, Optional, Any | |
import argilla as rg | |
import os | |
nlp = pipeline("ner", model="mrm8488/bert-spanish-cased-finetuned-ner") | |
examples = [ | |
["Mi nombre es Juan y vivo en Barcelona"] | |
] | |
def create_record(input_text, feedback): | |
# define the record status based on feedback | |
# default means it needs to be reviewed --> "Incorrect" or "Ambiguous" | |
# validated means it's correct and has been checked --> "Correct" | |
status = "Validated" if feedback == "DoΔru" else "Default" | |
# Making the prediction | |
predictions = nlp(input_text, aggregation_strategy="first") | |
# Creating the predicted entities as a list of tuples (entity, start_char, end_char, score) | |
prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions] | |
# Create word tokens | |
batch_encoding = nlp.tokenizer(input_text) | |
word_ids = sorted(set(batch_encoding.word_ids()) - {None}) | |
words = [] | |
for word_id in word_ids: | |
char_span = batch_encoding.word_to_chars(word_id) | |
words.append(input_text[char_span.start:char_span.end]) | |
# Building a TokenClassificationRecord | |
record = rg.TokenClassificationRecord( | |
text=input_text, | |
tokens=words, | |
prediction=prediction, | |
prediction_agent="gradio_crowd", | |
status=status, | |
metadata={"feedback": feedback} | |
) | |
print(record) | |
return record | |
class ArgillaLogger(FlaggingCallback): | |
def __init__(self, api_url, api_key, dataset_name): | |
rg.init(api_url=api_url, api_key=api_key) | |
self.dataset_name = dataset_name | |
def setup(self, components: List[IOComponent], flagging_dir: str): | |
pass | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
text = flag_data[0] | |
inference = flag_data[1] | |
rg.log(name=self.dataset_name, records=create_record(text, flag_option)) | |
gr.Interface.load( | |
"mrm8488/bert-spanish-cased-finetuned-ner", | |
examples=examples, | |
title = "NER en EspaΓ±ol, crowdsource con Argilla", | |
description = "Ayudanos a mejorar este model introduciendo un ejemplo clasificandolo como correcto, incorrecto o ambiguo", | |
allow_flagging="manual", | |
flagging_callback=ArgillaLogger( | |
api_url="https://dvilasuero-taller-somosnlp.hf.space", | |
api_key=os.getenv("TEAM_API_KEY"), | |
dataset_name="ner-flags" | |
), | |
flagging_options=["Correcto", "Incorrecto", "Ambiguo"] | |
).launch() |