File size: 7,717 Bytes
b115d50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from typing import Any, Dict, List, Optional, Union, cast

from pydantic import Field

from steamship.base.client import Client
from steamship.base.error import SteamshipError
from steamship.base.model import CamelModel
from steamship.base.tasks import Task
from steamship.data.embeddings import EmbeddedItem, EmbeddingIndex, QueryResult, QueryResults
from steamship.data.plugin.plugin_instance import PluginInstance
from steamship.data.tags.tag import Tag


class EmbedderInvocation(CamelModel):
    """The parameters capable of creating/fetching an Embedder (Tagger) Plugin Instance."""

    plugin_handle: str
    instance_handle: Optional[str] = None
    config: Optional[Dict[str, Any]] = None
    version: Optional[str] = None
    fetch_if_exists: bool = True


class SearchResult(CamelModel):
    """A single scored search result -- which is always a tag.

    This class is intended to eventually replace the QueryResult object currently used with the Embedding layer."""

    tag: Optional[Tag] = None
    score: Optional[float] = None

    @staticmethod
    def from_query_result(query_result: QueryResult) -> "SearchResult":
        hit = query_result.value
        value = hit.metadata or {}

        # To make this change Python-only, some fields are stached in `hit.metadata`.
        # This has the temporary consequence of these keys not being safe. This will be resolved when we spread
        # this refactor to the engine.
        block_id = None
        if "_block_id" in value:
            block_id = value.get("_block_id")
        del value["_block_id"]

        file_id = None
        if "_file_id" in value:
            file_id = value.get("_file_id")
        del value["_file_id"]

        tag_id = None
        if "_tag_id" in value:
            tag_id = value.get("_tag_id")
        del value["_tag_id"]

        tag = Tag(
            id=hit.id,
            kind=hit.external_type,
            name=hit.external_id,
            block_id=block_id,
            tag_id=tag_id,
            file_id=file_id,
            text=hit.value,
            value=value,
        )
        return SearchResult(tag=tag, score=query_result.score)


class SearchResults(CamelModel):
    """Results of a search operation -- which is always a list of ranked tag.

    This class is intended to eventually replace the QueryResults object currently used with the Embedding layer.
    TODO: add in paging support."""

    items: List[SearchResult] = None

    @staticmethod
    def from_query_results(query_results: QueryResults) -> "SearchResults":
        items = [SearchResult.from_query_result(qr) for qr in query_results.items or []]
        return SearchResults(items=items)


class EmbeddingIndexPluginInstance(PluginInstance):
    """A persistent, read-optimized index over embeddings.

    This is currently implemented as an object which behaves like a PluginInstance even though
    it isn't from an implementation perspective on the back-end.
    """

    client: Client = Field(None, exclude=True)
    embedder: PluginInstance = Field(None, exclude=True)
    index: EmbeddingIndex = Field(None, exclude=True)

    def delete(self):
        """Delete the EmbeddingIndexPluginInstnace.

        For now, we will have this correspond to deleting the `index` but not the `embedder`. This is likely
        a temporary design.
        """
        return self.index.delete()

    def insert(self, tags: Union[Tag, List[Tag]], allow_long_records: bool = False):
        """Insert tags into the embedding index."""

        # Make a list if a single tag was provided
        if isinstance(tags, Tag):
            tags = [tags]

        for tag in tags:
            if not tag.text:
                raise SteamshipError(
                    message="Please set the `text` field of your Tag before inserting it into an index."
                )

            # Now we need to prepare an EmbeddingIndexItem of a particular shape that encodes the tag.
            metadata = tag.value or {}
            if not isinstance(metadata, dict):
                raise SteamshipError(
                    "Only Tags with a dict or None value can be embedded. "
                    + f"This tag had a value of type: {type(tag.value)}"
                )

            # To make this change Python-only, some fields are stached in `hit.metadata`.
            # This has the temporary consequence of these keys not being safe. This will be resolved when we spread
            # this refactor to the engine.
            metadata["_file_id"] = tag.file_id
            metadata["_tag_id"] = tag.id
            metadata["_block_id"] = tag.block_id
            tag.value = metadata

        embedded_items = [
            EmbeddedItem(
                value=tag.text,
                external_id=tag.name,
                external_type=tag.kind,
                metadata=tag.value,
            )
            for tag in tags
        ]

        # We always reindex in this new style; to not do so is to expose details (when embedding occurrs) we'd rather
        # not have users exercise control over.
        self.index.insert_many(embedded_items, reindex=True, allow_long_records=allow_long_records)

    def search(self, query: str, k: Optional[int] = None) -> Task[SearchResults]:
        """Search the embedding index.

        This wrapper implementation simply projects the `Hit` data structure into a `Tag`
        """
        if query is None or len(query.strip()) == 0:
            raise SteamshipError(message="Query field must be non-empty.")

        # Metadata will always be included; this is the equivalent of Tag.value
        wrapped_result = self.index.search(query, k=k, include_metadata=True)

        # For now, we'll have to do this synchronously since we're trying to avoid changing things on the engine.
        wrapped_result.wait()

        # We're going to do a switcheroo on the output type of Task here.
        search_results = SearchResults.from_query_results(wrapped_result.output)
        wrapped_result.output = search_results

        # Return the index's search result, but projected into the data structure of Tags
        return cast(Task[SearchResults], wrapped_result)

    @staticmethod
    def create(
        client: Any,
        plugin_id: str = None,
        plugin_handle: str = None,
        plugin_version_id: str = None,
        plugin_version_handle: str = None,
        handle: str = None,
        fetch_if_exists: bool = True,
        config: Dict[str, Any] = None,
    ) -> "EmbeddingIndexPluginInstance":
        """Create a class that simulates an embedding index re-implemented as a PluginInstance."""

        # Perform a manual config validation check since the configuration isn't actually being sent up to the Engine.
        # In this case, an embedding index has special behavior which is to instantiate/fetch an Embedder that it can use.
        if "embedder" not in config:
            raise SteamshipError(
                message="Config key missing. Please include a field named `embedder` with type `EmbedderInvocation`."
            )

        # Just for pydantic validation.
        embedder_invocation = EmbedderInvocation.parse_obj(config["embedder"])

        # Create the embedder
        embedder = client.use_plugin(**embedder_invocation.dict())

        # Create the index
        index = EmbeddingIndex.create(
            client=client,
            handle=handle,
            embedder_plugin_instance_handle=embedder.handle,
            fetch_if_exists=fetch_if_exists,
        )

        # Now return the plugin wrapper
        return EmbeddingIndexPluginInstance(
            id=index.id, handle=index.handle, index=index, embedder=embedder
        )