Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments | |
from shared import CustomTokens, GeneralArguments | |
from dataclasses import dataclass, field | |
from typing import Optional, Union | |
import torch | |
import classify | |
import base64 | |
import re | |
import requests | |
import json | |
import logging | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
# Public innertube key (b64 encoded so that it is not incorrectly flagged) | |
INNERTUBE_KEY = base64.b64decode( | |
b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode() | |
YT_CONTEXT = { | |
'client': { | |
'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)', | |
'clientName': 'WEB', | |
'clientVersion': '2.20211221.00.00', | |
} | |
} | |
_YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)' | |
def get_all_channel_vids(channel_id): | |
continuation = None | |
while True: | |
if continuation is None: | |
params = {'list': channel_id.replace('UC', 'UU', 1)} | |
response = requests.get( | |
'https://www.youtube.com/playlist', params=params) | |
items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][ | |
'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents'] | |
else: | |
params = {'key': INNERTUBE_KEY} | |
data = { | |
'context': YT_CONTEXT, | |
'continuation': continuation | |
} | |
response = requests.post( | |
'https://www.youtube.com/youtubei/v1/browse', params=params, json=data) | |
items = response.json()[ | |
'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems'] | |
new_token = None | |
for vid in items: | |
info = vid.get('playlistVideoRenderer') | |
if info: | |
yield info['videoId'] | |
continue | |
info = vid.get('continuationItemRenderer') | |
if info: | |
new_token = info['continuationEndpoint']['continuationCommand']['token'] | |
if new_token is None: | |
break | |
continuation = new_token | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
""" | |
model_name_or_path: str = field( | |
default=None, | |
metadata={ | |
'help': 'Path to pretrained model or model identifier from huggingface.co/models' | |
} | |
) | |
cache_dir: Optional[str] = field( | |
default='models', | |
metadata={ | |
'help': 'Where to store the pretrained models downloaded from huggingface.co' | |
}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=True, | |
metadata={ | |
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.' | |
}, | |
) | |
model_revision: str = field( | |
default='main', | |
metadata={ | |
'help': 'The specific model version to use (can be a branch name, tag name or commit id).' | |
}, | |
) | |
use_auth_token: bool = field( | |
default=False, | |
metadata={ | |
'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script ' | |
'with private models).' | |
}, | |
) | |
import itertools | |
from errors import InferenceException, ModelLoadError | |
class InferenceArguments(ModelArguments): | |
model_name_or_path: str = field( | |
default='Xenova/sponsorblock-small', | |
metadata={ | |
'help': 'Path to pretrained model used for prediction' | |
} | |
) | |
classifier_model_name_or_path: str = field( | |
default='EColi/SB_Classifier', | |
metadata={ | |
'help': 'Use a pretrained classifier' | |
} | |
) | |
max_videos: Optional[int] = field( | |
default=None, | |
metadata={ | |
'help': 'The number of videos to test on' | |
} | |
) | |
start_index: int = field(default=None, metadata={ | |
'help': 'Video to start the evaluation at.'}) | |
channel_id: Optional[str] = field( | |
default=None, | |
metadata={ | |
'help': 'Used to evaluate a channel' | |
} | |
) | |
video_ids: str = field( | |
default_factory=lambda: [], | |
metadata={ | |
'nargs': '+' | |
} | |
) | |
output_as_json: bool = field(default=False, metadata={ | |
'help': 'Output evaluations as JSON'}) | |
min_probability: float = field( | |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'}) | |
def __post_init__(self): | |
self.video_ids = list(map(str.strip, self.video_ids)) | |
if any(len(video_id) != 11 for video_id in self.video_ids): | |
raise InferenceException('Invalid video IDs (length not 11)') | |
if self.channel_id is not None: | |
start = self.start_index or 0 | |
end = None if self.max_videos is None else start + self.max_videos | |
channel_video_ids = list(itertools.islice(get_all_channel_vids( | |
self.channel_id), start, end)) | |
logger.info( | |
f'Found {len(channel_video_ids)} for channel {self.channel_id}') | |
self.video_ids += channel_video_ids | |
def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_args: GeneralArguments): | |
original_path = inference_args.model_name_or_path | |
# Load main model and tokenizer | |
model, tokenizer = get_model_tokenizer(inference_args, general_args) | |
# Load classifier | |
inference_args.model_name_or_path = inference_args.classifier_model_name_or_path | |
classifier_model, classifier_tokenizer = get_model_tokenizer( | |
inference_args, general_args, model_type='classifier') | |
classifier = classify.SponsorBlockClassificationPipeline( | |
classifier_model, classifier_tokenizer) | |
# Reset to original model_name_or_path | |
inference_args.model_name_or_path = original_path | |
return model, tokenizer, classifier | |
def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'): | |
if model_args.model_name_or_path is None: | |
raise ModelLoadError('Must specify --model_name_or_path') | |
if config_args is None: | |
config_args = {} | |
use_auth_token = True if model_args.use_auth_token else None | |
config = AutoConfig.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
revision=model_args.model_revision, | |
use_auth_token=use_auth_token, | |
**config_args | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
use_fast=model_args.use_fast_tokenizer, | |
revision=model_args.model_revision, | |
use_auth_token=use_auth_token, | |
) | |
model_type = AutoModelForSeq2SeqLM if model_type == 'seq2seq' else AutoModelForSequenceClassification | |
model = model_type.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
cache_dir=model_args.cache_dir, | |
revision=model_args.model_revision, | |
use_auth_token=use_auth_token, | |
) | |
# Add custom tokens | |
CustomTokens.add_custom_tokens(tokenizer) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Potentially move model to gpu | |
if general_args is not None and not general_args.no_cuda: | |
model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
return model, tokenizer | |