Spaces:
Running
Running
import torch | |
import numpy as np | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import spaces | |
import gradio as gr | |
import re | |
import emoji | |
from ..prompts.prompt_templates import PromptTemplates | |
import faiss | |
class ImageRecommender: | |
def __init__(self, config): | |
self.config = config | |
def read_image_from_url(self, url): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)).convert("RGB") | |
return img | |
def extract_features_siglip(self, image): | |
with torch.no_grad(): | |
inputs = self.config.processor(images=image, return_tensors="pt").to(self.config.device) | |
image_features = self.config.model.get_image_features(**inputs) | |
return image_features | |
def process_image(self, image_path, num_results=2): | |
input_image = Image.open(image_path).convert("RGB") | |
input_features = self.extract_features_siglip(input_image) | |
input_features = input_features.detach().cpu().numpy() | |
input_features = np.float32(input_features) | |
faiss.normalize_L2(input_features) | |
distances, indices = self.config.index.search(input_features, num_results) | |
gallery_output = [] | |
for i, v in enumerate(indices[0]): | |
sim = -distances[0][i] | |
image_url = self.config.df.iloc[v]["Link"] | |
img_retrieved = self.read_image_from_url(image_url) | |
gallery_output.append(img_retrieved) | |
return gallery_output | |
def infer(self, crop_image_path, full_image_path, state, language, task_type=None): | |
style_gallery_output = [] | |
item_gallery_output = [] | |
if crop_image_path: | |
item_gallery_output = self.process_image(crop_image_path, 2) | |
style_gallery_output = self.process_image(full_image_path, 2) | |
else: | |
style_gallery_output = self.process_image(full_image_path, 4) | |
msg = self.config.get_messages(language) | |
state += [(None, msg)] | |
return item_gallery_output, style_gallery_output, state, state | |
async def item_associate(self, new_crop, openai_api_key, language, autoplay, length, | |
log_state, sort_score, narrative, state, evt: gr.SelectData): | |
rec_path = evt._data['value']['image']['path'] | |
return ( | |
state, | |
state, | |
None, | |
log_state, | |
None, | |
gr.update(value=[]), | |
rec_path, | |
rec_path, | |
"Item" | |
) | |
async def style_associate(self, image_path, openai_api_key, language, autoplay, | |
length, log_state, sort_score, narrative, state, artist, | |
evt: gr.SelectData): | |
rec_path = evt._data['value']['image']['path'] | |
return ( | |
state, | |
state, | |
None, | |
log_state, | |
None, | |
gr.update(value=[]), | |
rec_path, | |
rec_path, | |
"Style" | |
) | |
def generate_recommendation_prompt(self, recommend_type, narrative, language, length, artist=None): | |
narrative_value = PromptTemplates.NARRATIVE_MAPPING[narrative] | |
prompt_type = 0 if recommend_type == "Item" else 1 | |
if narrative_value == 1 and recommend_type == "Style": | |
return PromptTemplates.RECOMMENDATION_PROMPTS[prompt_type][narrative_value].format( | |
language=language, | |
length=length, | |
artist=artist[8:] if artist else "" | |
) | |
else: | |
return PromptTemplates.RECOMMENDATION_PROMPTS[prompt_type][narrative_value].format( | |
language=language, | |
length=length | |
) | |