johaunh
Add streamlit app
3273f67
import os
import re
from time import time
from argparse import ArgumentParser
from typing import List
from zipfile import ZipFile
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from . import VALIDATION_CONFIG, OUTPUT_DIR
from .output_parsers import QuestionOutputParser
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"run_ids", nargs="+", help="Run IDs to evaluate. Format: YYYY-MM-DD-XXX"
)
parser.add_argument(
"--search_dir", help="Directory to search for output", default=OUTPUT_DIR
)
return parser.parse_args()
def format_system_prompt():
"""Format instructional prompt from instructions.yml"""
def format_question(question: dict):
# Question format: {text} {additional} {options}
# Ex. Can pigs fly? Explain. (Yes/No)
formatted = (
question["title"]
+ " "
+ question["text"]
+ " "
+ (question["additional"] + " " if question["additional"] else "")
+ "("
+ "; ".join(question["options"])
+ ")\n"
)
return formatted
def format_example(example: dict):
formatted = (
f"PASSAGE: {example['passage']}\n"
f"TRIPLET: {example['triplet']}\n\n"
+ "".join([
f"{i}) {answer}\n"
for i, answer in enumerate(example["answers"].values(), start=1)
])
)
return formatted
instruction = VALIDATION_CONFIG["instruction"]
questions = "".join([
f"{i}) {format_question(q)}"
for i, q in enumerate(VALIDATION_CONFIG["questions"].values(), start=1)
])
example = format_example(VALIDATION_CONFIG["example"])
system_prompt = (
f"{instruction}\n" f"QUESTIONS:\n{questions}\n" f"[* EXAMPLE *]\n\n{example}"
)
return system_prompt
def validate_triplets(
llm: ChatOpenAI, instruction: str, passage: str, triplets: List[List[str]]
) -> List[pd.DataFrame]:
"""Validate triplets with respect to passage."""
print(
f"Validating {len(triplets):>2} triplet{'s' if len(triplets) else ''}...",
end=" ",
)
prompt = ChatPromptTemplate.from_messages([
("system", "{instruction}"),
("user", "PASSAGE: {passage}\n\nTRIPLET: {triplet}"),
])
chain = prompt | llm | QuestionOutputParser()
output = chain.batch([
{"instruction": instruction, "passage": passage, "triplet": triplet}
for triplet in triplets
])
print("Done!", end="\n")
return output
def validate_knowledge_graph(llm: ChatOpenAI, output_zip: str):
"""Validate all triplets in a knowledge graph."""
run_id = re.findall(r"output-(.*)\.zip", output_zip)[0]
run_dir = os.path.dirname(output_zip)
# read output zip
with ZipFile(output_zip) as z:
# load knowledge graph
with z.open("kg.csv") as f:
graph = pd.read_csv(f)
# load text batches
with z.open("text_batches.csv") as f:
text_batches = pd.read_csv(f)
print("Initializing knowledge graph validation. Run:", run_id)
print()
# start stopwatch
start = time()
# container for evaluation responses
responses = []
# instructions
instruction = format_system_prompt()
# triplets are batched by passage
# so we iterate over passages
for idx, passage in enumerate(text_batches.text):
triplets = (
graph[graph["batch_id"] == idx].drop(columns=["batch_id"]).values.tolist()
)
print(f"Starting excerpt {idx + 1:>2} of {len(text_batches)}.", end=" ")
# if batch has no triplets to validate, skip batch
if len(triplets) == 0:
print("Excerpt has no triplets to validate.", end="\n")
continue
# validate triplets by batch
response = validate_triplets(
llm=llm,
instruction=instruction,
passage=passage,
triplets=triplets
)
responses.extend(response)
validation = pd.concat(responses, ignore_index=True)
# merge with knowlege graph data
validation_merged = (
text_batches.merge(graph)
.drop(columns=["batch_id"])
.merge(validation, left_index=True, right_index=True)
)
savepath = os.path.join(run_dir, f"validation-{run_id}.csv")
validation_merged.to_csv(savepath, index=False)
# stop stopwatch
end = time()
print("\nKnowledge graph validation complete!")
print(f"It took {end - start:0.3f} seconds to validate {len(validation)} triplets.")
print("Saved to:", savepath)
return savepath
if __name__ == "__main__":
args = parse_args()
llm = ChatOpenAI(
model="gpt-4-turbo-preview",
openai_api_key=os.getenv("OPENAI_API_KEY")
)
for run_id in args.run_ids:
zipfile = os.path.join(args.search_dir, f"output-{run_id}.zip")
validate_knowledge_graph(llm, run_id)
print("* " * 25)