|
import argilla as rg |
|
import time |
|
import pandas as pd |
|
from argilla.client.singleton import active_client |
|
from utils.config import Color |
|
from dataset.base_dataset import DatasetBase |
|
import json |
|
import webbrowser |
|
import base64 |
|
|
|
class ArgillaEstimator: |
|
""" |
|
The ArgillaEstimator class is responsible to generate the GT for the dataset by using Argilla interface. |
|
In particular using the text classification mode. |
|
""" |
|
def __init__(self, opt): |
|
""" |
|
Initialize a new instance of the ArgillaEstimator class. |
|
""" |
|
try: |
|
self.opt = opt |
|
rg.init( |
|
api_url=opt.api_url, |
|
api_key=opt.api_key, |
|
workspace=opt.workspace |
|
) |
|
self.time_interval = opt.time_interval |
|
except: |
|
raise Exception("Failed to connect to argilla, check connection details") |
|
|
|
@staticmethod |
|
def initialize_dataset(dataset_name: str, label_schema: set[str]): |
|
""" |
|
Initialize a new dataset in the Argilla system |
|
:param dataset_name: The name of the dataset |
|
:param label_schema: The list of classes |
|
""" |
|
try: |
|
settings = rg.TextClassificationSettings(label_schema=label_schema) |
|
rg.configure_dataset_settings(name=dataset_name, settings=settings) |
|
except: |
|
raise Exception("Failed to create dataset") |
|
|
|
@staticmethod |
|
def upload_missing_records(dataset_name: str, batch_id: int, batch_records: pd.DataFrame): |
|
""" |
|
Update the Argilla dataset by adding missing records from batch_id that appears in batch_records |
|
:param dataset_name: The dataset name |
|
:param batch_id: The batch id |
|
:param batch_records: A dataframe of the batch records |
|
""" |
|
|
|
query = "metadata.batch_id:{}".format(batch_id) |
|
result = rg.load(name=dataset_name, query=query) |
|
df = result.to_pandas() |
|
if len(df) == len(batch_records): |
|
return |
|
if df.empty: |
|
upload_df = batch_records |
|
else: |
|
merged_df = pd.merge(batch_records, df['text'], on='text', how='left', indicator=True) |
|
upload_df = merged_df[merged_df['_merge'] == 'left_only'].drop(columns=['_merge']) |
|
record_list = [] |
|
for index, row in upload_df.iterrows(): |
|
config = {'text': row['text'], 'metadata': {"batch_id": row['batch_id'], 'id': row['id']}, "id": row['id']} |
|
|
|
|
|
if not(row[['annotation']].isnull().any()): |
|
config['annotation'] = row['annotation'] |
|
record_list.append(rg.TextClassificationRecord(**config)) |
|
rg.log(records=record_list, name=dataset_name) |
|
|
|
def calc_usage(self): |
|
""" |
|
Dummy function to calculate the usage of the estimator |
|
""" |
|
return 0 |
|
|
|
def apply(self, dataset: DatasetBase, batch_id: int): |
|
""" |
|
Apply the estimator on the dataset. The function enter to infinite loop until all the records are annotated. |
|
Then it update the dataset with all the annotations |
|
:param dataset: DatasetBase object, contains all the processed records |
|
:param batch_id: The batch id to annotate |
|
""" |
|
current_api = active_client() |
|
try: |
|
rg_dataset = current_api.datasets.find_by_name(dataset.name) |
|
except: |
|
self.initialize_dataset(dataset.name, dataset.label_schema) |
|
rg_dataset = current_api.datasets.find_by_name(dataset.name) |
|
batch_records = dataset[batch_id] |
|
if batch_records.empty: |
|
return [] |
|
self.upload_missing_records(dataset.name, batch_id, batch_records) |
|
data = {'metadata': {'batch_id': [str(batch_id)]}} |
|
json_data = json.dumps(data) |
|
encoded_bytes = base64.b64encode(json_data.encode('utf-8')) |
|
encoded_string = str(encoded_bytes, "utf-8") |
|
url_link = self.opt.api_url + '/datasets/' + self.opt.workspace + '/' \ |
|
+ dataset.name + '?query=' + encoded_string |
|
print(f"{Color.GREEN}Waiting for annotations from batch {batch_id}:\n{url_link}{Color.END}") |
|
webbrowser.open(url_link) |
|
while True: |
|
query = "(status:Validated OR status:Discarded) AND metadata.batch_id:{}".format(batch_id) |
|
search_results = current_api.search.search_records( |
|
name=dataset.name, |
|
task=rg_dataset.task, |
|
size=0, |
|
query_text=query, |
|
) |
|
if search_results.total == len(batch_records): |
|
result = rg.load(name=dataset.name, query=query) |
|
df = result.to_pandas()[['text', 'annotation', 'metadata', 'status']] |
|
df["annotation"] = df.apply(lambda x: 'Discarded' if x['status']=='Discarded' else x['annotation'], axis=1) |
|
df = df.drop(columns=['status']) |
|
df['id'] = df.apply(lambda x: x['metadata']['id'], axis=1) |
|
return df |
|
time.sleep(self.time_interval) |
|
|