File size: 2,953 Bytes
e331e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import logging
import os
import unittest

from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult

log = logging.getLogger(__name__)


class TestRun(unittest.IsolatedAsyncioTestCase):
    async def test_megapipeline(self):
        pipeline_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "./megapipeline.yml",
        )
        pipeline_result = [gen async for gen in run_pipeline_with_config(pipeline_path)]

        errors = []
        for result in pipeline_result:
            if result.errors is not None and len(result.errors) > 0:
                errors.extend(result.errors)

        if len(errors) > 0:
            print("Errors: ", errors)
        assert len(errors) == 0, "received errors\n!" + "\n".join(errors)

        self._assert_text_units_and_entities_reference_each_other(pipeline_result)

    def _assert_text_units_and_entities_reference_each_other(
        self, pipeline_result: list[PipelineRunResult]
    ):
        text_unit_df = next(
            filter(lambda x: x.workflow == "create_final_text_units", pipeline_result)
        ).result
        entity_df = next(
            filter(lambda x: x.workflow == "create_final_entities", pipeline_result)
        ).result

        assert text_unit_df is not None, "Text unit dataframe should not be None"
        assert entity_df is not None, "Entity dataframe should not be None"

        # Get around typing issues
        if text_unit_df is None or entity_df is None:
            return

        assert len(text_unit_df) > 0, "Text unit dataframe should not be empty"
        assert len(entity_df) > 0, "Entity dataframe should not be empty"

        text_unit_entity_map = {}
        log.info("text_unit_df %s", text_unit_df.columns)

        for _, row in text_unit_df.iterrows():
            values = row.get("entity_ids", [])
            text_unit_entity_map[row["id"]] = set([] if values is None else values)

        entity_text_unit_map = {}
        for _, row in entity_df.iterrows():
            # ALL entities should have text units
            values = row.get("text_unit_ids", [])
            entity_text_unit_map[row["id"]] = set([] if values is None else values)

        text_unit_ids = set(text_unit_entity_map.keys())
        entity_ids = set(entity_text_unit_map.keys())

        for text_unit_id, text_unit_entities in text_unit_entity_map.items():
            assert text_unit_entities.issubset(
                entity_ids
            ), f"Text unit {text_unit_id} has entities {text_unit_entities} that are not in the entity set"
        for entity_id, entity_text_units in entity_text_unit_map.items():
            assert entity_text_units.issubset(
                text_unit_ids
            ), f"Entity {entity_id} has text units {entity_text_units} that are not in the text unit set"