eexitimport os, time, copy os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" from PIL import Image import gradio as gr import numpy as np import torch from transformers import logging logging.set_verbosity_error() from fromage import models from fromage import utils BASE_WIDTH = 512 MODEL_DIR = './fromage_model/fromage_vis4' def upload_image(file): return Image.open(file) def upload_button_config(): return gr.update(visible=False) def upload_textbox_config(text_in): return gr.update(visible=True) class ChatBotCheese: def __init__(self): from huggingface_hub import hf_hub_download model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar") self.model = models.load_fromage(MODEL_DIR, model_ckpt_path) self.curr_image = None self.chat_history = '' def add_image(self, state, image_in): state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")] self.curr_image = Image.open(image_in.name).convert('RGB') return state, state def save_im(self, image_pil): file_name = f"{int(time.time())}_{np.random.randint(100)}.png" image_pil.save(file_name) return file_name def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp): # model_outputs = ["heyo", []] self.chat_history += f'Q: {input_text} \nA:' if self.curr_image is not None: model_outputs = self.model.generate_for_images_and_texts([self.curr_image, self.chat_history], num_words=num_words, max_num_rets=num_ims, ret_scale_factor=ret_scale_factor, temperature=temp) else: model_outputs = self.model.generate_for_images_and_texts([self.chat_history], max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp) self.chat_history += ' '.join([s for s in model_outputs if type(s) == str]) + '\n' im_names = [] if len(model_outputs) > 1: im_names = [self.save_im(im) for im in model_outputs[1]] response = model_outputs[0] for im_name in im_names: response += f'' state.append((input_text, response.replace("[RET]", ""))) self.curr_image = None return state, state def reset(self): self.chat_history = "" self.curr_image = None return [], [] def main(self): with gr.Blocks(css="#chatbot .overflow-y-auto{height:1500px}") as demo: gr.Markdown( """ ## FROMAGe ### Grounding Language Models to Images for Multimodal Generation Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried
[Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage)
- Upload an image (optional) - Chat with FROMAGe! - Check out the examples at the bottom! """ ) chatbot = gr.Chatbot(elem_id="chatbot") gr_state = gr.State([]) with gr.Row(): with gr.Column(scale=0.85): txt = gr.Textbox(show_label=False, placeholder="Upload an image first [Optional]. Then enter text and press enter,").style(container=False) with gr.Column(scale=0.15, min_width=0): btn = gr.UploadButton("🖼️", file_types=["image"]) with gr.Row(): with gr.Column(scale=0.20, min_width=0): reset_btn = gr.Button("Reset Messages") gr_ret_scale_factor = gr.Number(value=1.0, label="Increased prob of returning images", interactive=True) gr_num_ims = gr.Number(value=3, precision=1, label="Max # of Images returned", interactive=True) gr_num_words = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True) gr_temp = gr.Number(value=0.0, label="Temperature", interactive=True) with gr.Row(): gr.Image("example_1.png", label="Example 1") gr.Image("example_2.png", label="Example 2") gr.Image("example_3.png", label="Example 3") txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp], [gr_state, chatbot]) txt.submit(lambda :"", None, txt) btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot]) reset_btn.click(self.reset, [], [gr_state, chatbot]) # chatbot.change(fn = upload_button_config, outputs=btn_upload) # text_in.submit(None, [], [], _js = "() => document.getElementById('#chatbot-component').scrollTop = document.getElementById('#chatbot-component').scrollHeight") demo.launch(share=False, server_name="0.0.0.0") def main(): cheddar = ChatBotCheese() cheddar.main() if __name__ == "__main__": cheddar = ChatBotCheese() cheddar.main()