File size: 2,670 Bytes
5d82991
43c4ef2
 
 
 
 
5d82991
f24c655
f8fef9e
f24c655
5d82991
43c4ef2
 
 
 
5d82991
43c4ef2
 
 
 
 
 
 
 
 
 
 
 
f24c655
 
 
43c4ef2
 
 
 
f24c655
43c4ef2
f24c655
 
 
 
 
 
 
 
 
43c4ef2
 
 
 
f8fef9e
43c4ef2
f8fef9e
 
 
 
 
 
f24c655
 
 
43c4ef2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
from typing import Dict, Any, List
import requests
import numpy as np
from fashion_clip.fashion_clip import FashionCLIP
from io import BytesIO
import base64



class EndpointHandler:
    def __init__(self, path="."):
        # Preload all the elements you are going to need at inference.
        self.fclip = FashionCLIP('fashion-clip')

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data: A dictionary with keys:
                - "text": List[str] (required) - The list of text descriptions.
                - "image": List[Union[str, PIL.Image.Image, np.ndarray]] (required) - The list of images. Images can be URLs, PIL Images, or numpy arrays.

        Returns:
            List[Dict[str, Any]]: The embeddings for the text and/or images.
        """
        # Extract text and images from the input data
        texts = data['inputs'].get("text", [])
        images = data['inputs'].get("image", [])


        # Convert image URLs to PIL Images if needed
        images = [self._load_image(img) for img in images]

        results = {}

        if images:
            image_embeddings = self.fclip.encode_images(images, batch_size=32)
            image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True)
            results["image_embeddings"] = image_embeddings.tolist()
  
        if texts:
            text_embeddings = self.fclip.encode_text(texts, batch_size=32)
            text_embeddings = text_embeddings/np.linalg.norm(text_embeddings, ord=2, axis=-1, keepdims=True)
            results["text_embeddings"] = text_embeddings.tolist()

        return results

    def _load_image(self, img):
        """Helper function to load an image from a URL, PIL Image, numpy array, bytes, or base64 string."""
        if isinstance(img, str):
            if img.startswith('http'):
                # If the image is a URL
                img = Image.open(requests.get(img, stream=True).raw)
            else:
                # If the image is a base64-encoded string
                img = Image.open(BytesIO(base64.b64decode(img)))
        elif isinstance(img, bytes):
            # If the image is in bytes
            img = Image.open(BytesIO(img))
        elif isinstance(img, Image.Image):
            # If the image is already a PIL Image
            pass
        elif isinstance(img, np.ndarray):
            # If the image is a numpy array
            img = Image.fromarray(img)
        else:
            raise ValueError("Unsupported image type.")
        return img