Spaces:
Sleeping
Sleeping
File size: 10,107 Bytes
86c402d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
from uuid import UUID
import numpy as np
from ntr_text_fragmentation import LinkerEntity
from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from components.dbo.models.entity import EntityModel
class ChunkRepository(SQLAlchemyEntityRepository):
def __init__(self, db: Session):
super().__init__(db)
def _entity_model_class(self):
return EntityModel
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel):
"""
Преобразует сущность из базы данных в LinkerEntity.
Args:
db_entity: Сущность из базы данных
Returns:
LinkerEntity
"""
# Преобразуем строковые ID в UUID
entity = LinkerEntity(
id=UUID(db_entity.uuid), # Преобразуем строку в UUID
name=db_entity.name,
text=db_entity.text,
type=db_entity.entity_type,
in_search_text=db_entity.in_search_text,
metadata=db_entity.metadata_json,
source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID
target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID
number_in_relation=db_entity.number_in_relation,
)
return LinkerEntity.deserialize(entity)
def add_entities(
self,
entities: list[LinkerEntity],
dataset_id: int,
embeddings: dict[str, np.ndarray],
):
"""
Добавляет сущности в базу данных.
Args:
entities: Список сущностей для добавления
dataset_id: ID датасета
embeddings: Словарь эмбеддингов {entity_id: embedding}
"""
with self.db() as session:
for entity in entities:
# Преобразуем UUID в строку для хранения в базе
entity_id = str(entity.id)
if entity_id in embeddings:
embedding = embeddings[entity_id]
else:
embedding = None
session.add(
EntityModel(
uuid=str(entity.id), # UUID в строку
name=entity.name,
text=entity.text,
entity_type=entity.type,
in_search_text=entity.in_search_text,
metadata_json=entity.metadata,
source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку
target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку
number_in_relation=entity.number_in_relation,
chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index
dataset_id=dataset_id,
embedding=embedding,
)
)
session.commit()
def get_searching_entities(
self,
dataset_id: int,
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
with self.db() as session:
models = (
session.query(EntityModel)
.filter(EntityModel.in_search_text is not None)
.filter(EntityModel.dataset_id == dataset_id)
.all()
)
return (
[self._map_db_entity_to_linker_entity(model) for model in models],
[model.embedding for model in models],
)
def get_chunks_by_ids(
self,
chunk_ids: list[str],
) -> list[LinkerEntity]:
"""
Получение чанков по их ID.
Args:
chunk_ids: Список ID чанков
Returns:
Список чанков
"""
# Преобразуем все ID в строки для единообразия
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
with self.db() as session:
models = (
session.query(EntityModel)
.filter(EntityModel.uuid.in_(str_chunk_ids))
.all()
)
return [self._map_db_entity_to_linker_entity(model) for model in models]
def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]:
"""
Получить сущности по списку идентификаторов.
Args:
entity_ids: Список идентификаторов сущностей
Returns:
Список сущностей, соответствующих указанным идентификаторам
"""
if not entity_ids:
return []
# Преобразуем UUID в строки
str_entity_ids = [str(entity_id) for entity_id in entity_ids]
with self.db() as session:
entity_model = self._entity_model_class()
db_entities = session.execute(
select(entity_model).where(entity_model.uuid.in_(str_entity_ids))
).scalars().all()
return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]:
"""
Получить соседние чанки для указанных чанков.
Args:
chunk_ids: Список идентификаторов чанков
max_distance: Максимальное расстояние до соседа
Returns:
Список соседних чанков
"""
if not chunk_ids:
return []
# Преобразуем UUID в строки
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
with self.db() as session:
entity_model = self._entity_model_class()
result = []
# Сначала получаем указанные чанки, чтобы узнать их индексы и документы
chunks = session.execute(
select(entity_model).where(
and_(
entity_model.uuid.in_(str_chunk_ids),
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
)
)
).scalars().all()
if not chunks:
return []
# Находим документы для чанков через связи
doc_ids = set()
chunk_indices = {}
for chunk in chunks:
chunk_indices[chunk.uuid] = chunk.chunk_index
# Находим связь от документа к чанку
links = session.execute(
select(entity_model).where(
and_(
entity_model.target_id == chunk.uuid,
entity_model.name == "document_to_chunk"
)
)
).scalars().all()
for link in links:
doc_ids.add(link.source_id)
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
return []
# Для каждого документа находим все его чанки
for doc_id in doc_ids:
# Находим все связи от документа к чанкам
links = session.execute(
select(entity_model).where(
and_(
entity_model.source_id == doc_id,
entity_model.name == "document_to_chunk"
)
)
).scalars().all()
doc_chunk_ids = [link.target_id for link in links]
# Получаем все чанки документа
doc_chunks = session.execute(
select(entity_model).where(
and_(
entity_model.uuid.in_(doc_chunk_ids),
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
)
)
).scalars().all()
# Для каждого чанка в документе проверяем, является ли он соседом
for doc_chunk in doc_chunks:
if doc_chunk.uuid in str_chunk_ids:
continue
if doc_chunk.chunk_index is None:
continue
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
is_neighbor = False
for orig_chunk_id, orig_index in chunk_indices.items():
if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance:
is_neighbor = True
break
if is_neighbor:
result.append(self._map_db_entity_to_linker_entity(doc_chunk))
return result
|