Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Optional, Union, cast | |
from pydantic import Field | |
from steamship.base.client import Client | |
from steamship.base.error import SteamshipError | |
from steamship.base.model import CamelModel | |
from steamship.base.tasks import Task | |
from steamship.data.embeddings import EmbeddedItem, EmbeddingIndex, QueryResult, QueryResults | |
from steamship.data.plugin.plugin_instance import PluginInstance | |
from steamship.data.tags.tag import Tag | |
class EmbedderInvocation(CamelModel): | |
"""The parameters capable of creating/fetching an Embedder (Tagger) Plugin Instance.""" | |
plugin_handle: str | |
instance_handle: Optional[str] = None | |
config: Optional[Dict[str, Any]] = None | |
version: Optional[str] = None | |
fetch_if_exists: bool = True | |
class SearchResult(CamelModel): | |
"""A single scored search result -- which is always a tag. | |
This class is intended to eventually replace the QueryResult object currently used with the Embedding layer.""" | |
tag: Optional[Tag] = None | |
score: Optional[float] = None | |
def from_query_result(query_result: QueryResult) -> "SearchResult": | |
hit = query_result.value | |
value = hit.metadata or {} | |
# To make this change Python-only, some fields are stached in `hit.metadata`. | |
# This has the temporary consequence of these keys not being safe. This will be resolved when we spread | |
# this refactor to the engine. | |
block_id = None | |
if "_block_id" in value: | |
block_id = value.get("_block_id") | |
del value["_block_id"] | |
file_id = None | |
if "_file_id" in value: | |
file_id = value.get("_file_id") | |
del value["_file_id"] | |
tag_id = None | |
if "_tag_id" in value: | |
tag_id = value.get("_tag_id") | |
del value["_tag_id"] | |
tag = Tag( | |
id=hit.id, | |
kind=hit.external_type, | |
name=hit.external_id, | |
block_id=block_id, | |
tag_id=tag_id, | |
file_id=file_id, | |
text=hit.value, | |
value=value, | |
) | |
return SearchResult(tag=tag, score=query_result.score) | |
class SearchResults(CamelModel): | |
"""Results of a search operation -- which is always a list of ranked tag. | |
This class is intended to eventually replace the QueryResults object currently used with the Embedding layer. | |
TODO: add in paging support.""" | |
items: List[SearchResult] = None | |
def from_query_results(query_results: QueryResults) -> "SearchResults": | |
items = [SearchResult.from_query_result(qr) for qr in query_results.items or []] | |
return SearchResults(items=items) | |
class EmbeddingIndexPluginInstance(PluginInstance): | |
"""A persistent, read-optimized index over embeddings. | |
This is currently implemented as an object which behaves like a PluginInstance even though | |
it isn't from an implementation perspective on the back-end. | |
""" | |
client: Client = Field(None, exclude=True) | |
embedder: PluginInstance = Field(None, exclude=True) | |
index: EmbeddingIndex = Field(None, exclude=True) | |
def delete(self): | |
"""Delete the EmbeddingIndexPluginInstnace. | |
For now, we will have this correspond to deleting the `index` but not the `embedder`. This is likely | |
a temporary design. | |
""" | |
return self.index.delete() | |
def insert(self, tags: Union[Tag, List[Tag]], allow_long_records: bool = False): | |
"""Insert tags into the embedding index.""" | |
# Make a list if a single tag was provided | |
if isinstance(tags, Tag): | |
tags = [tags] | |
for tag in tags: | |
if not tag.text: | |
raise SteamshipError( | |
message="Please set the `text` field of your Tag before inserting it into an index." | |
) | |
# Now we need to prepare an EmbeddingIndexItem of a particular shape that encodes the tag. | |
metadata = tag.value or {} | |
if not isinstance(metadata, dict): | |
raise SteamshipError( | |
"Only Tags with a dict or None value can be embedded. " | |
+ f"This tag had a value of type: {type(tag.value)}" | |
) | |
# To make this change Python-only, some fields are stached in `hit.metadata`. | |
# This has the temporary consequence of these keys not being safe. This will be resolved when we spread | |
# this refactor to the engine. | |
metadata["_file_id"] = tag.file_id | |
metadata["_tag_id"] = tag.id | |
metadata["_block_id"] = tag.block_id | |
tag.value = metadata | |
embedded_items = [ | |
EmbeddedItem( | |
value=tag.text, | |
external_id=tag.name, | |
external_type=tag.kind, | |
metadata=tag.value, | |
) | |
for tag in tags | |
] | |
# We always reindex in this new style; to not do so is to expose details (when embedding occurrs) we'd rather | |
# not have users exercise control over. | |
self.index.insert_many(embedded_items, reindex=True, allow_long_records=allow_long_records) | |
def search(self, query: str, k: Optional[int] = None) -> Task[SearchResults]: | |
"""Search the embedding index. | |
This wrapper implementation simply projects the `Hit` data structure into a `Tag` | |
""" | |
if query is None or len(query.strip()) == 0: | |
raise SteamshipError(message="Query field must be non-empty.") | |
# Metadata will always be included; this is the equivalent of Tag.value | |
wrapped_result = self.index.search(query, k=k, include_metadata=True) | |
# For now, we'll have to do this synchronously since we're trying to avoid changing things on the engine. | |
wrapped_result.wait() | |
# We're going to do a switcheroo on the output type of Task here. | |
search_results = SearchResults.from_query_results(wrapped_result.output) | |
wrapped_result.output = search_results | |
# Return the index's search result, but projected into the data structure of Tags | |
return cast(Task[SearchResults], wrapped_result) | |
def create( | |
client: Any, | |
plugin_id: str = None, | |
plugin_handle: str = None, | |
plugin_version_id: str = None, | |
plugin_version_handle: str = None, | |
handle: str = None, | |
fetch_if_exists: bool = True, | |
config: Dict[str, Any] = None, | |
) -> "EmbeddingIndexPluginInstance": | |
"""Create a class that simulates an embedding index re-implemented as a PluginInstance.""" | |
# Perform a manual config validation check since the configuration isn't actually being sent up to the Engine. | |
# In this case, an embedding index has special behavior which is to instantiate/fetch an Embedder that it can use. | |
if "embedder" not in config: | |
raise SteamshipError( | |
message="Config key missing. Please include a field named `embedder` with type `EmbedderInvocation`." | |
) | |
# Just for pydantic validation. | |
embedder_invocation = EmbedderInvocation.parse_obj(config["embedder"]) | |
# Create the embedder | |
embedder = client.use_plugin(**embedder_invocation.dict()) | |
# Create the index | |
index = EmbeddingIndex.create( | |
client=client, | |
handle=handle, | |
embedder_plugin_instance_handle=embedder.handle, | |
fetch_if_exists=fetch_if_exists, | |
) | |
# Now return the plugin wrapper | |
return EmbeddingIndexPluginInstance( | |
id=index.id, handle=index.handle, index=index, embedder=embedder | |
) | |