Spaces:
Sleeping
Sleeping
from model import get_model_tokenizer_classifier, InferenceArguments | |
from utils import jaccard, safe_print | |
from transformers import HfArgumentParser | |
from preprocess import get_words, clean_text | |
from shared import GeneralArguments, DatasetArguments | |
from predict import predict | |
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words | |
import pandas as pd | |
from dataclasses import dataclass, field | |
from typing import Optional | |
from tqdm import tqdm | |
import json | |
import os | |
import random | |
from shared import seconds_to_time | |
from urllib.parse import quote | |
import logging | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
class EvaluationArguments(InferenceArguments): | |
"""Arguments pertaining to how evaluation will occur.""" | |
output_file: Optional[str] = field( | |
default='metrics.csv', | |
metadata={ | |
'help': 'Save metrics to output file' | |
} | |
) | |
skip_missing: bool = field( | |
default=False, | |
metadata={ | |
'help': 'Whether to skip checking for missing segments. If False, predictions will be made.' | |
} | |
) | |
skip_incorrect: bool = field( | |
default=False, | |
metadata={ | |
'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.' | |
} | |
) | |
def attach_predictions_to_sponsor_segments(predictions, sponsor_segments): | |
"""Attach sponsor segments to closest prediction""" | |
for prediction in predictions: | |
prediction['best_overlap'] = 0 | |
prediction['best_sponsorship'] = None | |
# Assign predictions to actual (labelled) sponsored segments | |
for sponsor_segment in sponsor_segments: | |
j = jaccard(prediction['start'], prediction['end'], | |
sponsor_segment['start'], sponsor_segment['end']) | |
if prediction['best_overlap'] < j: | |
prediction['best_overlap'] = j | |
prediction['best_sponsorship'] = sponsor_segment | |
return sponsor_segments | |
def calculate_metrics(labelled_words, predictions): | |
metrics = { | |
'true_positive': 0, # Is sponsor, predicted sponsor | |
# Is sponsor, predicted not sponsor (i.e., missed it - bad) | |
'false_negative': 0, | |
# Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards) | |
'false_positive': 0, | |
'true_negative': 0, # Is not sponsor, predicted not sponsor | |
} | |
metrics['video_duration'] = word_end( | |
labelled_words[-1])-word_start(labelled_words[0]) | |
for index, word in enumerate(labelled_words): | |
if index >= len(labelled_words) - 1: | |
continue | |
duration = word_end(word) - word_start(word) | |
predicted_sponsor = False | |
for p in predictions: | |
# Is in some prediction | |
if p['start'] <= word['start'] <= p['end']: | |
predicted_sponsor = True | |
break | |
if predicted_sponsor: | |
# total_positive_time += duration | |
if word.get('category') is not None: # Is actual sponsor | |
metrics['true_positive'] += duration | |
else: | |
metrics['false_positive'] += duration | |
else: | |
# total_negative_time += duration | |
if word.get('category') is not None: # Is actual sponsor | |
metrics['false_negative'] += duration | |
else: | |
metrics['true_negative'] += duration | |
# NOTE In cases where we encounter division by 0, we say that the value is 1 | |
# https://stats.stackexchange.com/a/1775 | |
# (Precision) TP+FP=0: means that all instances were predicted as negative | |
# (Recall) TP+FN=0: means that there were no positive cases in the input data | |
# The fraction of predictions our model got right | |
# Can simplify, but use full formula | |
z = metrics['true_positive'] + metrics['true_negative'] + \ | |
metrics['false_positive'] + metrics['false_negative'] | |
metrics['accuracy'] = ( | |
(metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1 | |
# What proportion of positive identifications was actually correct? | |
z = metrics['true_positive'] + metrics['false_positive'] | |
metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1 | |
# What proportion of actual positives was identified correctly? | |
z = metrics['true_positive'] + metrics['false_negative'] | |
metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1 | |
# https://deepai.org/machine-learning-glossary-and-terms/f-score | |
s = metrics['precision'] + metrics['recall'] | |
metrics['f-score'] = (2 * (metrics['precision'] * | |
metrics['recall']) / s) if s > 0 else 0 | |
return metrics | |
def main(): | |
logger.setLevel(logging.DEBUG) | |
hf_parser = HfArgumentParser(( | |
EvaluationArguments, | |
DatasetArguments, | |
SegmentationArguments, | |
GeneralArguments | |
)) | |
evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses() | |
if evaluation_args.skip_missing and evaluation_args.skip_incorrect: | |
logger.error('ERROR: Nothing to do') | |
return | |
# Load labelled data: | |
final_path = os.path.join( | |
dataset_args.data_dir, dataset_args.processed_file) | |
if not os.path.exists(final_path): | |
logger.error('ERROR: Processed database not found.\n' | |
f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".') | |
return | |
model, tokenizer, classifier = get_model_tokenizer_classifier( | |
evaluation_args, general_args) | |
with open(final_path) as fp: | |
final_data = json.load(fp) | |
if evaluation_args.video_ids: # Use specified | |
video_ids = evaluation_args.video_ids | |
else: # Use items found in preprocessed database | |
video_ids = list(final_data.keys()) | |
random.shuffle(video_ids) | |
if evaluation_args.start_index is not None: | |
video_ids = video_ids[evaluation_args.start_index:] | |
if evaluation_args.max_videos is not None: | |
video_ids = video_ids[:evaluation_args.max_videos] | |
out_metrics = [] | |
all_metrics = {} | |
if not evaluation_args.skip_missing: | |
all_metrics['total_prediction_accuracy'] = 0 | |
all_metrics['total_prediction_precision'] = 0 | |
all_metrics['total_prediction_recall'] = 0 | |
all_metrics['total_prediction_fscore'] = 0 | |
if not evaluation_args.skip_incorrect: | |
all_metrics['classifier_segment_correct'] = 0 | |
all_metrics['classifier_segment_count'] = 0 | |
metric_count = 0 | |
postfix_info = {} | |
try: | |
with tqdm(video_ids) as progress: | |
for video_index, video_id in enumerate(progress): | |
progress.set_description(f'Processing {video_id}') | |
words = get_words(video_id) | |
if not words: | |
continue | |
# Get labels | |
sponsor_segments = final_data.get(video_id) | |
# Reset previous | |
missed_segments = [] | |
incorrect_segments = [] | |
current_metrics = { | |
'video_id': video_id | |
} | |
metric_count += 1 | |
if not evaluation_args.skip_missing: # Make predictions | |
predictions = predict(video_id, model, tokenizer, segmentation_args, | |
classifier=classifier, | |
min_probability=evaluation_args.min_probability) | |
if sponsor_segments: | |
labelled_words = add_labels_to_words( | |
words, sponsor_segments) | |
current_metrics.update( | |
calculate_metrics(labelled_words, predictions)) | |
all_metrics['total_prediction_accuracy'] += current_metrics['accuracy'] | |
all_metrics['total_prediction_precision'] += current_metrics['precision'] | |
all_metrics['total_prediction_recall'] += current_metrics['recall'] | |
all_metrics['total_prediction_fscore'] += current_metrics['f-score'] | |
# Just for display purposes | |
postfix_info.update({ | |
'accuracy': all_metrics['total_prediction_accuracy']/metric_count, | |
'precision': all_metrics['total_prediction_precision']/metric_count, | |
'recall': all_metrics['total_prediction_recall']/metric_count, | |
'f-score': all_metrics['total_prediction_fscore']/metric_count, | |
}) | |
sponsor_segments = attach_predictions_to_sponsor_segments( | |
predictions, sponsor_segments) | |
# Identify possible issues: | |
for prediction in predictions: | |
if prediction['best_sponsorship'] is not None: | |
continue | |
prediction_words = prediction.pop('words', []) | |
# Attach original text to missed segments | |
prediction['text'] = ' '.join( | |
x['text'] for x in prediction_words) | |
missed_segments.append(prediction) | |
else: | |
# Not in database (all segments missed) | |
missed_segments = predictions | |
if not evaluation_args.skip_incorrect and sponsor_segments: | |
# Check for incorrect segments using the classifier | |
segments_to_check = [] | |
cleaned_texts = [] # Texts to send through tokenizer | |
for sponsor_segment in sponsor_segments: | |
segment_words = extract_segment( | |
words, sponsor_segment['start'], sponsor_segment['end']) | |
sponsor_segment['text'] = ' '.join( | |
x['text'] for x in segment_words) | |
duration = sponsor_segment['end'] - \ | |
sponsor_segment['start'] | |
wps = (len(segment_words) / | |
duration) if duration > 0 else 0 | |
if wps < 1.5: | |
continue | |
# Do not worry about those that are locked or have enough votes | |
# or segment['votes'] > 5: | |
if sponsor_segment['locked']: | |
continue | |
cleaned_texts.append( | |
clean_text(sponsor_segment['text'])) | |
segments_to_check.append(sponsor_segment) | |
if segments_to_check: # Some segments to check | |
segments_scores = classifier(cleaned_texts) | |
num_correct = 0 | |
for segment, scores in zip(segments_to_check, segments_scores): | |
fixed_scores = { | |
score['label']: score['score'] | |
for score in scores | |
} | |
all_metrics['classifier_segment_count'] += 1 | |
prediction = max(scores, key=lambda x: x['score']) | |
predicted_category = prediction['label'].lower() | |
if predicted_category == segment['category']: | |
num_correct += 1 | |
continue # Ignore correct segments | |
segment.update({ | |
'predicted': predicted_category, | |
'scores': fixed_scores | |
}) | |
incorrect_segments.append(segment) | |
current_metrics['num_segments'] = len( | |
segments_to_check) | |
current_metrics['classified_correct'] = num_correct | |
all_metrics['classifier_segment_correct'] += num_correct | |
if all_metrics['classifier_segment_count'] > 0: | |
postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \ | |
all_metrics['classifier_segment_count'] | |
out_metrics.append(current_metrics) | |
progress.set_postfix(postfix_info) | |
if missed_segments or incorrect_segments: | |
if evaluation_args.output_as_json: | |
to_print = {'video_id': video_id} | |
if missed_segments: | |
to_print['missed'] = missed_segments | |
if incorrect_segments: | |
to_print['incorrect'] = incorrect_segments | |
safe_print(json.dumps(to_print)) | |
else: | |
safe_print( | |
f'Issues identified for {video_id} (#{video_index})') | |
# Potentially missed segments (model predicted, but not in database) | |
if missed_segments: | |
safe_print(' - Missed segments:') | |
segments_to_submit = [] | |
for i, missed_segment in enumerate(missed_segments, start=1): | |
safe_print(f'\t#{i}:', seconds_to_time( | |
missed_segment['start']), '-->', seconds_to_time(missed_segment['end'])) | |
safe_print('\t\tText: "', | |
missed_segment['text'], '"', sep='') | |
safe_print('\t\tCategory:', | |
missed_segment.get('category')) | |
if 'probability' in missed_segment: | |
safe_print('\t\tProbability:', | |
missed_segment['probability']) | |
segments_to_submit.append({ | |
'segment': [missed_segment['start'], missed_segment['end']], | |
'category': missed_segment['category'].lower(), | |
'actionType': 'skip' | |
}) | |
json_data = quote(json.dumps(segments_to_submit)) | |
safe_print( | |
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}') | |
# Incorrect segments (in database, but incorrectly classified) | |
if incorrect_segments: | |
safe_print(' - Incorrect segments:') | |
for i, incorrect_segment in enumerate(incorrect_segments, start=1): | |
safe_print(f'\t#{i}:', seconds_to_time( | |
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end'])) | |
safe_print( | |
'\t\tText: "', incorrect_segment['text'], '"', sep='') | |
safe_print( | |
'\t\tUUID:', incorrect_segment['uuid']) | |
safe_print( | |
'\t\tVotes:', incorrect_segment['votes']) | |
safe_print( | |
'\t\tViews:', incorrect_segment['views']) | |
safe_print('\t\tLocked:', | |
incorrect_segment['locked']) | |
safe_print('\t\tCurrent Category:', | |
incorrect_segment['category']) | |
safe_print('\t\tPredicted Category:', | |
incorrect_segment['predicted']) | |
safe_print('\t\tProbabilities:') | |
for label, score in incorrect_segment['scores'].items(): | |
safe_print( | |
f"\t\t\t{label}: {score}") | |
safe_print() | |
except KeyboardInterrupt: | |
pass | |
df = pd.DataFrame(out_metrics) | |
df.to_csv(evaluation_args.output_file) | |
logger.info(df.mean()) | |
if __name__ == '__main__': | |
main() | |