# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License import json import os import re import unittest from pathlib import Path from typing import Any, cast from unittest import mock import pytest import yaml from pydantic import ValidationError import graphrag.config.defaults as defs from graphrag.config import ( ApiKeyMissingError, AzureApiBaseMissingError, AzureDeploymentNameMissingError, CacheConfig, CacheConfigInput, CacheType, ChunkingConfig, ChunkingConfigInput, ClaimExtractionConfig, ClaimExtractionConfigInput, ClusterGraphConfig, ClusterGraphConfigInput, CommunityReportsConfig, CommunityReportsConfigInput, EmbedGraphConfig, EmbedGraphConfigInput, EntityExtractionConfig, EntityExtractionConfigInput, GlobalSearchConfig, GraphRagConfig, GraphRagConfigInput, InputConfig, InputConfigInput, InputFileType, InputType, LLMParameters, LLMParametersInput, LocalSearchConfig, ParallelizationParameters, ReportingConfig, ReportingConfigInput, ReportingType, SnapshotsConfig, SnapshotsConfigInput, StorageConfig, StorageConfigInput, StorageType, SummarizeDescriptionsConfig, SummarizeDescriptionsConfigInput, TextEmbeddingConfig, TextEmbeddingConfigInput, UmapConfig, UmapConfigInput, create_graphrag_config, ) from graphrag.index import ( PipelineConfig, PipelineCSVInputConfig, PipelineFileCacheConfig, PipelineFileReportingConfig, PipelineFileStorageConfig, PipelineInputConfig, PipelineTextInputConfig, PipelineWorkflowReference, create_pipeline_config, ) current_dir = os.path.dirname(__file__) ALL_ENV_VARS = { "GRAPHRAG_API_BASE": "http://some/base", "GRAPHRAG_API_KEY": "test", "GRAPHRAG_API_ORGANIZATION": "test_org", "GRAPHRAG_API_PROXY": "http://some/proxy", "GRAPHRAG_API_VERSION": "v1234", "GRAPHRAG_ASYNC_MODE": "asyncio", "GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL": "cache_account_blob_url", "GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir", "GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1", "GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1", "GRAPHRAG_CACHE_TYPE": "blob", "GRAPHRAG_CHUNK_BY_COLUMNS": "a,b", "GRAPHRAG_CHUNK_OVERLAP": "12", "GRAPHRAG_CHUNK_SIZE": "500", "GRAPHRAG_CLAIM_EXTRACTION_ENABLED": "True", "GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", "GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000", "GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt", "GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456", "GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt", "GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17", "GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000", "GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12", "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name", "GRAPHRAG_EMBEDDING_MAX_RETRIES": "3", "GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123", "GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", "GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500", "GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1", "GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", "GRAPHRAG_EMBEDDING_TARGET": "all", "GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", "GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456", "GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000", "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", "GRAPHRAG_ENCODING_MODEL": "test123", "GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url", "GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant", "GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", "GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt", "GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir", "GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs", "GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn", "GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2", "GRAPHRAG_INPUT_ENCODING": "utf-16", "GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$", "GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source", "GRAPHRAG_INPUT_TYPE": "blob", "GRAPHRAG_INPUT_TEXT_COLUMN": "test_text", "GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp", "GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format", "GRAPHRAG_INPUT_TITLE_COLUMN": "test_title", "GRAPHRAG_INPUT_FILE_TYPE": "text", "GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12", "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", "GRAPHRAG_LLM_MAX_RETRIES": "312", "GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122", "GRAPHRAG_LLM_MAX_TOKENS": "15000", "GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true", "GRAPHRAG_LLM_MODEL": "test-llm", "GRAPHRAG_LLM_N": "1", "GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7", "GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900", "GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", "GRAPHRAG_LLM_THREAD_COUNT": "987", "GRAPHRAG_LLM_THREAD_STAGGER": "0.123", "GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_MAX_CLUSTER_SIZE": "123", "GRAPHRAG_NODE2VEC_ENABLED": "true", "GRAPHRAG_NODE2VEC_ITERATIONS": "878787", "GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000", "GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101", "GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111", "GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345", "GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL": "reporting_account_blob_url", "GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir", "GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2", "GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2", "GRAPHRAG_REPORTING_TYPE": "blob", "GRAPHRAG_SKIP_WORKFLOWS": "a,b,c", "GRAPHRAG_SNAPSHOT_GRAPHML": "true", "GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true", "GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true", "GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL": "storage_account_blob_url", "GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir", "GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs", "GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn", "GRAPHRAG_STORAGE_TYPE": "blob", "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt", "GRAPHRAG_LLM_TEMPERATURE": "0.0", "GRAPHRAG_LLM_TOP_P": "1.0", "GRAPHRAG_UMAP_ENABLED": "true", "GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", "GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", "GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1", "GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9", "GRAPHRAG_LOCAL_SEARCH_LLM_N": "2", "GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12", "GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15", "GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14", "GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2", "GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435", "GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1", "GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9", "GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2", "GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123", "GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123", "GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123", "GRAPHRAG_GLOBAL_SEARCH_CONCURRENCY": "7", "GRAPHRAG_GLOBAL_SEARCH_REDUCE_MAX_TOKENS": "15432", } class TestDefaultConfig(unittest.TestCase): def test_clear_warnings(self): """Just clearing unused import warnings""" assert CacheConfig is not None assert ChunkingConfig is not None assert ClaimExtractionConfig is not None assert ClusterGraphConfig is not None assert CommunityReportsConfig is not None assert EmbedGraphConfig is not None assert EntityExtractionConfig is not None assert GlobalSearchConfig is not None assert GraphRagConfig is not None assert InputConfig is not None assert LLMParameters is not None assert LocalSearchConfig is not None assert ParallelizationParameters is not None assert ReportingConfig is not None assert SnapshotsConfig is not None assert StorageConfig is not None assert SummarizeDescriptionsConfig is not None assert TextEmbeddingConfig is not None assert UmapConfig is not None assert PipelineConfig is not None assert PipelineFileReportingConfig is not None assert PipelineFileStorageConfig is not None assert PipelineInputConfig is not None assert PipelineFileCacheConfig is not None assert PipelineWorkflowReference is not None @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) def test_string_repr(self): # __str__ can be json loaded config = create_graphrag_config() string_repr = str(config) assert string_repr is not None assert json.loads(string_repr) is not None # __repr__ can be eval()'d repr_str = config.__repr__() # TODO: add __repr__ to datashaper enum repr_str = repr_str.replace("async_mode=,", "") assert eval(repr_str) is not None # Pipeline config __str__ can be json loaded pipeline_config = create_pipeline_config(config) string_repr = str(pipeline_config) assert string_repr is not None assert json.loads(string_repr) is not None # Pipeline config __repr__ can be eval()'d repr_str = pipeline_config.__repr__() # TODO: add __repr__ to datashaper enum repr_str = repr_str.replace( "'async_mode': ,", "" ) assert eval(repr_str) is not None @mock.patch.dict(os.environ, {}, clear=True) def test_default_config_with_no_env_vars_throws(self): with pytest.raises(ApiKeyMissingError): # This should throw an error because the API key is missing create_pipeline_config(create_graphrag_config()) @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_default_config_with_api_key_passes(self): # doesn't throw config = create_pipeline_config(create_graphrag_config()) assert config is not None @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) def test_default_config_with_oai_key_passes_envvar(self): # doesn't throw config = create_pipeline_config(create_graphrag_config()) assert config is not None def test_default_config_with_oai_key_passes_obj(self): # doesn't throw config = create_pipeline_config( create_graphrag_config({"llm": {"api_key": "test"}}) ) assert config is not None @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat"}, clear=True, ) def test_throws_if_azure_is_used_without_api_base_envvar(self): with pytest.raises(AzureApiBaseMissingError): create_graphrag_config() @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_throws_if_azure_is_used_without_api_base_obj(self): with pytest.raises(AzureApiBaseMissingError): create_graphrag_config( GraphRagConfigInput(llm=LLMParametersInput(type="azure_openai_chat")) ) @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_API_BASE": "http://some/base", }, clear=True, ) def test_throws_if_azure_is_used_without_llm_deployment_name_envvar(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config() @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_throws_if_azure_is_used_without_llm_deployment_name_obj(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config( GraphRagConfigInput( llm=LLMParametersInput( type="azure_openai_chat", api_base="http://some/base" ) ) ) @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "x", }, clear=True, ) def test_throws_if_azure_is_used_without_embedding_api_base_envvar(self): with pytest.raises(AzureApiBaseMissingError): create_graphrag_config() @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_throws_if_azure_is_used_without_embedding_api_base_obj(self): with pytest.raises(AzureApiBaseMissingError): create_graphrag_config( GraphRagConfigInput( embeddings=TextEmbeddingConfigInput( llm=LLMParametersInput( type="azure_openai_embedding", deployment_name="x", ) ), ) ) @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_API_BASE": "http://some/base", "GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", }, clear=True, ) def test_throws_if_azure_is_used_without_embedding_deployment_name_envvar(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config() @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_throws_if_azure_is_used_without_embedding_deployment_name_obj(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config( GraphRagConfigInput( llm=LLMParametersInput( type="azure_openai_chat", api_base="http://some/base", deployment_name="model-deployment-name-x", ), embeddings=TextEmbeddingConfigInput( llm=LLMParametersInput( type="azure_openai_embedding", ) ), ) ) @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_minimim_azure_config_object(self): config = create_graphrag_config( GraphRagConfigInput( llm=LLMParametersInput( type="azure_openai_chat", api_base="http://some/base", deployment_name="model-deployment-name-x", ), embeddings=TextEmbeddingConfigInput( llm=LLMParametersInput( type="azure_openai_embedding", deployment_name="model-deployment-name", ) ), ) ) assert config is not None @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", }, clear=True, ) def test_throws_if_azure_is_used_without_api_base(self): with pytest.raises(AzureApiBaseMissingError): create_graphrag_config() @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_LLM_API_BASE": "http://some/base", }, clear=True, ) def test_throws_if_azure_is_used_without_llm_deployment_name(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config() @mock.patch.dict( os.environ, { "GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_API_BASE": "http://some/base", "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", }, clear=True, ) def test_throws_if_azure_is_used_without_embedding_deployment_name(self): with pytest.raises(AzureDeploymentNameMissingError): create_graphrag_config() @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "csv"}, clear=True, ) def test_csv_input_returns_correct_config(self): config = create_pipeline_config(create_graphrag_config(root_dir="/some/root")) assert config.root_dir == "/some/root" # Make sure the input is a CSV input assert isinstance(config.input, PipelineCSVInputConfig) assert (config.input.file_pattern or "") == ".*\\.csv$" # type: ignore @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "text"}, clear=True, ) def test_text_input_returns_correct_config(self): config = create_pipeline_config(create_graphrag_config(root_dir=".")) assert isinstance(config.input, PipelineTextInputConfig) assert config.input is not None assert (config.input.file_pattern or "") == ".*\\.txt$" # type: ignore def test_all_env_vars_is_accurate(self): env_var_docs_path = Path("docsite/posts/config/env_vars.md") query_docs_path = Path("docsite/posts/query/3-cli.md") env_var_docs = env_var_docs_path.read_text(encoding="utf-8") query_docs = query_docs_path.read_text(encoding="utf-8") def find_envvar_names(text) -> set[str]: pattern = r"`(GRAPHRAG_[^`]+)`" found = re.findall(pattern, text) found = {f for f in found if not f.endswith("_")} return {*found} graphrag_strings = find_envvar_names(env_var_docs) | find_envvar_names( query_docs ) missing = {s for s in graphrag_strings if s not in ALL_ENV_VARS} - { # Remove configs covered by the base LLM connection configs "GRAPHRAG_LLM_API_KEY", "GRAPHRAG_LLM_API_BASE", "GRAPHRAG_LLM_API_VERSION", "GRAPHRAG_LLM_API_ORGANIZATION", "GRAPHRAG_LLM_API_PROXY", "GRAPHRAG_EMBEDDING_API_KEY", "GRAPHRAG_EMBEDDING_API_BASE", "GRAPHRAG_EMBEDDING_API_VERSION", "GRAPHRAG_EMBEDDING_API_ORGANIZATION", "GRAPHRAG_EMBEDDING_API_PROXY", } if missing: msg = f"{len(missing)} missing env vars: {missing}" print(msg) raise ValueError(msg) @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True, ) def test_malformed_input_dict_throws(self): with pytest.raises(ValidationError): create_graphrag_config(cast(Any, {"llm": 12})) @mock.patch.dict( os.environ, ALL_ENV_VARS, clear=True, ) def test_create_parameters_from_env_vars(self) -> None: parameters = create_graphrag_config() assert parameters.async_mode == "asyncio" assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" assert parameters.cache.base_dir == "/some/cache/dir" assert parameters.cache.connection_string == "test_cs1" assert parameters.cache.container_name == "test_cn1" assert parameters.cache.type == CacheType.blob assert parameters.chunks.group_by_columns == ["a", "b"] assert parameters.chunks.overlap == 12 assert parameters.chunks.size == 500 assert parameters.claim_extraction.enabled assert parameters.claim_extraction.description == "test 123" assert parameters.claim_extraction.max_gleanings == 5000 assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt" assert parameters.cluster_graph.max_cluster_size == 123 assert parameters.community_reports.max_length == 23456 assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt" assert parameters.embed_graph.enabled assert parameters.embed_graph.iterations == 878787 assert parameters.embed_graph.num_walks == 5_000_000 assert parameters.embed_graph.random_seed == 10101 assert parameters.embed_graph.walk_length == 555111 assert parameters.embed_graph.window_size == 12345 assert parameters.embeddings.batch_max_tokens == 17 assert parameters.embeddings.batch_size == 1_000_000 assert parameters.embeddings.llm.concurrent_requests == 12 assert parameters.embeddings.llm.deployment_name == "model-deployment-name" assert parameters.embeddings.llm.max_retries == 3 assert parameters.embeddings.llm.max_retry_wait == 0.1123 assert parameters.embeddings.llm.model == "text-embedding-2" assert parameters.embeddings.llm.requests_per_minute == 500 assert parameters.embeddings.llm.sleep_on_rate_limit_recommendation is False assert parameters.embeddings.llm.tokens_per_minute == 7000 assert parameters.embeddings.llm.type == "azure_openai_embedding" assert parameters.embeddings.parallelization.num_threads == 2345 assert parameters.embeddings.parallelization.stagger == 0.456 assert parameters.embeddings.skip == ["a1", "b1", "c1"] assert parameters.embeddings.target == "all" assert parameters.encoding_model == "test123" assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] assert parameters.entity_extraction.llm.api_base == "http://some/base" assert parameters.entity_extraction.max_gleanings == 112 assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt" assert parameters.input.storage_account_blob_url == "input_account_blob_url" assert parameters.input.base_dir == "/some/input/dir" assert parameters.input.connection_string == "input_cs" assert parameters.input.container_name == "input_cn" assert parameters.input.document_attribute_columns == ["test1", "test2"] assert parameters.input.encoding == "utf-16" assert parameters.input.file_pattern == ".*\\test\\.txt$" assert parameters.input.file_type == InputFileType.text assert parameters.input.source_column == "test_source" assert parameters.input.text_column == "test_text" assert parameters.input.timestamp_column == "test_timestamp" assert parameters.input.timestamp_format == "test_format" assert parameters.input.title_column == "test_title" assert parameters.input.type == InputType.blob assert parameters.llm.api_base == "http://some/base" assert parameters.llm.api_key == "test" assert parameters.llm.api_version == "v1234" assert parameters.llm.concurrent_requests == 12 assert parameters.llm.deployment_name == "model-deployment-name-x" assert parameters.llm.max_retries == 312 assert parameters.llm.max_retry_wait == 0.1122 assert parameters.llm.max_tokens == 15000 assert parameters.llm.model == "test-llm" assert parameters.llm.model_supports_json assert parameters.llm.n == 1 assert parameters.llm.organization == "test_org" assert parameters.llm.proxy == "http://some/proxy" assert parameters.llm.request_timeout == 12.7 assert parameters.llm.requests_per_minute == 900 assert parameters.llm.sleep_on_rate_limit_recommendation is False assert parameters.llm.temperature == 0.0 assert parameters.llm.top_p == 1.0 assert parameters.llm.tokens_per_minute == 8000 assert parameters.llm.type == "azure_openai_chat" assert parameters.parallelization.num_threads == 987 assert parameters.parallelization.stagger == 0.123 assert ( parameters.reporting.storage_account_blob_url == "reporting_account_blob_url" ) assert parameters.reporting.base_dir == "/some/reporting/dir" assert parameters.reporting.connection_string == "test_cs2" assert parameters.reporting.container_name == "test_cn2" assert parameters.reporting.type == ReportingType.blob assert parameters.skip_workflows == ["a", "b", "c"] assert parameters.snapshots.graphml assert parameters.snapshots.raw_entities assert parameters.snapshots.top_level_nodes assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" assert parameters.storage.base_dir == "/some/storage/dir" assert parameters.storage.connection_string == "test_cs" assert parameters.storage.container_name == "test_cn" assert parameters.storage.type == StorageType.blob assert parameters.summarize_descriptions.max_length == 12345 assert ( parameters.summarize_descriptions.prompt == "tests/unit/config/prompt-d.txt" ) assert parameters.umap.enabled assert parameters.local_search.text_unit_prop == 0.713 assert parameters.local_search.community_prop == 0.1234 assert parameters.local_search.llm_max_tokens == 12 assert parameters.local_search.top_k_relationships == 15 assert parameters.local_search.conversation_history_max_turns == 2 assert parameters.local_search.top_k_entities == 14 assert parameters.local_search.temperature == 0.1 assert parameters.local_search.top_p == 0.9 assert parameters.local_search.n == 2 assert parameters.local_search.max_tokens == 142435 assert parameters.global_search.temperature == 0.1 assert parameters.global_search.top_p == 0.9 assert parameters.global_search.n == 2 assert parameters.global_search.max_tokens == 5123 assert parameters.global_search.data_max_tokens == 123 assert parameters.global_search.map_max_tokens == 4123 assert parameters.global_search.concurrency == 7 assert parameters.global_search.reduce_max_tokens == 15432 @mock.patch.dict(os.environ, {"API_KEY_X": "test"}, clear=True) def test_create_parameters(self) -> None: parameters = create_graphrag_config( GraphRagConfigInput( llm=LLMParametersInput(api_key="${API_KEY_X}", model="test-llm"), storage=StorageConfigInput( type=StorageType.blob, connection_string="test_cs", container_name="test_cn", base_dir="/some/storage/dir", storage_account_blob_url="storage_account_blob_url", ), cache=CacheConfigInput( type=CacheType.blob, connection_string="test_cs1", container_name="test_cn1", base_dir="/some/cache/dir", storage_account_blob_url="cache_account_blob_url", ), reporting=ReportingConfigInput( type=ReportingType.blob, connection_string="test_cs2", container_name="test_cn2", base_dir="/some/reporting/dir", storage_account_blob_url="reporting_account_blob_url", ), input=InputConfigInput( file_type=InputFileType.text, file_encoding="utf-16", document_attribute_columns=["test1", "test2"], base_dir="/some/input/dir", connection_string="input_cs", container_name="input_cn", file_pattern=".*\\test\\.txt$", source_column="test_source", text_column="test_text", timestamp_column="test_timestamp", timestamp_format="test_format", title_column="test_title", type="blob", storage_account_blob_url="input_account_blob_url", ), embed_graph=EmbedGraphConfigInput( enabled=True, num_walks=5_000_000, iterations=878787, random_seed=10101, walk_length=555111, ), embeddings=TextEmbeddingConfigInput( batch_size=1_000_000, batch_max_tokens=8000, skip=["a1", "b1", "c1"], llm=LLMParametersInput(model="text-embedding-2"), ), chunks=ChunkingConfigInput( size=500, overlap=12, group_by_columns=["a", "b"] ), snapshots=SnapshotsConfigInput( graphml=True, raw_entities=True, top_level_nodes=True, ), entity_extraction=EntityExtractionConfigInput( max_gleanings=112, entity_types=["cat", "dog", "elephant"], prompt="entity_extraction_prompt_file.txt", ), summarize_descriptions=SummarizeDescriptionsConfigInput( max_length=12345, prompt="summarize_prompt_file.txt" ), community_reports=CommunityReportsConfigInput( max_length=23456, prompt="community_report_prompt_file.txt", max_input_length=12345, ), claim_extraction=ClaimExtractionConfigInput( description="test 123", max_gleanings=5000, prompt="claim_extraction_prompt_file.txt", ), cluster_graph=ClusterGraphConfigInput( max_cluster_size=123, ), umap=UmapConfigInput(enabled=True), encoding_model="test123", skip_workflows=["a", "b", "c"], ), ".", ) assert parameters.cache.base_dir == "/some/cache/dir" assert parameters.cache.connection_string == "test_cs1" assert parameters.cache.container_name == "test_cn1" assert parameters.cache.type == CacheType.blob assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" assert parameters.chunks.group_by_columns == ["a", "b"] assert parameters.chunks.overlap == 12 assert parameters.chunks.size == 500 assert parameters.claim_extraction.description == "test 123" assert parameters.claim_extraction.max_gleanings == 5000 assert parameters.claim_extraction.prompt == "claim_extraction_prompt_file.txt" assert parameters.cluster_graph.max_cluster_size == 123 assert parameters.community_reports.max_input_length == 12345 assert parameters.community_reports.max_length == 23456 assert parameters.community_reports.prompt == "community_report_prompt_file.txt" assert parameters.embed_graph.enabled assert parameters.embed_graph.iterations == 878787 assert parameters.embed_graph.num_walks == 5_000_000 assert parameters.embed_graph.random_seed == 10101 assert parameters.embed_graph.walk_length == 555111 assert parameters.embeddings.batch_max_tokens == 8000 assert parameters.embeddings.batch_size == 1_000_000 assert parameters.embeddings.llm.model == "text-embedding-2" assert parameters.embeddings.skip == ["a1", "b1", "c1"] assert parameters.encoding_model == "test123" assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] assert parameters.entity_extraction.max_gleanings == 112 assert ( parameters.entity_extraction.prompt == "entity_extraction_prompt_file.txt" ) assert parameters.input.base_dir == "/some/input/dir" assert parameters.input.connection_string == "input_cs" assert parameters.input.container_name == "input_cn" assert parameters.input.document_attribute_columns == ["test1", "test2"] assert parameters.input.encoding == "utf-16" assert parameters.input.file_pattern == ".*\\test\\.txt$" assert parameters.input.source_column == "test_source" assert parameters.input.type == "blob" assert parameters.input.text_column == "test_text" assert parameters.input.timestamp_column == "test_timestamp" assert parameters.input.timestamp_format == "test_format" assert parameters.input.title_column == "test_title" assert parameters.input.file_type == InputFileType.text assert parameters.input.storage_account_blob_url == "input_account_blob_url" assert parameters.llm.api_key == "test" assert parameters.llm.model == "test-llm" assert parameters.reporting.base_dir == "/some/reporting/dir" assert parameters.reporting.connection_string == "test_cs2" assert parameters.reporting.container_name == "test_cn2" assert parameters.reporting.type == ReportingType.blob assert ( parameters.reporting.storage_account_blob_url == "reporting_account_blob_url" ) assert parameters.skip_workflows == ["a", "b", "c"] assert parameters.snapshots.graphml assert parameters.snapshots.raw_entities assert parameters.snapshots.top_level_nodes assert parameters.storage.base_dir == "/some/storage/dir" assert parameters.storage.connection_string == "test_cs" assert parameters.storage.container_name == "test_cn" assert parameters.storage.type == StorageType.blob assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" assert parameters.summarize_descriptions.max_length == 12345 assert parameters.summarize_descriptions.prompt == "summarize_prompt_file.txt" assert parameters.umap.enabled @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True, ) def test_default_values(self) -> None: parameters = create_graphrag_config() assert parameters.async_mode == defs.ASYNC_MODE assert parameters.cache.base_dir == defs.CACHE_BASE_DIR assert parameters.cache.type == defs.CACHE_TYPE assert parameters.cache.base_dir == defs.CACHE_BASE_DIR assert parameters.chunks.group_by_columns == defs.CHUNK_GROUP_BY_COLUMNS assert parameters.chunks.overlap == defs.CHUNK_OVERLAP assert parameters.chunks.size == defs.CHUNK_SIZE assert parameters.claim_extraction.description == defs.CLAIM_DESCRIPTION assert parameters.claim_extraction.max_gleanings == defs.CLAIM_MAX_GLEANINGS assert ( parameters.community_reports.max_input_length == defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH ) assert ( parameters.community_reports.max_length == defs.COMMUNITY_REPORT_MAX_LENGTH ) assert parameters.embeddings.batch_max_tokens == defs.EMBEDDING_BATCH_MAX_TOKENS assert parameters.embeddings.batch_size == defs.EMBEDDING_BATCH_SIZE assert parameters.embeddings.llm.model == defs.EMBEDDING_MODEL assert parameters.embeddings.target == defs.EMBEDDING_TARGET assert parameters.embeddings.llm.type == defs.EMBEDDING_TYPE assert ( parameters.embeddings.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE ) assert parameters.embeddings.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE assert ( parameters.embeddings.llm.sleep_on_rate_limit_recommendation == defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION ) assert ( parameters.entity_extraction.entity_types == defs.ENTITY_EXTRACTION_ENTITY_TYPES ) assert ( parameters.entity_extraction.max_gleanings == defs.ENTITY_EXTRACTION_MAX_GLEANINGS ) assert parameters.encoding_model == defs.ENCODING_MODEL assert parameters.input.base_dir == defs.INPUT_BASE_DIR assert parameters.input.file_pattern == defs.INPUT_CSV_PATTERN assert parameters.input.encoding == defs.INPUT_FILE_ENCODING assert parameters.input.type == defs.INPUT_TYPE assert parameters.input.base_dir == defs.INPUT_BASE_DIR assert parameters.input.text_column == defs.INPUT_TEXT_COLUMN assert parameters.input.file_type == defs.INPUT_FILE_TYPE assert parameters.llm.concurrent_requests == defs.LLM_CONCURRENT_REQUESTS assert parameters.llm.max_retries == defs.LLM_MAX_RETRIES assert parameters.llm.max_retry_wait == defs.LLM_MAX_RETRY_WAIT assert parameters.llm.max_tokens == defs.LLM_MAX_TOKENS assert parameters.llm.model == defs.LLM_MODEL assert parameters.llm.request_timeout == defs.LLM_REQUEST_TIMEOUT assert parameters.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE assert parameters.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE assert ( parameters.llm.sleep_on_rate_limit_recommendation == defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION ) assert parameters.llm.type == defs.LLM_TYPE assert parameters.cluster_graph.max_cluster_size == defs.MAX_CLUSTER_SIZE assert parameters.embed_graph.enabled == defs.NODE2VEC_ENABLED assert parameters.embed_graph.iterations == defs.NODE2VEC_ITERATIONS assert parameters.embed_graph.num_walks == defs.NODE2VEC_NUM_WALKS assert parameters.embed_graph.random_seed == defs.NODE2VEC_RANDOM_SEED assert parameters.embed_graph.walk_length == defs.NODE2VEC_WALK_LENGTH assert parameters.embed_graph.window_size == defs.NODE2VEC_WINDOW_SIZE assert ( parameters.parallelization.num_threads == defs.PARALLELIZATION_NUM_THREADS ) assert parameters.parallelization.stagger == defs.PARALLELIZATION_STAGGER assert parameters.reporting.type == defs.REPORTING_TYPE assert parameters.reporting.base_dir == defs.REPORTING_BASE_DIR assert parameters.snapshots.graphml == defs.SNAPSHOTS_GRAPHML assert parameters.snapshots.raw_entities == defs.SNAPSHOTS_RAW_ENTITIES assert parameters.snapshots.top_level_nodes == defs.SNAPSHOTS_TOP_LEVEL_NODES assert parameters.storage.base_dir == defs.STORAGE_BASE_DIR assert parameters.storage.type == defs.STORAGE_TYPE assert parameters.umap.enabled == defs.UMAP_ENABLED @mock.patch.dict( os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True, ) def test_prompt_file_reading(self): config = create_graphrag_config({ "entity_extraction": {"prompt": "tests/unit/config/prompt-a.txt"}, "claim_extraction": {"prompt": "tests/unit/config/prompt-b.txt"}, "community_reports": {"prompt": "tests/unit/config/prompt-c.txt"}, "summarize_descriptions": {"prompt": "tests/unit/config/prompt-d.txt"}, }) strategy = config.entity_extraction.resolved_strategy(".", "abc123") assert strategy["extraction_prompt"] == "Hello, World! A" assert strategy["encoding_name"] == "abc123" strategy = config.claim_extraction.resolved_strategy(".") assert strategy["extraction_prompt"] == "Hello, World! B" strategy = config.community_reports.resolved_strategy(".") assert strategy["extraction_prompt"] == "Hello, World! C" strategy = config.summarize_descriptions.resolved_strategy(".") assert strategy["summarize_prompt"] == "Hello, World! D" @mock.patch.dict( os.environ, { "PIPELINE_LLM_API_KEY": "test", "PIPELINE_LLM_API_BASE": "http://test", "PIPELINE_LLM_API_VERSION": "v1", "PIPELINE_LLM_MODEL": "test-llm", "PIPELINE_LLM_DEPLOYMENT_NAME": "test", }, clear=True, ) def test_yaml_load_e2e(): config_dict = yaml.safe_load( """ input: file_type: text llm: type: azure_openai_chat api_key: ${PIPELINE_LLM_API_KEY} api_base: ${PIPELINE_LLM_API_BASE} api_version: ${PIPELINE_LLM_API_VERSION} model: ${PIPELINE_LLM_MODEL} deployment_name: ${PIPELINE_LLM_DEPLOYMENT_NAME} model_supports_json: True tokens_per_minute: 80000 requests_per_minute: 900 thread_count: 50 concurrent_requests: 25 """ ) # create default configuration pipeline parameters from the custom settings model = config_dict parameters = create_graphrag_config(model, ".") assert parameters.llm.api_key == "test" assert parameters.llm.model == "test-llm" assert parameters.llm.api_base == "http://test" assert parameters.llm.api_version == "v1" assert parameters.llm.deployment_name == "test" # generate the pipeline from the default parameters pipeline_config = create_pipeline_config(parameters, True) config_str = pipeline_config.model_dump_json() assert "${PIPELINE_LLM_API_KEY}" not in config_str assert "${PIPELINE_LLM_API_BASE}" not in config_str assert "${PIPELINE_LLM_API_VERSION}" not in config_str assert "${PIPELINE_LLM_MODEL}" not in config_str assert "${PIPELINE_LLM_DEPLOYMENT_NAME}" not in config_str