Spaces:
Sleeping
Sleeping
import os | |
import re | |
from time import time | |
from argparse import ArgumentParser | |
from typing import List | |
from zipfile import ZipFile | |
import pandas as pd | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from . import VALIDATION_CONFIG, OUTPUT_DIR | |
from .output_parsers import QuestionOutputParser | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"run_ids", nargs="+", help="Run IDs to evaluate. Format: YYYY-MM-DD-XXX" | |
) | |
parser.add_argument( | |
"--search_dir", help="Directory to search for output", default=OUTPUT_DIR | |
) | |
return parser.parse_args() | |
def format_system_prompt(): | |
"""Format instructional prompt from instructions.yml""" | |
def format_question(question: dict): | |
# Question format: {text} {additional} {options} | |
# Ex. Can pigs fly? Explain. (Yes/No) | |
formatted = ( | |
question["title"] | |
+ " " | |
+ question["text"] | |
+ " " | |
+ (question["additional"] + " " if question["additional"] else "") | |
+ "(" | |
+ "; ".join(question["options"]) | |
+ ")\n" | |
) | |
return formatted | |
def format_example(example: dict): | |
formatted = ( | |
f"PASSAGE: {example['passage']}\n" | |
f"TRIPLET: {example['triplet']}\n\n" | |
+ "".join([ | |
f"{i}) {answer}\n" | |
for i, answer in enumerate(example["answers"].values(), start=1) | |
]) | |
) | |
return formatted | |
instruction = VALIDATION_CONFIG["instruction"] | |
questions = "".join([ | |
f"{i}) {format_question(q)}" | |
for i, q in enumerate(VALIDATION_CONFIG["questions"].values(), start=1) | |
]) | |
example = format_example(VALIDATION_CONFIG["example"]) | |
system_prompt = ( | |
f"{instruction}\n" f"QUESTIONS:\n{questions}\n" f"[* EXAMPLE *]\n\n{example}" | |
) | |
return system_prompt | |
def validate_triplets( | |
llm: ChatOpenAI, instruction: str, passage: str, triplets: List[List[str]] | |
) -> List[pd.DataFrame]: | |
"""Validate triplets with respect to passage.""" | |
print( | |
f"Validating {len(triplets):>2} triplet{'s' if len(triplets) else ''}...", | |
end=" ", | |
) | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "{instruction}"), | |
("user", "PASSAGE: {passage}\n\nTRIPLET: {triplet}"), | |
]) | |
chain = prompt | llm | QuestionOutputParser() | |
output = chain.batch([ | |
{"instruction": instruction, "passage": passage, "triplet": triplet} | |
for triplet in triplets | |
]) | |
print("Done!", end="\n") | |
return output | |
def validate_knowledge_graph(llm: ChatOpenAI, output_zip: str): | |
"""Validate all triplets in a knowledge graph.""" | |
run_id = re.findall(r"output-(.*)\.zip", output_zip)[0] | |
run_dir = os.path.dirname(output_zip) | |
# read output zip | |
with ZipFile(output_zip) as z: | |
# load knowledge graph | |
with z.open("kg.csv") as f: | |
graph = pd.read_csv(f) | |
# load text batches | |
with z.open("text_batches.csv") as f: | |
text_batches = pd.read_csv(f) | |
print("Initializing knowledge graph validation. Run:", run_id) | |
print() | |
# start stopwatch | |
start = time() | |
# container for evaluation responses | |
responses = [] | |
# instructions | |
instruction = format_system_prompt() | |
# triplets are batched by passage | |
# so we iterate over passages | |
for idx, passage in enumerate(text_batches.text): | |
triplets = ( | |
graph[graph["batch_id"] == idx].drop(columns=["batch_id"]).values.tolist() | |
) | |
print(f"Starting excerpt {idx + 1:>2} of {len(text_batches)}.", end=" ") | |
# if batch has no triplets to validate, skip batch | |
if len(triplets) == 0: | |
print("Excerpt has no triplets to validate.", end="\n") | |
continue | |
# validate triplets by batch | |
response = validate_triplets( | |
llm=llm, | |
instruction=instruction, | |
passage=passage, | |
triplets=triplets | |
) | |
responses.extend(response) | |
validation = pd.concat(responses, ignore_index=True) | |
# merge with knowlege graph data | |
validation_merged = ( | |
text_batches.merge(graph) | |
.drop(columns=["batch_id"]) | |
.merge(validation, left_index=True, right_index=True) | |
) | |
savepath = os.path.join(run_dir, f"validation-{run_id}.csv") | |
validation_merged.to_csv(savepath, index=False) | |
# stop stopwatch | |
end = time() | |
print("\nKnowledge graph validation complete!") | |
print(f"It took {end - start:0.3f} seconds to validate {len(validation)} triplets.") | |
print("Saved to:", savepath) | |
return savepath | |
if __name__ == "__main__": | |
args = parse_args() | |
llm = ChatOpenAI( | |
model="gpt-4-turbo-preview", | |
openai_api_key=os.getenv("OPENAI_API_KEY") | |
) | |
for run_id in args.run_ids: | |
zipfile = os.path.join(args.search_dir, f"output-{run_id}.zip") | |
validate_knowledge_graph(llm, run_id) | |
print("* " * 25) | |