File size: 1,469 Bytes
70a65c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import sys
import pathlib
import spacy
import pandas as pd
from ast import literal_eval
from nltk.tokenize import sent_tokenize

folder_path = pathlib.Path().parent.resolve()
sys.path.append(os.path.join(folder_path, "../"))

from utils import load_subs

class NamedEntityRecognizer:
    def __init__(self):
        
        self.model = self.load_model()

    def load_model(self):
        nlp = spacy.load("en_core_web_trf")
        return nlp
    
    def get_chars_inference(self, script):
        script_sents = sent_tokenize(script)
        chars = [] 

        for sent in script_sents:
            doc = self.model(sent)
            char = set()

            for entity in doc.ents:
                if entity.label_ == "PERSON":
                    name  = entity.text.strip().split(" ")[0]
                    char.add(name)

            chars.append(char)

        return chars
    
    def get_chars(self, dataset_path, save_path=None):

        if save_path and not save_path.endswith(".csv"):
            save_path += "chars.csv"
        
        if save_path and os.path.exists(save_path):
            df = pd.read_csv(save_path)
            df["chars"] = df["chars"].apply(lambda x: literal_eval(x) if isinstance(x, str) else x)
            return df

        df = load_subs(dataset_path)

        df["chars"] = df["script"].apply(self.get_chars_inference)

        if save_path:
            df.to_csv(save_path, index=False)

        return df