Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import pathlib | |
import spacy | |
import pandas as pd | |
from ast import literal_eval | |
from nltk.tokenize import sent_tokenize | |
folder_path = pathlib.Path().parent.resolve() | |
sys.path.append(os.path.join(folder_path, "../")) | |
from utils import load_subs | |
class NamedEntityRecognizer: | |
def __init__(self): | |
self.model = self.load_model() | |
def load_model(self): | |
nlp = spacy.load("en_core_web_trf") | |
return nlp | |
def get_chars_inference(self, script): | |
script_sents = sent_tokenize(script) | |
chars = [] | |
for sent in script_sents: | |
doc = self.model(sent) | |
char = set() | |
for entity in doc.ents: | |
if entity.label_ == "PERSON": | |
name = entity.text.strip().split(" ")[0] | |
char.add(name) | |
chars.append(char) | |
return chars | |
def get_chars(self, dataset_path, save_path=None): | |
if save_path and not save_path.endswith(".csv"): | |
save_path += "chars.csv" | |
if save_path and os.path.exists(save_path): | |
df = pd.read_csv(save_path) | |
df["chars"] = df["chars"].apply(lambda x: literal_eval(x) if isinstance(x, str) else x) | |
return df | |
df = load_subs(dataset_path) | |
df["chars"] = df["script"].apply(self.get_chars_inference) | |
if save_path: | |
df.to_csv(save_path, index=False) | |
return df |