import gradio as gr import sys sys.path.append(".") sys.path.append("..") from model_loader import Model from PIL import Image import cv2 import io from huggingface_hub import snapshot_download models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model") # models fron pretrained/latent_transformer folder models_files = { "anime": "anime.pt", "car": "car.pt", "cat": "cat.pt", "church": "church.pt", "ffhq": "ffhq.pt", } models = {name: Model(models_path + "/" + path) for name, path in models_files.items()} def cv_to_pil(img): return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB)) def random_sample(model_name: str): model = models[model_name] img, latents = model.random_sample() pil_img = cv_to_pil(img) return pil_img, model_name, latents def transform(model_state, latents_state, dx=0, dy=0, dz=0, sxsy=[128, 128]): model = models[model_state] dx = dx dy = dy dz = dz sx = sxsy[0] sy = sxsy[1] stop_points = [] img, latents_state = model.transform( latents_state, dz, dxy=[dx, dy], sxsy=[sx, sy], stop_points=stop_points ) pil_img = cv_to_pil(img) return pil_img, latents_state def change_style(image: Image.Image, model_state, latents_state): model = models[model_state] img, latents_state = model.change_style(latents_state) pil_img = cv_to_pil(img) return pil_img, latents_state def reset(model_state, latents_state): model = models[model_state] img, latents_state = model.reset(latents_state) pil_img = cv_to_pil(img) return pil_img, latents_state def image_click(evt: gr.SelectData): click_pos = evt.index return click_pos with gr.Blocks() as block: model_state = gr.State(value="cat") latents_state = gr.State({}) sxsy = gr.State([128, 128]) gr.Markdown("# UserControllableLT: User controllable latent transformer") gr.Markdown("## Select model") with gr.Row(): with gr.Column(): model_name = gr.Dropdown( choices=list(models_files.keys()), label="Select Pretrained Model", value="cat", ) with gr.Row(): button = gr.Button("Random sample") reset_btn = gr.Button("Reset") change_style_bt = gr.Button("Change style") dx = gr.Slider( minimum=-256, maximum=256, step_size=0.1, label="dx", value=0.0 ) dy = gr.Slider( minimum=-256, maximum=256, step_size=0.1, label="dy", value=0.0 ) dz = gr.Slider( minimum=-5, maximum=5, step_size=0.01, label="dz", value=0.0 ) image = gr.Image(type="pil", label="").style(height=500) with gr.Column(): html = gr.HTML(label="output") image.select(image_click, inputs=None, outputs=sxsy) button.click( random_sample, inputs=[model_name], outputs=[image, model_state, latents_state] ) reset_btn.click( reset, inputs=[model_state, latents_state], outputs=[image, latents_state], ) change_style_bt.click( change_style, inputs=[image, model_state, latents_state], outputs=[image, latents_state], ) dx.change( transform, inputs=[model_state, latents_state, dx, dy, dz, sxsy], outputs=[image, latents_state], show_progress=False, ) dy.change( transform, inputs=[model_state, latents_state, dx, dy, dz, sxsy], outputs=[image, latents_state], show_progress=False, ) dz.change( transform, inputs=[model_state, latents_state, dx, dy, dz, sxsy], outputs=[image, latents_state], show_progress=False, ) block.queue() block.launch()