himym-analysis / classifier /classifier.py
akshatsanghvi's picture
Fix Bugs
600c17e
raw
history blame
2.4 kB
import os
import sys
import pathlib
import nltk
import torch
import numpy as np
import pandas as pd
from transformers import pipeline
from nltk.tokenize import sent_tokenize
folder_name = pathlib.Path(__file__).parent.resolve()
sys.path.append(os.path.join(folder_name, "../"))
from utils import load_subs
nltk.download("punkt")
nltk.download("punkt_tab")
class ThemeClassifier:
def __init__(self, theme_list):
self.model = "facebook/bart-large-mnli"
self.device = 0 if torch.cuda.is_available() else "cpu"
self.theme_list = theme_list
self.theme_classifier = self.load_model(self.device)
def load_model(self, device):
clf = pipeline("zero-shot-classification",
model=self.model,
device=device)
return clf
def get_theme_inference(self, script):
script_sentences = sent_tokenize(script)
sentence_batch_size = 20
script_batches = []
for index in range(0, len(script_sentences), sentence_batch_size):
script_batches.append(" ".join(script_sentences[index:index + sentence_batch_size]))
theme_output = self.theme_classifier(
script_batches,
self.theme_list,
multi_label=True
)
themes = {}
for output in theme_output:
for label, score in zip(output["labels"], output["scores"]):
if label not in themes:
themes[label] = []
themes[label].append(score)
themes = {key:np.mean(np.array(value)) for key, value in themes.items()}
return themes
def get_themes(self, path, save_path=None):
if save_path and not save_path.endswith(".csv"):
save_path += "series.csv"
# Read Saved Output, if Exists
if save_path is not None and os.path.exists(save_path):
df = pd.read_csv(save_path)
if set(df.columns) == set(self.theme_list):
return df
# Load dataset
df = load_subs(path)
# Run Inference
op = df["script"].apply(self.get_theme_inference)
theme_df = pd.DataFrame(op.tolist())
df[theme_df.columns] = theme_df
# Save Output
if save_path:
df.to_csv(save_path, index=False)
return df