=
adding package
592bfb5
raw
history blame contribute delete
795 Bytes
import numpy as np
import evaluate
metrics = {
'f1': evaluate.load('f1'),
'accuracy': evaluate.load('accuracy'),
'roc_auc': evaluate.load('roc_auc', 'multiclass')
}
def compute_metrics(p): # some part was got from https://huggingface.co/blog/fine-tune-vit
predictions, label_ids = p
metric = metrics['accuracy'].compute(predictions = np.argmax(predictions, axis = 1), references=label_ids)
f1_score = metrics['f1'].compute(predictions=np.argmax(predictions, axis = 1), references=label_ids)
metric.update(f1_score)
try:
auc = metrics['roc_auc'].compute(prediction_scores=predictions, references=label_ids)
metric.update(auc)
except:
pass
return metric