Spaces:
Running
Running
from transformers import ViTModel, AutoModelForMaskedLM, AutoTokenizer, ViTImageProcessor, DistilBertModel | |
from pinecone import Pinecone | |
import torch | |
from huggingface_hub import hf_hub_download | |
pc = Pinecone() | |
index = pc.Index("clipmodel") | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
from model import CLIPChemistryModel, TextEncoderHead, ImageEncoderHead | |
ENCODER_BASE = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
IMAGE_BASE = ViTModel.from_pretrained("google/vit-base-patch16-224") | |
text_encoder = TextEncoderHead(model=ENCODER_BASE) | |
image_encoder = ImageEncoderHead(model=IMAGE_BASE) | |
clip_model = CLIPChemistryModel(text_encoder=text_encoder, image_encoder=image_encoder) | |
model_name = "sebastiansarasti/clip_fashion" | |
filename = "best_model.pth" | |
file_path = hf_hub_download(repo_id=model_name, filename=filename) | |
clip_model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu'))) | |
te_final = clip_model.text_encoder | |
ie_final = clip_model.image_encoder | |
def process_text_for_encoder(text, model): | |
# tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=256) | |
input_ids = encoded_input['input_ids'] | |
attention_mask = encoded_input['attention_mask'] | |
model.eval() | |
with torch.no_grad(): | |
output = model(input_ids=input_ids, attention_mask=attention_mask) | |
return output.detach().numpy().tolist()[0] | |
def process_image_for_encoder(image, model): | |
# image = Image.open(BytesIO(image)) | |
print(type(image)) | |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
image_tensor = image_processor(image, | |
return_tensors="pt", | |
do_resize=True | |
)['pixel_values'] | |
model.eval() | |
with torch.no_grad(): | |
output = model(pixel_values=image_tensor) | |
return output.detach().numpy().tolist()[0] | |
def search_similarity(input, mode, top_k=5): | |
if mode == 'text': | |
output = process_text_for_encoder(input, model=te_final) | |
else: | |
output = input | |
if mode == 'text': | |
mode_search = 'image' | |
response = index.query( | |
namespace="space-" + mode_search + "-fashion", | |
vector=output, | |
top_k=top_k, | |
include_values=True, | |
include_metadata=True | |
) | |
similar_images = [value['metadata']['image'] for value in response['matches']] | |
return similar_images | |
elif mode == 'image': | |
mode_search = 'text' | |
response = index.query( | |
namespace="space-" + mode_search + "-fashion", | |
vector=output, | |
top_k=top_k, | |
include_values=True, | |
include_metadata=True | |
) | |
similar_text = [value['metadata']['text'] for value in response['matches']] | |
return similar_text | |
else: | |
raise ValueError("mode must be either 'text' or 'image'") | |
def process_image_for_encoder_gradio(image, is_bytes=True): | |
"""Procesa tanto imágenes en bytes como objetos PIL Image""" | |
try: | |
if is_bytes: | |
# Si la imagen viene en bytes | |
image = Image.open(BytesIO(image)) | |
else: | |
# Si la imagen ya es un objeto PIL Image o viene de gradio | |
if not isinstance(image, Image.Image): | |
# Si viene de gradio, podría ser un numpy array | |
image = Image.fromarray(image) | |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
image_tensor = image_processor(image, | |
return_tensors="pt", | |
do_resize=True | |
)['pixel_values'] | |
output = ie_final(pixel_values=image_tensor) | |
return output.detach().numpy().tolist()[0] | |
except Exception as e: | |
print(f"Error en process_image_for_encoder: {e}") | |
raise |