akshatsanghvi commited on
Commit
70a65c2
·
1 Parent(s): 5ba7c45

Create named_entity_recognizer.py

Browse files
characters/named_entity_recognizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pathlib
4
+ import spacy
5
+ import pandas as pd
6
+ from ast import literal_eval
7
+ from nltk.tokenize import sent_tokenize
8
+
9
+ folder_path = pathlib.Path().parent.resolve()
10
+ sys.path.append(os.path.join(folder_path, "../"))
11
+
12
+ from utils import load_subs
13
+
14
+ class NamedEntityRecognizer:
15
+ def __init__(self):
16
+
17
+ self.model = self.load_model()
18
+
19
+ def load_model(self):
20
+ nlp = spacy.load("en_core_web_trf")
21
+ return nlp
22
+
23
+ def get_chars_inference(self, script):
24
+ script_sents = sent_tokenize(script)
25
+ chars = []
26
+
27
+ for sent in script_sents:
28
+ doc = self.model(sent)
29
+ char = set()
30
+
31
+ for entity in doc.ents:
32
+ if entity.label_ == "PERSON":
33
+ name = entity.text.strip().split(" ")[0]
34
+ char.add(name)
35
+
36
+ chars.append(char)
37
+
38
+ return chars
39
+
40
+ def get_chars(self, dataset_path, save_path=None):
41
+
42
+ if save_path and not save_path.endswith(".csv"):
43
+ save_path += "chars.csv"
44
+
45
+ if save_path and os.path.exists(save_path):
46
+ df = pd.read_csv(save_path)
47
+ df["chars"] = df["chars"].apply(lambda x: literal_eval(x) if isinstance(x, str) else x)
48
+ return df
49
+
50
+ df = load_subs(dataset_path)
51
+
52
+ df["chars"] = df["script"].apply(self.get_chars_inference)
53
+
54
+ if save_path:
55
+ df.to_csv(save_path, index=False)
56
+
57
+ return df