File size: 3,838 Bytes
54ec520
 
 
 
 
 
 
 
 
 
54eaeff
54ec520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import requests
import spaces
import gradio as gr
import re
import emoji
from ..prompts.prompt_templates import PromptTemplates
import faiss

class ImageRecommender:
    def __init__(self, config):
        self.config = config
        
    def read_image_from_url(self, url):
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        return img
        
    def extract_features_siglip(self, image):
        with torch.no_grad():
            inputs = self.config.processor(images=image, return_tensors="pt").to(self.config.device)
            image_features = self.config.model.get_image_features(**inputs)
        return image_features
        
    def process_image(self, image_path, num_results=2):
        input_image = Image.open(image_path).convert("RGB")
        input_features = self.extract_features_siglip(input_image)
        input_features = input_features.detach().cpu().numpy()
        input_features = np.float32(input_features)
        faiss.normalize_L2(input_features)
        
        distances, indices = self.config.index.search(input_features, num_results)
        gallery_output = []
        
        for i, v in enumerate(indices[0]):
            sim = -distances[0][i]
            image_url = self.config.df.iloc[v]["Link"]
            img_retrieved = self.read_image_from_url(image_url)
            gallery_output.append(img_retrieved)
            
        return gallery_output
        
    @spaces.GPU
    def infer(self, crop_image_path, full_image_path, state, language, task_type=None):
        style_gallery_output = []
        item_gallery_output = []
        
        if crop_image_path:
            item_gallery_output = self.process_image(crop_image_path, 2)
            style_gallery_output = self.process_image(full_image_path, 2)
        else:
            style_gallery_output = self.process_image(full_image_path, 4)
            
        msg = self.config.get_messages(language)
        state += [(None, msg)]
        
        return item_gallery_output, style_gallery_output, state, state

    async def item_associate(self, new_crop, openai_api_key, language, autoplay, length, 
                           log_state, sort_score, narrative, state, evt: gr.SelectData):
        rec_path = evt._data['value']['image']['path']
        return (
            state,
            state,
            None,
            log_state,
            None,
            gr.update(value=[]),
            rec_path,
            rec_path,
            "Item"
        )

    async def style_associate(self, image_path, openai_api_key, language, autoplay, 
                            length, log_state, sort_score, narrative, state, artist, 
                            evt: gr.SelectData):
        rec_path = evt._data['value']['image']['path']
        return (
            state,
            state,
            None,
            log_state,
            None,
            gr.update(value=[]),
            rec_path,
            rec_path,
            "Style"
        )

    def generate_recommendation_prompt(self, recommend_type, narrative, language, length, artist=None):

        narrative_value = PromptTemplates.NARRATIVE_MAPPING[narrative]
        prompt_type = 0 if recommend_type == "Item" else 1
        
        if narrative_value == 1 and recommend_type == "Style":
            return PromptTemplates.RECOMMENDATION_PROMPTS[prompt_type][narrative_value].format(
                language=language,
                length=length,
                artist=artist[8:] if artist else ""
            )
        else:
            return PromptTemplates.RECOMMENDATION_PROMPTS[prompt_type][narrative_value].format(
                language=language,
                length=length
            )