johaunh
Add streamlit app
3273f67
import re
from typing import List
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from nltk.tokenize import sent_tokenize
from . import MODULES_CONFIG, STOP_WORDS
from .config import Config
from .output_parsers import ClauseParser, TripletParser
from .utils import partition_sentences
# llm output parsers
PARSERS = {
"StrOutputParser": StrOutputParser(),
"ClauseParser": ClauseParser(),
"TripletParser": TripletParser()
}
class Module:
"""Words2Wisdom module class."""
def __init__(self, name: str) -> None:
self.name = name
self.parser = self.get_parser()
self.prompts = self.get_prompts()
self.type = self.get_module_type()
def __repr__(self):
return self.name.replace("_", " ").title().replace(" ", "") + "()"
def get_prompts(self):
return ChatPromptTemplate.from_messages(MODULES_CONFIG[self.name]["prompts"].items())
def get_parser(self):
return PARSERS.get(MODULES_CONFIG[self.name]["parser"], StrOutputParser())
def get_module_type(self):
return MODULES_CONFIG[self.name]["type"]
class Pipeline:
"""Words2Wisdom pipeline class."""
def __init__(self, config: Config):
self.config = config
self.initialize(config)
def __call__(self, text: str, clean: bool=True) -> pd.DataFrame:
return self.run(text, clean)
def __repr__(self) -> str:
return f"Words2Wisdom(\n\tconfig.pipeline={self.config.pipeline}\n\tconfig.llm={self.config.llm}\n)"
def __str__(self) -> str:
return ("[INPUT: text] -> "
+ " -> ".join([str(m) for m in self.modules])
+ " -> [OUTPUT: knowledge graph]")
@classmethod
def from_ini(cls, config_path: str):
return cls(Config.read_ini(config_path))
def initialize(self, config: Config):
"""Initialize Words2Wisdom pipeline from config."""
# validate preprocess
preprocess_modules = [Module(name) for name in config.pipeline["preprocess"]]
for item in preprocess_modules:
if item.get_module_type() != "preprocess":
raise ValueError(f"Expected preprocess step `{item.name}` to"
f" have module type='preprocess'. Consider reviewing"
f" schema.yml")
# validate extraction process
extraction_module = Module(config.pipeline["extraction"])
if extraction_module.get_module_type() != "extraction":
raise ValueError(f"Expected `{extraction_module.name}` to have module"
f" type='extraction'. Consider reviewing schema.yml")
# combine preprocess + extraction
self.modules = preprocess_modules + [extraction_module]
# init prompts & parsers
prompts = [m.get_prompts() for m in self.modules]
parsers = [m.get_parser() for m in self.modules]
# init llm
llm = ChatOpenAI(**self.config.llm)
# init chains
chains = [(prompt | llm | parser)
for prompt, parser in zip(prompts, parsers)]
# stitch chains together
self.pipeline = {"text": RunnablePassthrough()} | chains[0]
for i in range(1, len(chains)):
self.pipeline = {"text": self.pipeline} | chains[i]
# print pipeline
print("Initialized Words2Wisdom pipeline:")
print(str(self))
def run(self, text: str, clean=True) -> tuple[List[str], pd.DataFrame]:
"""Run Words2Wisdom pipeline on passed text.
Args:
*texts (str): The text inputs
clean (bool): Whether to clean the raw KG or not
Returns:
text_batches (list): Batched text
knowledge_graph (DataFrame): A dataframe containing the extracted KG triplets,
indexed by batch
"""
print("Running Words2Wisdom pipeline:")
# split text into batches
text_batches = list(partition_sentences(
sentences=sent_tokenize(text),
min_words=self.config.pipeline["words_per_batch"]
))
# run pipeline in parallel; convert to dataframe
print("Extracting knowledge graph...", end=' ')
output = self.pipeline.batch(text_batches)
knowledge_graph = pd.DataFrame([{'batch_id': i, **triplet}
for i, batch in enumerate(output)
for triplet in batch])
if clean:
knowledge_graph = self._clean(knowledge_graph)
print("Done!", end='\n')
return text_batches, knowledge_graph
def _clean(self, kg: pd.DataFrame) -> pd.DataFrame:
"""Words2Wisdom post-processing."""
print("Cleaning knowledge graph components...", end=' ')
drop_list = []
for i, row in kg.iterrows():
# drop stopwords (e.g. pronouns)
if (row.subject in STOP_WORDS) or (row.object in STOP_WORDS):
drop_list.append(i)
# drop broken triplets
elif row.hasnans:
drop_list.append(i)
# lowercase nodes/edges, drop articles
else:
article_pattern = r'^(the|a|an) (.+)'
be_pattern = r'^(are|is) (a |an )?(.+)'
kg.at[i, "subject"] = re.sub(article_pattern, r'\2', row.subject.lower())
kg.at[i, "relation"] = re.sub(be_pattern, r'\3', row.relation.lower())
kg.at[i, "object"] = re.sub(article_pattern, r'\2', row.object.lower()).strip('.')
return kg.drop(drop_list)
def _normalize(self):
"""Unused."""
return
def serialize(self):
return self.config.serialize()