File size: 6,099 Bytes
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3273f67
4b9251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import re
from typing import List

import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from nltk.tokenize import sent_tokenize

from . import MODULES_CONFIG, STOP_WORDS
from .config import Config
from .output_parsers import ClauseParser, TripletParser
from .utils import partition_sentences


# llm output parsers
PARSERS = {
    "StrOutputParser": StrOutputParser(),
    "ClauseParser": ClauseParser(),
    "TripletParser": TripletParser()
}


class Module:
    """Words2Wisdom module class."""
    def __init__(self, name: str) -> None:
        self.name = name
        self.parser = self.get_parser()
        self.prompts = self.get_prompts()
        self.type = self.get_module_type()

    def __repr__(self):
        return self.name.replace("_", " ").title().replace(" ", "") + "()"

    def get_prompts(self):
        return ChatPromptTemplate.from_messages(MODULES_CONFIG[self.name]["prompts"].items())
    
    def get_parser(self):
        return PARSERS.get(MODULES_CONFIG[self.name]["parser"], StrOutputParser())
    
    def get_module_type(self):
        return MODULES_CONFIG[self.name]["type"]


class Pipeline:
    """Words2Wisdom pipeline class."""

    def __init__(self, config: Config):
        
        self.config = config
        self.initialize(config)


    def __call__(self, text: str, clean: bool=True) -> pd.DataFrame:
        return self.run(text, clean)
    

    def __repr__(self) -> str:
        return f"Words2Wisdom(\n\tconfig.pipeline={self.config.pipeline}\n\tconfig.llm={self.config.llm}\n)"
    

    def __str__(self) -> str:
        return ("[INPUT: text] -> " 
                + " -> ".join([str(m) for m in self.modules])
                + " -> [OUTPUT: knowledge graph]")

    
    @classmethod
    def from_ini(cls, config_path: str):
        return cls(Config.read_ini(config_path))
    
    
    def initialize(self, config: Config):
        """Initialize Words2Wisdom pipeline from config."""
        
        # validate preprocess
        preprocess_modules = [Module(name) for name in config.pipeline["preprocess"]]
        
        for item in preprocess_modules:
            if item.get_module_type() != "preprocess":
                raise ValueError(f"Expected preprocess step `{item.name}` to"
                                 f" have module type='preprocess'. Consider reviewing"
                                 f" schema.yml")
        
        # validate extraction process
        extraction_module = Module(config.pipeline["extraction"])
        
        if extraction_module.get_module_type() != "extraction":
            raise ValueError(f"Expected `{extraction_module.name}` to have module"
                             f" type='extraction'. Consider reviewing schema.yml")

        # combine preprocess + extraction
        self.modules = preprocess_modules + [extraction_module]

        # init prompts & parsers
        prompts = [m.get_prompts() for m in self.modules]
        parsers = [m.get_parser() for m in self.modules]

        # init llm
        llm = ChatOpenAI(**self.config.llm)
        
        # init chains
        chains = [(prompt | llm | parser) 
                  for prompt, parser in zip(prompts, parsers)]

        # stitch chains together
        self.pipeline = {"text": RunnablePassthrough()} | chains[0]
        for i in range(1, len(chains)):
            self.pipeline = {"text": self.pipeline} | chains[i]
        
        # print pipeline
        print("Initialized Words2Wisdom pipeline:")
        print(str(self))

    
    def run(self, text: str, clean=True) -> tuple[List[str], pd.DataFrame]:
        """Run Words2Wisdom pipeline on passed text.
        
        Args:
            *texts (str): The text inputs
            clean (bool): Whether to clean the raw KG or not
        
        Returns:
            text_batches (list): Batched text
            knowledge_graph (DataFrame): A dataframe containing the extracted KG triplets, 
                indexed by batch
        """
        print("Running Words2Wisdom pipeline:")
        # split text into batches
        text_batches = list(partition_sentences(
            sentences=sent_tokenize(text), 
            min_words=self.config.pipeline["words_per_batch"]
        ))

        # run pipeline in parallel; convert to dataframe
        print("Extracting knowledge graph...", end=' ')
        output = self.pipeline.batch(text_batches)
        
        knowledge_graph = pd.DataFrame([{'batch_id': i, **triplet} 
                                        for i, batch in enumerate(output) 
                                        for triplet in batch])
        
        if clean:
            knowledge_graph = self._clean(knowledge_graph)
        
        print("Done!", end='\n')
        
        return text_batches, knowledge_graph


    def _clean(self, kg: pd.DataFrame) -> pd.DataFrame:
        """Words2Wisdom post-processing."""
        print("Cleaning knowledge graph components...", end=' ')
        drop_list = []

        for i, row in kg.iterrows():
            # drop stopwords (e.g. pronouns)
            if (row.subject in STOP_WORDS) or (row.object in STOP_WORDS):
                drop_list.append(i)

            # drop broken triplets
            elif row.hasnans:
                drop_list.append(i)
            
            # lowercase nodes/edges, drop articles
            else:
                article_pattern = r'^(the|a|an) (.+)'
                be_pattern = r'^(are|is) (a |an )?(.+)'

                kg.at[i, "subject"] = re.sub(article_pattern, r'\2', row.subject.lower())
                kg.at[i, "relation"] = re.sub(be_pattern, r'\3', row.relation.lower())
                kg.at[i, "object"] = re.sub(article_pattern, r'\2', row.object.lower()).strip('.')
        
        return kg.drop(drop_list)


    def _normalize(self):
        """Unused."""
        return
    
    def serialize(self):
        return self.config.serialize()