|
import asyncio |
|
import os |
|
from dataclasses import asdict, dataclass, field |
|
from datetime import datetime |
|
from functools import partial |
|
from typing import Type, cast |
|
|
|
from .llm import ( |
|
gpt_4o_mini_complete, |
|
openai_embedding, |
|
) |
|
from .operate import ( |
|
chunking_by_token_size, |
|
extract_entities, |
|
local_query, |
|
global_query, |
|
hybrid_query, |
|
naive_query, |
|
) |
|
|
|
from .utils import ( |
|
EmbeddingFunc, |
|
compute_mdhash_id, |
|
limit_async_func_call, |
|
convert_response_to_json, |
|
logger, |
|
set_logger, |
|
) |
|
from .base import ( |
|
BaseGraphStorage, |
|
BaseKVStorage, |
|
BaseVectorStorage, |
|
StorageNameSpace, |
|
QueryParam, |
|
) |
|
|
|
from .storage import ( |
|
JsonKVStorage, |
|
NanoVectorDBStorage, |
|
NetworkXStorage, |
|
) |
|
|
|
from .kg.neo4j_impl import Neo4JStorage |
|
|
|
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: |
|
try: |
|
return asyncio.get_event_loop() |
|
|
|
except RuntimeError: |
|
logger.info("Creating a new event loop in main thread.") |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
return loop |
|
|
|
|
|
@dataclass |
|
class LightRAG: |
|
working_dir: str = field( |
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" |
|
) |
|
|
|
kv_storage: str = field(default="JsonKVStorage") |
|
vector_storage: str = field(default="NanoVectorDBStorage") |
|
graph_storage: str = field(default="NetworkXStorage") |
|
|
|
current_log_level = logger.level |
|
log_level: str = field(default=current_log_level) |
|
|
|
|
|
chunk_token_size: int = 1200 |
|
chunk_overlap_token_size: int = 100 |
|
tiktoken_model_name: str = "gpt-4o-mini" |
|
|
|
|
|
entity_extract_max_gleaning: int = 1 |
|
entity_summary_to_max_tokens: int = 500 |
|
|
|
|
|
node_embedding_algorithm: str = "node2vec" |
|
node2vec_params: dict = field( |
|
default_factory=lambda: { |
|
"dimensions": 1536, |
|
"num_walks": 10, |
|
"walk_length": 40, |
|
"window_size": 2, |
|
"iterations": 3, |
|
"random_seed": 3, |
|
} |
|
) |
|
|
|
|
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) |
|
embedding_batch_num: int = 32 |
|
embedding_func_max_async: int = 16 |
|
|
|
|
|
llm_model_func: callable = gpt_4o_mini_complete |
|
llm_model_name: str = ( |
|
"meta-llama/Llama-3.2-1B-Instruct" |
|
) |
|
llm_model_max_token_size: int = 32768 |
|
llm_model_max_async: int = 16 |
|
llm_model_kwargs: dict = field(default_factory=dict) |
|
|
|
|
|
vector_db_storage_cls_kwargs: dict = field(default_factory=dict) |
|
|
|
enable_llm_cache: bool = True |
|
|
|
|
|
addon_params: dict = field(default_factory=dict) |
|
convert_response_to_json_func: callable = convert_response_to_json |
|
|
|
def __post_init__(self): |
|
log_file = os.path.join(self.working_dir, "lightrag.log") |
|
set_logger(log_file) |
|
logger.setLevel(self.log_level) |
|
|
|
logger.info(f"Logger initialized for working directory: {self.working_dir}") |
|
|
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) |
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n") |
|
|
|
|
|
|
|
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( |
|
self._get_storage_class()[self.kv_storage] |
|
) |
|
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ |
|
self.vector_storage |
|
] |
|
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ |
|
self.graph_storage |
|
] |
|
|
|
if not os.path.exists(self.working_dir): |
|
logger.info(f"Creating working directory {self.working_dir}") |
|
os.makedirs(self.working_dir) |
|
|
|
self.llm_response_cache = ( |
|
self.key_string_value_json_storage_cls( |
|
namespace="llm_response_cache", |
|
global_config=asdict(self), |
|
embedding_func=None, |
|
) |
|
if self.enable_llm_cache |
|
else None |
|
) |
|
|
|
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( |
|
self.embedding_func |
|
) |
|
|
|
|
|
|
|
|
|
self.full_docs = self.key_string_value_json_storage_cls( |
|
namespace="full_docs", |
|
global_config=asdict(self), |
|
embedding_func=self.embedding_func, |
|
) |
|
self.text_chunks = self.key_string_value_json_storage_cls( |
|
namespace="text_chunks", |
|
global_config=asdict(self), |
|
embedding_func=self.embedding_func, |
|
) |
|
self.chunk_entity_relation_graph = self.graph_storage_cls( |
|
namespace="chunk_entity_relation", global_config=asdict(self) |
|
) |
|
|
|
|
|
|
|
|
|
self.entities_vdb = self.vector_db_storage_cls( |
|
namespace="entities", |
|
global_config=asdict(self), |
|
embedding_func=self.embedding_func, |
|
meta_fields={"entity_name"}, |
|
) |
|
self.relationships_vdb = self.vector_db_storage_cls( |
|
namespace="relationships", |
|
global_config=asdict(self), |
|
embedding_func=self.embedding_func, |
|
meta_fields={"src_id", "tgt_id"}, |
|
) |
|
self.chunks_vdb = self.vector_db_storage_cls( |
|
namespace="chunks", |
|
global_config=asdict(self), |
|
embedding_func=self.embedding_func, |
|
) |
|
|
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( |
|
partial( |
|
self.llm_model_func, |
|
hashing_kv=self.llm_response_cache, |
|
**self.llm_model_kwargs, |
|
) |
|
) |
|
|
|
def _get_storage_class(self) -> Type[BaseGraphStorage]: |
|
return { |
|
|
|
"JsonKVStorage": JsonKVStorage, |
|
"OracleKVStorage": OracleKVStorage, |
|
|
|
"NanoVectorDBStorage": NanoVectorDBStorage, |
|
"OracleVectorDBStorage": OracleVectorDBStorage, |
|
|
|
"NetworkXStorage": NetworkXStorage, |
|
"Neo4JStorage": Neo4JStorage, |
|
"OracleGraphStorage": OracleGraphStorage, |
|
|
|
} |
|
|
|
def insert(self, string_or_strings): |
|
loop = always_get_an_event_loop() |
|
return loop.run_until_complete(self.ainsert(string_or_strings)) |
|
|
|
async def ainsert(self, string_or_strings): |
|
update_storage = False |
|
try: |
|
if isinstance(string_or_strings, str): |
|
string_or_strings = [string_or_strings] |
|
|
|
new_docs = { |
|
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} |
|
for c in string_or_strings |
|
} |
|
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) |
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} |
|
if not len(new_docs): |
|
logger.warning("All docs are already in the storage") |
|
return |
|
update_storage = True |
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs") |
|
|
|
inserting_chunks = {} |
|
for doc_key, doc in new_docs.items(): |
|
chunks = { |
|
compute_mdhash_id(dp["content"], prefix="chunk-"): { |
|
**dp, |
|
"full_doc_id": doc_key, |
|
} |
|
for dp in chunking_by_token_size( |
|
doc["content"], |
|
overlap_token_size=self.chunk_overlap_token_size, |
|
max_token_size=self.chunk_token_size, |
|
tiktoken_model=self.tiktoken_model_name, |
|
) |
|
} |
|
inserting_chunks.update(chunks) |
|
_add_chunk_keys = await self.text_chunks.filter_keys( |
|
list(inserting_chunks.keys()) |
|
) |
|
inserting_chunks = { |
|
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys |
|
} |
|
if not len(inserting_chunks): |
|
logger.warning("All chunks are already in the storage") |
|
return |
|
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") |
|
|
|
await self.chunks_vdb.upsert(inserting_chunks) |
|
|
|
logger.info("[Entity Extraction]...") |
|
maybe_new_kg = await extract_entities( |
|
inserting_chunks, |
|
knowledge_graph_inst=self.chunk_entity_relation_graph, |
|
entity_vdb=self.entities_vdb, |
|
relationships_vdb=self.relationships_vdb, |
|
global_config=asdict(self), |
|
) |
|
if maybe_new_kg is None: |
|
logger.warning("No new entities and relationships found") |
|
return |
|
self.chunk_entity_relation_graph = maybe_new_kg |
|
|
|
await self.full_docs.upsert(new_docs) |
|
await self.text_chunks.upsert(inserting_chunks) |
|
finally: |
|
if update_storage: |
|
await self._insert_done() |
|
|
|
async def _insert_done(self): |
|
tasks = [] |
|
for storage_inst in [ |
|
self.full_docs, |
|
self.text_chunks, |
|
self.llm_response_cache, |
|
self.entities_vdb, |
|
self.relationships_vdb, |
|
self.chunks_vdb, |
|
self.chunk_entity_relation_graph, |
|
]: |
|
if storage_inst is None: |
|
continue |
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) |
|
await asyncio.gather(*tasks) |
|
|
|
def query(self, query: str, param: QueryParam = QueryParam()): |
|
loop = always_get_an_event_loop() |
|
return loop.run_until_complete(self.aquery(query, param)) |
|
|
|
async def aquery(self, query: str, param: QueryParam = QueryParam()): |
|
if param.mode == "local": |
|
response = await local_query( |
|
query, |
|
self.chunk_entity_relation_graph, |
|
self.entities_vdb, |
|
self.relationships_vdb, |
|
self.text_chunks, |
|
param, |
|
asdict(self), |
|
) |
|
elif param.mode == "global": |
|
response = await global_query( |
|
query, |
|
self.chunk_entity_relation_graph, |
|
self.entities_vdb, |
|
self.relationships_vdb, |
|
self.text_chunks, |
|
param, |
|
asdict(self), |
|
) |
|
elif param.mode == "hybrid": |
|
response = await hybrid_query( |
|
query, |
|
self.chunk_entity_relation_graph, |
|
self.entities_vdb, |
|
self.relationships_vdb, |
|
self.text_chunks, |
|
param, |
|
asdict(self), |
|
) |
|
elif param.mode == "naive": |
|
response = await naive_query( |
|
query, |
|
self.chunks_vdb, |
|
self.text_chunks, |
|
param, |
|
asdict(self), |
|
) |
|
else: |
|
raise ValueError(f"Unknown mode {param.mode}") |
|
await self._query_done() |
|
return response |
|
|
|
async def _query_done(self): |
|
tasks = [] |
|
for storage_inst in [self.llm_response_cache]: |
|
if storage_inst is None: |
|
continue |
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) |
|
await asyncio.gather(*tasks) |
|
|
|
def delete_by_entity(self, entity_name: str): |
|
loop = always_get_an_event_loop() |
|
return loop.run_until_complete(self.adelete_by_entity(entity_name)) |
|
|
|
async def adelete_by_entity(self, entity_name: str): |
|
entity_name = f'"{entity_name.upper()}"' |
|
|
|
try: |
|
await self.entities_vdb.delete_entity(entity_name) |
|
await self.relationships_vdb.delete_relation(entity_name) |
|
await self.chunk_entity_relation_graph.delete_node(entity_name) |
|
|
|
logger.info( |
|
f"Entity '{entity_name}' and its relationships have been deleted." |
|
) |
|
await self._delete_by_entity_done() |
|
except Exception as e: |
|
logger.error(f"Error while deleting entity '{entity_name}': {e}") |
|
|
|
async def _delete_by_entity_done(self): |
|
tasks = [] |
|
for storage_inst in [ |
|
self.entities_vdb, |
|
self.relationships_vdb, |
|
self.chunk_entity_relation_graph, |
|
]: |
|
if storage_inst is None: |
|
continue |
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) |
|
await asyncio.gather(*tasks) |
|
|