import torch # usersim_path_shoes = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes" # usersim_path_dresses = "http://www.dcs.gla.ac.uk/~craigm/fcrs/captioners/dresses_cap_caption_models" drive_path = '/content/drive/MyDrive/Datasets/mmir_usersim_resources/' data_type= ["shoes", "dresses", "shirts", "tops&tees"] usersim_path_shoes = drive_path + "checkpoints_usersim/shoes" usersim_path_dresses = drive_path + "checkpoints_usersim/dresses" usersim_path_shirts = drive_path + "checkpoints_usersim/shirts" usersim_path_topstees = drive_path + "checkpoints_usersim/topstees" usersim_path = [usersim_path_shoes, usersim_path_dresses, usersim_path_shirts, usersim_path_topstees] import captioning.captioner as captioner image_feat_params = {'model':'resnet101','model_root':drive_path + 'imagenet_weights','att_size':7} # image_feat_params = {'model':'resnet101','model_root':'','att_size':7} captioner_relative_shoes = captioner.Captioner(is_relative= True, model_path= usersim_path[0], image_feat_params=image_feat_params, data_type=data_type[0], load_resnet=True) captioner_relative_dresses = captioner.Captioner(is_relative= True, model_path= usersim_path[1], image_feat_params=image_feat_params, data_type=data_type[1], load_resnet=True) captioner_relative_shirts = captioner.Captioner(is_relative= True, model_path= usersim_path[2], image_feat_params=image_feat_params, data_type=data_type[2], load_resnet=True) captioner_relative_topstees = captioner.Captioner(is_relative= True, model_path= usersim_path[3], image_feat_params=image_feat_params, data_type=data_type[3], load_resnet=True) def generate_sentence_shoes(image_path_1, image_path_2): fc_feat, att_feat = captioner_relative_shoes.get_img_feat(image_path_1) fc_feat_ref, att_feat_ref = captioner_relative_shoes.get_img_feat(image_path_2) fc_feat = torch.unsqueeze(fc_feat, dim=0) att_feat = torch.unsqueeze(att_feat, dim=0) fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) seq, sents = captioner_relative_shoes.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) sentence = sents[0] return sentence def generate_sentence_dresses(image_path_1, image_path_2): fc_feat, att_feat = captioner_relative_dresses.get_img_feat(image_path_1) fc_feat_ref, att_feat_ref = captioner_relative_dresses.get_img_feat(image_path_2) fc_feat = torch.unsqueeze(fc_feat, dim=0) att_feat = torch.unsqueeze(att_feat, dim=0) fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) seq, sents = captioner_relative_dresses.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) sentence = sents[0] return sentence def generate_sentence_shirts(image_path_1, image_path_2): fc_feat, att_feat = captioner_relative_shirts.get_img_feat(image_path_1) fc_feat_ref, att_feat_ref = captioner_relative_shirts.get_img_feat(image_path_2) fc_feat = torch.unsqueeze(fc_feat, dim=0) att_feat = torch.unsqueeze(att_feat, dim=0) fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) seq, sents = captioner_relative_shirts.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) sentence = sents[0] return sentence def generate_sentence_topstees(image_path_1, image_path_2): fc_feat, att_feat = captioner_relative_topstees.get_img_feat(image_path_1) fc_feat_ref, att_feat_ref = captioner_relative_topstees.get_img_feat(image_path_2) fc_feat = torch.unsqueeze(fc_feat, dim=0) att_feat = torch.unsqueeze(att_feat, dim=0) fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) seq, sents = captioner_relative_topstees.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) sentence = sents[0] return sentence import numpy as np import gradio as gr examples_shoes = [["images/shoes/img_womens_athletic_shoes_1223.jpg", "images/shoes/img_womens_athletic_shoes_830.jpg"], ["images/shoes/img_womens_athletic_shoes_830.jpg", "images/shoes/img_womens_athletic_shoes_1223.jpg"], ["images/shoes/img_womens_high_heels_559.jpg", "images/shoes/img_womens_high_heels_690.jpg"], ["images/shoes/img_womens_high_heels_690.jpg", "images/shoes/img_womens_high_heels_559.jpg"]] examples_dresses = [["images/dresses/B007UZSPC8.jpg", "images/dresses/B006MPVW4U.jpg"], ["images/dresses/B005KMQQFQ.jpg", "images/dresses/B005QYY5W4.jpg"], ["images/dresses/B005OBAGD6.jpg", "images/dresses/B006U07GW4.jpg"], ["images/dresses/B0047Y0K0U.jpg", "images/dresses/B006TAM4CW.jpg"]] examples_shirts = [["images/shirts/B00305G9I4.jpg", "images/shirts/B005BLUUJY.jpg"], ["images/shirts/B004WSVYX8.jpg", "images/shirts/B008TP27PY.jpg"], ["images/shirts/B003INE0Q6.jpg", "images/shirts/B0051D0X2Q.jpg"], ["images/shirts/B00EZUKCCM.jpg", "images/shirts/B00B88ZKXA.jpg"]] examples_topstees = [["images/topstees/B0082993AO.jpg", "images/topstees/B008293HO2.jpg"], ["images/topstees/B006YN4J2C.jpg", "images/topstees/B0035EPUBW.jpg"], ["images/topstees/B00B5SKOMU.jpg", "images/topstees/B004H3XMYM.jpg"], ["images/topstees/B008DVXGO0.jpg", "images/topstees/B008JYNN30.jpg"] ] with gr.Blocks() as demo: gr.Markdown("Relative Captioning for Fashion.") with gr.Tab("Shoes"): with gr.Row(): target_shoes = gr.Image(source="upload", type="filepath", label="Target Image") candidate_shoes = gr.Image(source="upload", type="filepath", label="Candidate Image") output_text_shoes = gr.Textbox(label="Generated Sentence") shoes_btn = gr.Button("Generate") gr.Examples(examples_shoes, inputs=[target_shoes, candidate_shoes]) with gr.Tab("Dresses"): with gr.Row(): target_dresses = gr.Image(source="upload", type="filepath", label="Target Image") candidate_dresses = gr.Image(source="upload", type="filepath", label="Candidate Image") output_text_dresses = gr.Textbox(label="Generated Sentence") dresses_btn = gr.Button("Generate") gr.Examples(examples_dresses, inputs=[target_dresses, candidate_dresses]) with gr.Tab("Shirts"): with gr.Row(): target_shirts = gr.Image(source="upload", type="filepath", label="Target Image") candidate_shirts = gr.Image(source="upload", type="filepath", label="Candidate Image") output_text_shirts = gr.Textbox(label="Generated Sentence") shirts_btn = gr.Button("Generate") gr.Examples(examples_shirts, inputs=[target_shirts, candidate_shirts]) with gr.Tab("Tops&Tees"): with gr.Row(): target_topstees = gr.Image(source="upload", type="filepath", label="Target Image") candidate_topstees = gr.Image(source="upload", type="filepath", label="Candidate Image") output_text_topstees = gr.Textbox(label="Generated Sentence") topstees_btn = gr.Button("Generate") gr.Examples(examples_topstees, inputs=[target_topstees, candidate_topstees]) shoes_btn.click(generate_sentence_shoes, inputs=[target_shoes, candidate_shoes], outputs=output_text_shoes) dresses_btn.click(generate_sentence_dresses, inputs=[target_dresses, candidate_dresses], outputs=output_text_dresses) shirts_btn.click(generate_sentence_shirts, inputs=[target_shirts, candidate_shirts], outputs=output_text_shirts) topstees_btn.click(generate_sentence_topstees, inputs=[target_topstees, candidate_topstees], outputs=output_text_topstees) demo.queue(concurrency_count=3) demo.launch()