chicelli's picture
Upload 21 files
9f68e7c verified
raw
history blame contribute delete
971 Bytes
import os
from typing import Any
import torch
from pinecone import Pinecone, ServerlessSpec
from img2art_search.data.transforms import inversetransform
pinecone_api_key = os.environ["PINECONE_API_KEY"]
def inverse_transform_img(img: torch.Tensor) -> torch.Tensor:
inv_tensor = inversetransform(img)
tensor_image = (inv_tensor * 255).byte()
return tensor_image.permute(1, 2, 0)
def get_pinecone_client() -> Pinecone:
pc = Pinecone(api_key=pinecone_api_key)
return pc
def get_or_create_pinecone_index(
pc: Pinecone, index_name: str = "img2art-search", embeddings_dim: int = 768
) -> Any:
indexes_names = [index.name for index in pc.list_indexes()]
if index_name not in indexes_names:
pc.create_index(
name=index_name,
dimension=embeddings_dim,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
index = pc.Index(index_name)
return index