Spaces:
Runtime error
Runtime error
from typing import List | |
import torch | |
from datasets import Dataset | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import PerceiverTokenizer | |
def _map_outputs(predictions): | |
""" | |
Map model outputs to classes. | |
:param predictions: model ouptut batch | |
:return: | |
""" | |
labels = [ | |
"admiration", | |
"amusement", | |
"anger", | |
"annoyance", | |
"approval", | |
"caring", | |
"confusion", | |
"curiosity", | |
"desire", | |
"disappointment", | |
"disapproval", | |
"disgust", | |
"embarrassment", | |
"excitement", | |
"fear", | |
"gratitude", | |
"grief", | |
"joy", | |
"love", | |
"nervousness", | |
"optimism", | |
"pride", | |
"realization", | |
"relief", | |
"remorse", | |
"sadness", | |
"surprise", | |
"neutral" | |
] | |
classes = [] | |
for i, example in enumerate(predictions): | |
out_batch = [] | |
for j, category in enumerate(example): | |
out_batch.append(labels[j]) if category > 0.5 else None | |
classes.append(out_batch) | |
return classes | |
class MultiLabelPipeline: | |
""" | |
Multi label classification pipeline. | |
""" | |
def __init__(self, model_path): | |
""" | |
Init MLC pipeline. | |
:param model_path: model to use | |
""" | |
# Init attributes | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if self.device == 'cuda': | |
self.model = torch.load(model_path).eval().to(self.device) | |
else: | |
self.model = torch.load(model_path, map_location=torch.device('cpu')).eval().to(self.device) | |
self.tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver') | |
def __call__(self, dataset, batch_size: int = 4): | |
""" | |
Processing pipeline. | |
:param dataset: dataset | |
:return: | |
""" | |
# Tokenize inputs | |
dataset = dataset.map(lambda row: self.tokenizer(row['text'], padding="max_length", truncation=True), | |
batched=True, remove_columns=['text'], desc='Tokenizing') | |
dataset.set_format('torch', columns=['input_ids', 'attention_mask']) | |
dataloader = DataLoader(dataset, batch_size=batch_size) | |
# Define output classes | |
classes = [] | |
mem_logs = [] | |
with tqdm(dataloader, unit='batches') as progression: | |
for batch in progression: | |
progression.set_description('Inference') | |
# Forward | |
outputs = self.model(inputs=batch['input_ids'].to(self.device), | |
attention_mask=batch['attention_mask'].to(self.device), ) | |
# Outputs | |
predictions = outputs.logits.cpu().detach().numpy() | |
# Map predictions to classes | |
batch_classes = _map_outputs(predictions) | |
for row in batch_classes: | |
classes.append(row) | |
# Retrieve memory usage | |
memory = round(torch.cuda.memory_reserved(self.device) / 1e9, 2) | |
mem_logs.append(memory) | |
# Update pbar | |
progression.set_postfix(memory=f"{round(sum(mem_logs) / len(mem_logs), 2)}Go") | |
return classes | |
def inputs_to_dataset(inputs: List[str]): | |
""" | |
Convert a list of strings to a dataset object. | |
:param inputs: list of strings | |
:return: | |
""" | |
inputs = {'text': [input for input in inputs]} | |
return Dataset.from_dict(inputs) | |