Spaces:
Build error
Build error
import numpy as np | |
import evaluate | |
metrics = { | |
'f1': evaluate.load('f1'), | |
'accuracy': evaluate.load('accuracy'), | |
'roc_auc': evaluate.load('roc_auc', 'multiclass') | |
} | |
def compute_metrics(p): # some part was got from https://huggingface.co/blog/fine-tune-vit | |
predictions, label_ids = p | |
metric = metrics['accuracy'].compute(predictions = np.argmax(predictions, axis = 1), references=label_ids) | |
f1_score = metrics['f1'].compute(predictions=np.argmax(predictions, axis = 1), references=label_ids) | |
metric.update(f1_score) | |
try: | |
auc = metrics['roc_auc'].compute(prediction_scores=predictions, references=label_ids) | |
metric.update(auc) | |
except: | |
pass | |
return metric | |