File size: 7,852 Bytes
e53c320
 
 
 
 
5638e50
e53c320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7774775
 
e53c320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7774775
e53c320
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch

# usersim_path_shoes = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes"
# usersim_path_dresses = "http://www.dcs.gla.ac.uk/~craigm/fcrs/captioners/dresses_cap_caption_models"

drive_path = 'mmir_usersim_resources/'

data_type= ["shoes", "dresses", "shirts", "tops&tees"]

usersim_path_shoes = drive_path + "checkpoints_usersim/shoes"
usersim_path_dresses = drive_path + "checkpoints_usersim/dresses"
usersim_path_shirts = drive_path + "checkpoints_usersim/shirts"
usersim_path_topstees = drive_path + "checkpoints_usersim/topstees"
usersim_path = [usersim_path_shoes, usersim_path_dresses, usersim_path_shirts, usersim_path_topstees]

import captioning.captioner as captioner
image_feat_params = {'model':'resnet101','model_root':drive_path + 'imagenet_weights','att_size':7}
# image_feat_params = {'model':'resnet101','model_root':'','att_size':7}

captioner_relative_shoes = captioner.Captioner(is_relative= True, model_path= usersim_path[0], image_feat_params=image_feat_params, data_type=data_type[0], load_resnet=True)
captioner_relative_dresses = captioner.Captioner(is_relative= True, model_path= usersim_path[1], image_feat_params=image_feat_params, data_type=data_type[1], load_resnet=True)
captioner_relative_shirts = captioner.Captioner(is_relative= True, model_path= usersim_path[2], image_feat_params=image_feat_params, data_type=data_type[2], load_resnet=True)
captioner_relative_topstees = captioner.Captioner(is_relative= True, model_path= usersim_path[3], image_feat_params=image_feat_params, data_type=data_type[3], load_resnet=True)

def generate_sentence_shoes(image_path_1, image_path_2):
    fc_feat, att_feat = captioner_relative_shoes.get_img_feat(image_path_1)
    fc_feat_ref, att_feat_ref = captioner_relative_shoes.get_img_feat(image_path_2)

    fc_feat = torch.unsqueeze(fc_feat, dim=0)
    att_feat = torch.unsqueeze(att_feat, dim=0)
    fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
    att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)

    seq, sents = captioner_relative_shoes.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))

    sentence = sents[0]
    return sentence

def generate_sentence_dresses(image_path_1, image_path_2):
    fc_feat, att_feat = captioner_relative_dresses.get_img_feat(image_path_1)
    fc_feat_ref, att_feat_ref = captioner_relative_dresses.get_img_feat(image_path_2)

    fc_feat = torch.unsqueeze(fc_feat, dim=0)
    att_feat = torch.unsqueeze(att_feat, dim=0)
    fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
    att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)

    seq, sents = captioner_relative_dresses.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))

    sentence = sents[0]
    return sentence

def generate_sentence_shirts(image_path_1, image_path_2):
    fc_feat, att_feat = captioner_relative_shirts.get_img_feat(image_path_1)
    fc_feat_ref, att_feat_ref = captioner_relative_shirts.get_img_feat(image_path_2)

    fc_feat = torch.unsqueeze(fc_feat, dim=0)
    att_feat = torch.unsqueeze(att_feat, dim=0)
    fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
    att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)

    seq, sents = captioner_relative_shirts.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))

    sentence = sents[0]
    return sentence

def generate_sentence_topstees(image_path_1, image_path_2):
    fc_feat, att_feat = captioner_relative_topstees.get_img_feat(image_path_1)
    fc_feat_ref, att_feat_ref = captioner_relative_topstees.get_img_feat(image_path_2)

    fc_feat = torch.unsqueeze(fc_feat, dim=0)
    att_feat = torch.unsqueeze(att_feat, dim=0)
    fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
    att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)

    seq, sents = captioner_relative_topstees.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))

    sentence = sents[0]
    return sentence

import numpy as np
import gradio as gr

examples_shoes = [["images/shoes/img_womens_athletic_shoes_1223.jpg", "images/shoes/img_womens_athletic_shoes_830.jpg"],
            ["images/shoes/img_womens_athletic_shoes_830.jpg", "images/shoes/img_womens_athletic_shoes_1223.jpg"],
            ["images/shoes/img_womens_high_heels_559.jpg", "images/shoes/img_womens_high_heels_690.jpg"],
            ["images/shoes/img_womens_high_heels_690.jpg", "images/shoes/img_womens_high_heels_559.jpg"]]

examples_dresses = [["images/dresses/B007UZSPC8.jpg", "images/dresses/B006MPVW4U.jpg"],
                    ["images/dresses/B005KMQQFQ.jpg", "images/dresses/B005QYY5W4.jpg"],
                    ["images/dresses/B005OBAGD6.jpg", "images/dresses/B006U07GW4.jpg"],
                    ["images/dresses/B0047Y0K0U.jpg", "images/dresses/B006TAM4CW.jpg"]]
examples_shirts = [["images/shirts/B00305G9I4.jpg", "images/shirts/B005BLUUJY.jpg"],
                   ["images/shirts/B004WSVYX8.jpg", "images/shirts/B008TP27PY.jpg"],
                   ["images/shirts/B003INE0Q6.jpg", "images/shirts/B0051D0X2Q.jpg"],
                   ["images/shirts/B00EZUKCCM.jpg", "images/shirts/B00B88ZKXA.jpg"]]
examples_topstees = [["images/topstees/B0082993AO.jpg", "images/topstees/B008293HO2.jpg"],
                     ["images/topstees/B006YN4J2C.jpg", "images/topstees/B0035EPUBW.jpg"],
                     ["images/topstees/B00B5SKOMU.jpg", "images/topstees/B004H3XMYM.jpg"],
                     ["images/topstees/B008DVXGO0.jpg", "images/topstees/B008JYNN30.jpg"]
                     ]

with gr.Blocks() as demo:
    gr.Markdown("Relative Captioning for Fashion.")
    with gr.Tab("Shoes"):
        with gr.Row():
          target_shoes = gr.Image(source="upload", type="filepath", label="Target Image")
          candidate_shoes = gr.Image(source="upload", type="filepath", label="Candidate Image")
          output_text_shoes = gr.Textbox(label="Generated Sentence")
        shoes_btn = gr.Button("Generate")
        gr.Examples(examples_shoes, inputs=[target_shoes, candidate_shoes])
    with gr.Tab("Dresses"):
        with gr.Row():
          target_dresses = gr.Image(source="upload", type="filepath", label="Target Image")
          candidate_dresses = gr.Image(source="upload", type="filepath", label="Candidate Image")
          output_text_dresses = gr.Textbox(label="Generated Sentence")
        dresses_btn = gr.Button("Generate")
        gr.Examples(examples_dresses, inputs=[target_dresses, candidate_dresses])
    with gr.Tab("Shirts"):
        with gr.Row():
          target_shirts = gr.Image(source="upload", type="filepath", label="Target Image")
          candidate_shirts = gr.Image(source="upload", type="filepath", label="Candidate Image")
          output_text_shirts = gr.Textbox(label="Generated Sentence")
        shirts_btn = gr.Button("Generate")
        gr.Examples(examples_shirts, inputs=[target_shirts, candidate_shirts])
    with gr.Tab("Tops&Tees"):
        with gr.Row():
          target_topstees = gr.Image(source="upload", type="filepath", label="Target Image")
          candidate_topstees = gr.Image(source="upload", type="filepath", label="Candidate Image")
          output_text_topstees = gr.Textbox(label="Generated Sentence")
        topstees_btn = gr.Button("Generate")
        gr.Examples(examples_topstees, inputs=[target_topstees, candidate_topstees])

    shoes_btn.click(generate_sentence_shoes, inputs=[target_shoes, candidate_shoes], outputs=output_text_shoes)
    dresses_btn.click(generate_sentence_dresses, inputs=[target_dresses, candidate_dresses], outputs=output_text_dresses)
    shirts_btn.click(generate_sentence_shirts, inputs=[target_shirts, candidate_shirts], outputs=output_text_shirts)
    topstees_btn.click(generate_sentence_topstees, inputs=[target_topstees, candidate_topstees], outputs=output_text_topstees)

demo.queue(concurrency_count=3)
demo.launch()