Spaces:
Sleeping
Sleeping
import logging | |
from uuid import UUID | |
import numpy as np | |
from ntr_text_fragmentation import LinkerEntity | |
from ntr_text_fragmentation.integrations.sqlalchemy import \ | |
SQLAlchemyEntityRepository | |
from sqlalchemy import func, select | |
from sqlalchemy.orm import Session, sessionmaker | |
from components.dbo.models.entity import EntityModel | |
logger = logging.getLogger(__name__) | |
class ChunkRepository(SQLAlchemyEntityRepository): | |
""" | |
Репозиторий для работы с сущностями (чанками, документами, связями), | |
хранящимися в базе данных с использованием SQL Alchemy. | |
Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию | |
для модели EntityModel. | |
""" | |
def __init__(self, db_session_factory: sessionmaker[Session]): | |
""" | |
Инициализация репозитория. | |
Args: | |
db_session_factory: Фабрика сессий SQLAlchemy. | |
""" | |
super().__init__(db_session_factory) | |
def _entity_model_class(self): | |
"""Возвращает класс модели SQLAlchemy.""" | |
return EntityModel | |
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity: | |
""" | |
Преобразует объект EntityModel из базы данных в объект LinkerEntity | |
или его соответствующий подкласс. | |
Args: | |
db_entity: Сущность EntityModel из базы данных. | |
Returns: | |
Объект LinkerEntity или его подкласс. | |
""" | |
# Создаем базовый LinkerEntity со всеми данными из БД | |
# Преобразуем строковые UUID обратно в объекты UUID | |
base_data = LinkerEntity( | |
id=UUID(db_entity.uuid), | |
name=db_entity.name, | |
text=db_entity.text, | |
in_search_text=db_entity.in_search_text, | |
metadata=db_entity.metadata_json or {}, | |
source_id=UUID(db_entity.source_id) if db_entity.source_id else None, | |
target_id=UUID(db_entity.target_id) if db_entity.target_id else None, | |
number_in_relation=db_entity.number_in_relation, | |
type=db_entity.entity_type, | |
groupper=db_entity.entity_type, | |
) | |
# Используем LinkerEntity._deserialize для получения объекта нужного типа | |
# на основе поля 'type', взятого из db_entity.entity_type | |
try: | |
deserialized_entity = base_data.deserialize() | |
return deserialized_entity | |
except Exception as e: | |
logger.error( | |
f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}" | |
) | |
return base_data | |
def add_entities( | |
self, | |
entities: list[LinkerEntity], | |
dataset_id: int, | |
embeddings: dict[str, np.ndarray] | None = None, | |
): | |
""" | |
Добавляет список сущностей LinkerEntity в базу данных. | |
Args: | |
entities: Список сущностей LinkerEntity для добавления. | |
dataset_id: ID датасета, к которому принадлежат сущности. | |
embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID. | |
""" | |
embeddings = embeddings or {} | |
with self.db() as session: | |
db_entities_to_add = [] | |
for entity in entities: | |
# Преобразуем UUID в строку для хранения в базе | |
entity_id_str = str(entity.id) | |
embedding = embeddings.get(entity_id_str) | |
db_entity = EntityModel( | |
uuid=entity_id_str, | |
name=entity.name, | |
text=entity.text, | |
entity_type=entity.type, | |
in_search_text=entity.in_search_text, | |
metadata_json=( | |
entity.metadata if isinstance(entity.metadata, dict) else {} | |
), | |
source_id=str(entity.source_id) if entity.source_id else None, | |
target_id=str(entity.target_id) if entity.target_id else None, | |
number_in_relation=entity.number_in_relation, | |
dataset_id=dataset_id, | |
embedding=embedding, | |
) | |
db_entities_to_add.append(db_entity) | |
session.add_all(db_entities_to_add) | |
session.commit() | |
def get_searching_entities( | |
self, | |
dataset_id: int, | |
) -> tuple[list[LinkerEntity], list[np.ndarray]]: | |
""" | |
Получает сущности из указанного датасета, которые имеют текст для поиска | |
(in_search_text не None), вместе с их эмбеддингами. | |
Args: | |
dataset_id: ID датасета. | |
Returns: | |
Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array). | |
Порядок эмбеддингов соответствует порядку сущностей. | |
""" | |
entity_model = self._entity_model_class | |
linker_entities = [] | |
embeddings_list = [] | |
with self.db() as session: | |
stmt = select(entity_model).where( | |
entity_model.in_search_text.isnot(None), | |
entity_model.dataset_id == dataset_id, | |
entity_model.embedding.isnot(None) | |
) | |
db_models = session.execute(stmt).scalars().all() | |
# Переносим цикл внутрь сессии | |
for model in db_models: | |
# Теперь маппинг происходит при активной сессии | |
linker_entity = self._map_db_entity_to_linker_entity(model) | |
linker_entities.append(linker_entity) | |
# Извлекаем эмбеддинг. | |
# _map_db_entity_to_linker_entity может поместить его в метаданные. | |
embedding = linker_entity.metadata.get('_embedding') | |
if embedding is None and hasattr(model, 'embedding'): # Fallback | |
embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии | |
if embedding is not None: | |
embeddings_list.append(embedding) | |
else: | |
# Обработка случая отсутствия эмбеддинга | |
print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.") | |
linker_entities.pop() | |
# Возвращаем результаты после закрытия сессии | |
return linker_entities, embeddings_list | |
def count_entities_by_dataset_id(self, dataset_id: int) -> int: | |
""" | |
Подсчитывает общее количество сущностей для указанного датасета. | |
Args: | |
dataset_id: ID датасета. | |
Returns: | |
Общее количество сущностей в датасете. | |
""" | |
entity_model = self._entity_model_class | |
id_column = self._get_id_column() # Получаем колонку ID (uuid или id) | |
with self.db() as session: | |
stmt = select(func.count(id_column)).where( | |
entity_model.dataset_id == dataset_id | |
) | |
count = session.execute(stmt).scalar_one() | |
return count | |