Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import pathlib | |
import nltk | |
import torch | |
import numpy as np | |
import pandas as pd | |
from transformers import pipeline | |
from nltk.tokenize import sent_tokenize | |
folder_name = pathlib.Path(__file__).parent.resolve() | |
sys.path.append(os.path.join(folder_name, "../")) | |
from utils import load_subs | |
nltk.download("punkt") | |
nltk.download("punkt_tab") | |
class ThemeClassifier: | |
def __init__(self, theme_list): | |
self.model = "facebook/bart-large-mnli" | |
self.device = 0 if torch.cuda.is_available() else "cpu" | |
self.theme_list = theme_list | |
self.theme_classifier = self.load_model(self.device) | |
def load_model(self, device): | |
clf = pipeline("zero-shot-classification", | |
model=self.model, | |
device=device) | |
return clf | |
def get_theme_inference(self, script): | |
script_sentences = sent_tokenize(script) | |
sentence_batch_size = 20 | |
script_batches = [] | |
for index in range(0, len(script_sentences), sentence_batch_size): | |
script_batches.append(" ".join(script_sentences[index:index + sentence_batch_size])) | |
theme_output = self.theme_classifier( | |
script_batches, | |
self.theme_list, | |
multi_label=True | |
) | |
themes = {} | |
for output in theme_output: | |
for label, score in zip(output["labels"], output["scores"]): | |
if label not in themes: | |
themes[label] = [] | |
themes[label].append(score) | |
themes = {key:np.mean(np.array(value)) for key, value in themes.items()} | |
return themes | |
def get_themes(self, path, save_path=None): | |
if save_path and not save_path.endswith(".csv"): | |
save_path += "series.csv" | |
# Read Saved Output, if Exists | |
if save_path is not None and os.path.exists(save_path): | |
df = pd.read_csv(save_path) | |
if set(df.columns) == set(self.theme_list): | |
return df | |
# Load dataset | |
df = load_subs(path) | |
# Run Inference | |
op = df["script"].apply(self.get_theme_inference) | |
theme_df = pd.DataFrame(op.tolist()) | |
df[theme_df.columns] = theme_df | |
# Save Output | |
if save_path: | |
df.to_csv(save_path, index=False) | |
return df | |