File size: 9,649 Bytes
b115d50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from __future__ import annotations

import json
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import BaseModel, Field

from steamship import SteamshipError
from steamship.base import Task
from steamship.base.client import Client
from steamship.base.model import CamelModel
from steamship.base.request import DeleteRequest, Request
from steamship.base.response import Response
from steamship.data.search import Hit
from steamship.utils.metadata import metadata_to_str

MAX_RECOMMENDED_ITEM_LENGTH = 5000


class EmbedAndSearchRequest(Request):
    query: str
    docs: List[str]
    plugin_instance: str
    k: int = 1


class QueryResult(CamelModel):
    value: Optional[Hit] = None
    score: Optional[float] = None
    index: Optional[int] = None
    id: Optional[str] = None


class QueryResults(Request):
    items: List[QueryResult] = None


class EmbeddedItem(CamelModel):
    id: str = None
    index_id: str = None
    file_id: str = None
    block_id: str = None
    tag_id: str = None
    value: str = None
    external_id: str = None
    external_type: str = None
    metadata: Any = None
    embedding: List[float] = None

    def clone_for_insert(self) -> EmbeddedItem:
        """Produces a clone with a string representation of the metadata"""
        ret = EmbeddedItem(
            id=self.id,
            index_id=self.index_id,
            file_id=self.file_id,
            block_id=self.block_id,
            tag_id=self.tag_id,
            value=self.value,
            external_id=self.external_id,
            external_type=self.external_type,
            metadata=self.metadata,
            embedding=self.embedding,
        )
        if isinstance(ret.metadata, dict) or isinstance(ret.metadata, list):
            ret.metadata = json.dumps(ret.metadata)
        return ret


class IndexCreateRequest(Request):
    handle: str = None
    name: str = None
    plugin_instance: str = None
    fetch_if_exists: bool = True
    external_id: str = None
    external_type: str = None
    metadata: Any = None


class IndexInsertRequest(Request):
    index_id: str
    items: List[EmbeddedItem] = None
    value: str = None
    file_id: str = None
    block_type: str = None
    external_id: str = None
    external_type: str = None
    metadata: Any = None
    reindex: bool = True


class IndexItemId(CamelModel):
    index_id: str = None
    id: str = None


class IndexInsertResponse(Response):
    item_ids: List[IndexItemId] = None


class IndexEmbedRequest(Request):
    id: str


class IndexEmbedResponse(Response):
    id: Optional[str] = None


class IndexSearchRequest(Request):
    id: str
    query: str = None
    queries: List[str] = None
    k: int = 1
    include_metadata: bool = False


class ListItemsRequest(Request):
    id: str = None
    file_id: str = None
    block_id: str = None
    span_id: str = None


class ListItemsResponse(Response):
    items: List[EmbeddedItem]


class EmbeddingIndex(CamelModel):
    """A persistent, read-optimized index over embeddings."""

    client: Client = Field(None, exclude=True)
    id: str = None
    handle: str = None
    name: str = None
    plugin: str = None
    external_id: str = None
    external_type: str = None
    metadata: str = None

    @classmethod
    def parse_obj(cls: Type[BaseModel], obj: Any) -> BaseModel:
        # TODO (enias): This needs to be solved at the engine side
        if "embeddingIndex" in obj:
            obj = obj["embeddingIndex"]
        elif "index" in obj:
            obj = obj["index"]
        return super().parse_obj(obj)

    def insert_file(
        self,
        file_id: str,
        block_type: str = None,
        external_id: str = None,
        external_type: str = None,
        metadata: Union[int, float, bool, str, List, Dict] = None,
        reindex: bool = True,
    ) -> IndexInsertResponse:
        if isinstance(metadata, dict) or isinstance(metadata, list):
            metadata = json.dumps(metadata)

        req = IndexInsertRequest(
            index_id=self.id,
            file_id=file_id,
            blockType=block_type,
            external_id=external_id,
            external_type=external_type,
            metadata=metadata,
            reindex=reindex,
        )
        return self.client.post(
            "embedding-index/item/create",
            req,
            expect=IndexInsertResponse,
        )

    def _check_input(self, request: IndexInsertRequest, allow_long_records: bool):
        if not allow_long_records:
            if request.value is not None and len(request.value) > MAX_RECOMMENDED_ITEM_LENGTH:
                raise SteamshipError(
                    f"Inserted item of length {len(request.value)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True."
                )
            if request.items is not None:
                for i, item in enumerate(request.items):
                    if item is not None:
                        if isinstance(item, str) and len(item) > MAX_RECOMMENDED_ITEM_LENGTH:
                            raise SteamshipError(
                                f"Inserted item {i} of length {len(item)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True."
                            )
                        if (
                            isinstance(item, EmbeddedItem)
                            and item.value is not None
                            and len(item.value) > MAX_RECOMMENDED_ITEM_LENGTH
                        ):
                            raise SteamshipError(
                                f"Inserted item {i} of length {len(item.value)} exceeded maximum recommended length of {MAX_RECOMMENDED_ITEM_LENGTH} characters. You may insert it anyway by passing allow_long_records=True."
                            )

    def insert_many(
        self,
        items: List[Union[EmbeddedItem, str]],
        reindex: bool = True,
        allow_long_records=False,
    ) -> IndexInsertResponse:
        new_items = []
        for item in items:
            if isinstance(item, str):
                new_items.append(EmbeddedItem(value=item))
            else:
                new_items.append(item)

        req = IndexInsertRequest(
            index_id=self.id,
            items=[item.clone_for_insert() for item in new_items],
            reindex=reindex,
        )
        self._check_input(req, allow_long_records)
        return self.client.post(
            "embedding-index/item/create",
            req,
            expect=IndexInsertResponse,
        )

    def insert(
        self,
        value: str,
        external_id: str = None,
        external_type: str = None,
        metadata: Union[int, float, bool, str, List, Dict] = None,
        reindex: bool = True,
        allow_long_records=False,
    ) -> IndexInsertResponse:

        req = IndexInsertRequest(
            index_id=self.id,
            value=value,
            external_id=external_id,
            external_type=external_type,
            metadata=metadata_to_str(metadata),
            reindex=reindex,
        )
        self._check_input(req, allow_long_records)
        return self.client.post(
            "embedding-index/item/create",
            req,
            expect=IndexInsertResponse,
        )

    def embed(
        self,
    ) -> Task[IndexEmbedResponse]:
        req = IndexEmbedRequest(id=self.id)
        return self.client.post(
            "embedding-index/embed",
            req,
            expect=IndexEmbedResponse,
        )

    def list_items(
        self,
        file_id: str = None,
        block_id: str = None,
        span_id: str = None,
    ) -> ListItemsResponse:
        req = ListItemsRequest(id=self.id, file_id=file_id, block_id=block_id, spanId=span_id)
        return self.client.post(
            "embedding-index/item/list",
            req,
            expect=ListItemsResponse,
        )

    def delete(self) -> EmbeddingIndex:
        return self.client.post(
            "embedding-index/delete",
            DeleteRequest(id=self.id),
            expect=EmbeddingIndex,
        )

    def search(
        self,
        query: Union[str, List[str]],
        k: int = 1,
        include_metadata: bool = False,
    ) -> Task[QueryResults]:
        if isinstance(query, list):
            req = IndexSearchRequest(
                id=self.id, queries=query, k=k, include_metadata=include_metadata
            )
        else:
            req = IndexSearchRequest(
                id=self.id, query=query, k=k, include_metadata=include_metadata
            )
        ret = self.client.post(
            "embedding-index/search",
            req,
            expect=QueryResults,
        )

        return ret

    @staticmethod
    def create(
        client: Client,
        handle: str = None,
        name: str = None,
        embedder_plugin_instance_handle: str = None,
        fetch_if_exists: bool = True,
        external_id: str = None,
        external_type: str = None,
        metadata: Any = None,
    ) -> EmbeddingIndex:
        req = IndexCreateRequest(
            handle=handle,
            name=name,
            plugin_instance=embedder_plugin_instance_handle,
            fetch_if_exists=fetch_if_exists,
            external_id=external_id,
            external_type=external_type,
            metadata=metadata,
        )
        return client.post(
            "embedding-index/create",
            req,
            expect=EmbeddingIndex,
        )