File size: 3,882 Bytes
0d38ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor, DistilBertModel
from pinecone import Pinecone
import torch


pc = Pinecone()
index = pc.Index("clipmodel")


from io import BytesIO
import base64
from PIL import Image

import sys

sys.path.append('../src')

from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead


ENCODER_BASE = DistilBertModel.from_pretrained("distilbert-base-uncased")
IMAGE_BASE = ViTModel.from_pretrained("google/vit-base-patch16-224")
text_encoder = TextEncoderHead(model=ENCODER_BASE)
image_encoder = ImageEncoderHead(model=IMAGE_BASE)

clip_model = CLIPChemistryModel(text_encoder=text_encoder, image_encoder=image_encoder)

clip_model.load_state_dict(torch.load('/Users/sebastianalejandrosarastizambonino/Documents/projects/CLIP_Pytorch/src/best_model_fashion.pth', map_location=torch.device('cpu')))

te_final = clip_model.text_encoder
ie_final = clip_model.image_encoder

def process_text_for_encoder(text, model):
    # tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=256)
    input_ids = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    output = model(input_ids=input_ids, attention_mask=attention_mask)
    return output.detach().numpy().tolist()[0]

def process_image_for_encoder(image, model):
    # image = Image.open(BytesIO(image))
    print(type(image))
    image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    image_tensor = image_processor(image, 
            return_tensors="pt", 
            do_resize=True
            )['pixel_values']
    output =  model(pixel_values=image_tensor)
    return output.detach().numpy().tolist()[0]

def search_similarity(input, mode, top_k=5):
    if mode == 'text':
        output = process_text_for_encoder(input, model=te_final)
    else:
        output = input
    
    if mode == 'text':
        mode_search = 'image'
        response = index.query(
            namespace="space-" + mode_search + "-fashion",
            vector=output,
            top_k=top_k,
            include_values=True,
            include_metadata=True
        )
        similar_images = [value['metadata']['image'] for value in response['matches']]
        return similar_images
    elif mode == 'image':
        mode_search = 'text'
        response = index.query(
            namespace="space-" + mode_search + "-fashion",
            vector=output,
            top_k=top_k,
            include_values=True,
            include_metadata=True
        )
        similar_text = [value['metadata']['text'] for value in response['matches']]
        return similar_text
    else:
        raise ValueError("mode must be either 'text' or 'image'")
    
def process_image_for_encoder_gradio(image, is_bytes=True):
    """Procesa tanto imágenes en bytes como objetos PIL Image"""
    try:
        if is_bytes:
            # Si la imagen viene en bytes
            image = Image.open(BytesIO(image))
        else:
            # Si la imagen ya es un objeto PIL Image o viene de gradio
            if not isinstance(image, Image.Image):
                # Si viene de gradio, podría ser un numpy array
                image = Image.fromarray(image)
        
        image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        image_tensor = image_processor(image, 
                return_tensors="pt", 
                do_resize=True
                )['pixel_values']
        output = ie_final(pixel_values=image_tensor)
        return output.detach().numpy().tolist()[0]
    except Exception as e:
        print(f"Error en process_image_for_encoder: {e}")
        raise