File size: 10,107 Bytes
86c402d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
from uuid import UUID

import numpy as np
from ntr_text_fragmentation import LinkerEntity
from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository
from sqlalchemy import and_, select
from sqlalchemy.orm import Session

from components.dbo.models.entity import EntityModel


class ChunkRepository(SQLAlchemyEntityRepository):
    def __init__(self, db: Session):
        super().__init__(db)

    def _entity_model_class(self):
        return EntityModel

    def _map_db_entity_to_linker_entity(self, db_entity: EntityModel):
        """
        Преобразует сущность из базы данных в LinkerEntity.
        
        Args:
            db_entity: Сущность из базы данных
            
        Returns:
            LinkerEntity
        """
        # Преобразуем строковые ID в UUID
        entity = LinkerEntity(
            id=UUID(db_entity.uuid),  # Преобразуем строку в UUID
            name=db_entity.name,
            text=db_entity.text,
            type=db_entity.entity_type,
            in_search_text=db_entity.in_search_text,
            metadata=db_entity.metadata_json,
            source_id=UUID(db_entity.source_id) if db_entity.source_id else None,  # Преобразуем строку в UUID
            target_id=UUID(db_entity.target_id) if db_entity.target_id else None,  # Преобразуем строку в UUID
            number_in_relation=db_entity.number_in_relation,
        )
        return LinkerEntity.deserialize(entity)

    def add_entities(
        self,
        entities: list[LinkerEntity],
        dataset_id: int,
        embeddings: dict[str, np.ndarray],
    ):
        """
        Добавляет сущности в базу данных.
        
        Args:
            entities: Список сущностей для добавления
            dataset_id: ID датасета
            embeddings: Словарь эмбеддингов {entity_id: embedding}
        """
        with self.db() as session:
            for entity in entities:
                # Преобразуем UUID в строку для хранения в базе
                entity_id = str(entity.id)
                
                if entity_id in embeddings:
                    embedding = embeddings[entity_id]
                else:
                    embedding = None
                    
                session.add(
                    EntityModel(
                        uuid=str(entity.id),  # UUID в строку
                        name=entity.name,
                        text=entity.text,
                        entity_type=entity.type,
                        in_search_text=entity.in_search_text,
                        metadata_json=entity.metadata,
                        source_id=str(entity.source_id) if entity.source_id else None,  # UUID в строку
                        target_id=str(entity.target_id) if entity.target_id else None,  # UUID в строку
                        number_in_relation=entity.number_in_relation,
                        chunk_index=getattr(entity, "chunk_index", None),  # Добавляем chunk_index
                        dataset_id=dataset_id,
                        embedding=embedding,
                    )
                )

            session.commit()

    def get_searching_entities(
        self,
        dataset_id: int,
    ) -> tuple[list[LinkerEntity], list[np.ndarray]]:
        with self.db() as session:
            models = (
                session.query(EntityModel)
                .filter(EntityModel.in_search_text is not None)
                .filter(EntityModel.dataset_id == dataset_id)
                .all()
            )
        return (
            [self._map_db_entity_to_linker_entity(model) for model in models],
            [model.embedding for model in models],
        )

    def get_chunks_by_ids(
        self,
        chunk_ids: list[str],
    ) -> list[LinkerEntity]:
        """
        Получение чанков по их ID.
        
        Args:
            chunk_ids: Список ID чанков
            
        Returns:
            Список чанков
        """
        # Преобразуем все ID в строки для единообразия
        str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
        
        with self.db() as session:
            models = (
                session.query(EntityModel)
                .filter(EntityModel.uuid.in_(str_chunk_ids))
                .all()
            )
        return [self._map_db_entity_to_linker_entity(model) for model in models]

    def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]:
        """
        Получить сущности по списку идентификаторов.
        
        Args:
            entity_ids: Список идентификаторов сущностей
            
        Returns:
            Список сущностей, соответствующих указанным идентификаторам
        """
        if not entity_ids:
            return []
        
        # Преобразуем UUID в строки
        str_entity_ids = [str(entity_id) for entity_id in entity_ids]
        
        with self.db() as session:
            entity_model = self._entity_model_class()
            db_entities = session.execute(
                select(entity_model).where(entity_model.uuid.in_(str_entity_ids))
            ).scalars().all()
        
        return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]

    def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]:
        """
        Получить соседние чанки для указанных чанков.
        
        Args:
            chunk_ids: Список идентификаторов чанков
            max_distance: Максимальное расстояние до соседа
            
        Returns:
            Список соседних чанков
        """
        if not chunk_ids:
            return []
        
        # Преобразуем UUID в строки
        str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
        
        with self.db() as session:
            entity_model = self._entity_model_class()
            result = []
            
            # Сначала получаем указанные чанки, чтобы узнать их индексы и документы
            chunks = session.execute(
                select(entity_model).where(
                    and_(
                        entity_model.uuid.in_(str_chunk_ids),
                        entity_model.entity_type == "Chunk"  # Используем entity_type вместо type
                    )
                )
            ).scalars().all()
        
            if not chunks:
                return []
            
            # Находим документы для чанков через связи
            doc_ids = set()
            chunk_indices = {}
            
            for chunk in chunks:
                chunk_indices[chunk.uuid] = chunk.chunk_index
                
                # Находим связь от документа к чанку
                links = session.execute(
                    select(entity_model).where(
                        and_(
                            entity_model.target_id == chunk.uuid,
                            entity_model.name == "document_to_chunk"
                        )
                    )
                ).scalars().all()
                
                for link in links:
                    doc_ids.add(link.source_id)
            
            if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
                return []
        
            # Для каждого документа находим все его чанки
            for doc_id in doc_ids:
                # Находим все связи от документа к чанкам
                links = session.execute(
                    select(entity_model).where(
                        and_(
                            entity_model.source_id == doc_id,
                            entity_model.name == "document_to_chunk"
                        )
                    )
                ).scalars().all()
                
                doc_chunk_ids = [link.target_id for link in links]
                
                # Получаем все чанки документа
                doc_chunks = session.execute(
                    select(entity_model).where(
                        and_(
                            entity_model.uuid.in_(doc_chunk_ids),
                            entity_model.entity_type == "Chunk"  # Используем entity_type вместо type
                        )
                    )
                ).scalars().all()
            
            # Для каждого чанка в документе проверяем, является ли он соседом
            for doc_chunk in doc_chunks:
                if doc_chunk.uuid in str_chunk_ids:
                    continue
                
                if doc_chunk.chunk_index is None:
                    continue
                
                # Проверяем, является ли чанк соседом какого-либо из исходных чанков
                is_neighbor = False
                for orig_chunk_id, orig_index in chunk_indices.items():
                    if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance:
                        is_neighbor = True
                        break
                
                if is_neighbor:
                    result.append(self._map_db_entity_to_linker_entity(doc_chunk))
        
        return result