Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Microsoft Corporation. | |
# Licensed under the MIT License | |
import asyncio | |
import json | |
import logging | |
import os | |
import shutil | |
import subprocess | |
from collections.abc import Callable | |
from functools import wraps | |
from pathlib import Path | |
from typing import Any, ClassVar | |
from unittest import mock | |
import pandas as pd | |
import pytest | |
from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage | |
log = logging.getLogger(__name__) | |
debug = os.environ.get("DEBUG") is not None | |
gh_pages = os.environ.get("GH_PAGES") is not None | |
# cspell:disable-next-line well-known-key | |
WELL_KNOWN_AZURITE_CONNECTION_STRING = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1" | |
def _load_fixtures(): | |
"""Load all fixtures from the tests/data folder.""" | |
params = [] | |
fixtures_path = Path("./tests/fixtures/") | |
# use the min-csv smoke test to hydrate the docsite parquet artifacts (see gh-pages.yml) | |
subfolders = ["min-csv"] if gh_pages else sorted(os.listdir(fixtures_path)) | |
for subfolder in subfolders: | |
if not os.path.isdir(fixtures_path / subfolder): | |
continue | |
config_file = fixtures_path / subfolder / "config.json" | |
with config_file.open() as f: | |
params.append((subfolder, json.load(f))) | |
return params | |
def pytest_generate_tests(metafunc): | |
"""Generate tests for all test functions in this module.""" | |
run_slow = metafunc.config.getoption("run_slow") | |
configs = metafunc.cls.params[metafunc.function.__name__] | |
if not run_slow: | |
# Only run tests that are not marked as slow | |
configs = [config for config in configs if not config[1].get("slow", False)] | |
funcarglist = [params[1] for params in configs] | |
id_list = [params[0] for params in configs] | |
argnames = sorted(arg for arg in funcarglist[0] if arg != "slow") | |
metafunc.parametrize( | |
argnames, | |
[[funcargs[name] for name in argnames] for funcargs in funcarglist], | |
ids=id_list, | |
) | |
def cleanup(skip: bool = False): | |
"""Decorator to cleanup the output and cache folders after each test.""" | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except AssertionError: | |
raise | |
finally: | |
if not skip: | |
root = Path(kwargs["input_path"]) | |
shutil.rmtree(root / "output", ignore_errors=True) | |
shutil.rmtree(root / "cache", ignore_errors=True) | |
return wrapper | |
return decorator | |
async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], None]: | |
"""Prepare the data for the Azurite tests.""" | |
input_container = azure["input_container"] | |
input_base_dir = azure.get("input_base_dir") | |
root = Path(input_path) | |
input_storage = BlobPipelineStorage( | |
connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, | |
container_name=input_container, | |
) | |
# Bounce the container if it exists to clear out old run data | |
input_storage.delete_container() | |
input_storage.create_container() | |
# Upload data files | |
txt_files = list((root / "input").glob("*.txt")) | |
csv_files = list((root / "input").glob("*.csv")) | |
data_files = txt_files + csv_files | |
for data_file in data_files: | |
with data_file.open(encoding="utf8") as f: | |
text = f.read() | |
file_path = ( | |
str(Path(input_base_dir) / data_file.name) | |
if input_base_dir | |
else data_file.name | |
) | |
await input_storage.set(file_path, text, encoding="utf-8") | |
return lambda: input_storage.delete_container() | |
class TestIndexer: | |
params: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = { | |
"test_fixture": _load_fixtures() | |
} | |
def __run_indexer( | |
self, | |
root: Path, | |
input_file_type: str, | |
): | |
command = [ | |
"poetry", | |
"run", | |
"poe", | |
"index", | |
"--verbose" if debug else None, | |
"--root", | |
root.absolute().as_posix(), | |
"--reporter", | |
"print", | |
] | |
command = [arg for arg in command if arg] | |
log.info("running command ", " ".join(command)) | |
completion = subprocess.run( | |
command, env={**os.environ, "GRAPHRAG_INPUT_FILE_TYPE": input_file_type} | |
) | |
assert ( | |
completion.returncode == 0 | |
), f"Indexer failed with return code: {completion.returncode}" | |
def __assert_indexer_outputs( | |
self, root: Path, workflow_config: dict[str, dict[str, Any]] | |
): | |
outputs_path = root / "output" | |
output_entries = list(outputs_path.iterdir()) | |
# Sort the output folders by creation time, most recent | |
output_entries.sort(key=lambda entry: entry.stat().st_ctime, reverse=True) | |
if not debug: | |
assert ( | |
len(output_entries) == 1 | |
), f"Expected one output folder, found {len(output_entries)}" | |
output_path = output_entries[0] | |
assert output_path.exists(), "output folder does not exist" | |
artifacts = output_path / "artifacts" | |
assert artifacts.exists(), "artifact folder does not exist" | |
# Check stats for all workflow | |
with (artifacts / "stats.json").open() as f: | |
stats = json.load(f) | |
# Check all workflows run | |
expected_workflows = set(workflow_config.keys()) | |
workflows = set(stats["workflows"].keys()) | |
assert ( | |
workflows == expected_workflows | |
), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}" | |
# [OPTIONAL] Check subworkflows | |
for workflow in expected_workflows: | |
if "subworkflows" in workflow_config[workflow]: | |
# Check number of subworkflows | |
subworkflows = stats["workflows"][workflow] | |
expected_subworkflows = workflow_config[workflow].get( | |
"subworkflows", None | |
) | |
if expected_subworkflows: | |
assert ( | |
len(subworkflows) - 1 == expected_subworkflows | |
), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]" | |
# Check max runtime | |
max_runtime = workflow_config[workflow].get("max_runtime", None) | |
if max_runtime: | |
assert ( | |
stats["workflows"][workflow]["overall"] <= max_runtime | |
), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}" | |
# Check artifacts | |
artifact_files = os.listdir(artifacts) | |
assert ( | |
len(artifact_files) == len(expected_workflows) + 1 | |
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}" | |
for artifact in artifact_files: | |
if artifact.endswith(".parquet"): | |
output_df = pd.read_parquet(artifacts / artifact) | |
artifact_name = artifact.split(".")[0] | |
workflow = workflow_config[artifact_name] | |
# Check number of rows between range | |
assert ( | |
workflow["row_range"][0] | |
<= len(output_df) | |
<= workflow["row_range"][1] | |
), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}" | |
# Get non-nan rows | |
nan_df = output_df.loc[ | |
:, ~output_df.columns.isin(workflow.get("nan_allowed_columns", [])) | |
] | |
nan_df = nan_df[nan_df.isna().any(axis=1)] | |
assert ( | |
len(nan_df) == 0 | |
), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}" | |
def __run_query(self, root: Path, query_config: dict[str, str]): | |
command = [ | |
"poetry", | |
"run", | |
"poe", | |
"query", | |
"--root", | |
root.absolute().as_posix(), | |
"--method", | |
query_config["method"], | |
"--community_level", | |
str(query_config.get("community_level", 2)), | |
query_config["query"], | |
] | |
log.info("running command ", " ".join(command)) | |
return subprocess.run(command, capture_output=True, text=True) | |
# Extend the timeout to 600 seconds (10 minutes) | |
def test_fixture( | |
self, | |
input_path: str, | |
input_file_type: str, | |
workflow_config: dict[str, dict[str, Any]], | |
query_config: list[dict[str, str]], | |
): | |
if workflow_config.get("skip", False): | |
print(f"skipping smoke test {input_path})") | |
return | |
azure = workflow_config.get("azure") | |
root = Path(input_path) | |
dispose = None | |
if azure is not None: | |
dispose = asyncio.run(prepare_azurite_data(input_path, azure)) | |
print("running indexer") | |
self.__run_indexer(root, input_file_type) | |
print("indexer complete") | |
if dispose is not None: | |
dispose() | |
if not workflow_config.get("skip_assert", False): | |
print("performing dataset assertions") | |
self.__assert_indexer_outputs(root, workflow_config) | |
print("running queries") | |
for query in query_config: | |
result = self.__run_query(root, query) | |
print(f"Query: {query}\nResponse: {result.stdout}") | |
# Check stderr because lancedb logs path creating as WARN which leads to false negatives | |
stderror = ( | |
result.stderr if "No existing dataset at" not in result.stderr else "" | |
) | |
assert stderror == "", f"Query failed with error: {stderror}" | |
assert result.stdout is not None, "Query returned no output" | |
assert len(result.stdout) > 0, "Query returned empty output" | |