Spaces:
Runtime error
Runtime error
from huggingface_hub import from_pretrained_keras | |
import keras_cv | |
import gradio as gr | |
from tensorflow import keras | |
keras.mixed_precision.set_global_policy("mixed_float16") | |
# load keras model | |
resolution = 512 | |
dreambooth_model = keras_cv.models.StableDiffusion( | |
img_width=resolution, img_height=resolution, jit_compile=True, | |
) | |
loaded_diffusion_model = from_pretrained_keras("keras-dreambooth/monkey_island_style") | |
dreambooth_model._diffusion_model = loaded_diffusion_model | |
def generate_images(prompt: str, negative_prompt: str, num_imgs_to_gen: int, num_steps: int, guidance_scale: float): | |
generated_img = dreambooth_model.text_to_image( | |
prompt, | |
negative_prompt=negative_prompt, | |
batch_size=num_imgs_to_gen, | |
num_steps=num_steps, | |
unconditional_guidance_scale=guidance_scale, | |
) | |
return generated_img | |
# pass function, input type for prompt, the output for multiple images | |
gr.Interface( | |
title="Keras Dreambooth - Monkey Island Style π", | |
description="""This SD model has been fine-tuned to learn the art style of the game *Return To Monkey Island* using Dreambooth. | |
While not being perfect at imitating the stlye, my experiments have shown that it works best on fictional characters, such as Geralt of Rivia, Frodo Baggins or Harry Potter. | |
To use the new style make sure to include `in mnky style` to your prompt. | |
""", | |
fn=generate_images, | |
inputs=[ | |
gr.Textbox(label="Positive Prompt", value="han solo in mnky style, high quality, 4k, trending on artstation"), | |
gr.Textbox(label="Negative Prompt", value="bad, ugly"), | |
gr.Slider(label='Number of gen image', minimum=1, maximum=4, value=2, step=1), | |
gr.Slider(label="Inference Steps", value=50), | |
gr.Slider(label='Guidance scale', value=7, maximum=15, minimum=0, step=0.5), | |
], | |
outputs=[ | |
gr.Gallery(show_label=False).style(grid=(1,2)), | |
], | |
examples=[["geralt of rivia in mnky style, high quality, 4k, trending on artstation", "bad, ugly", 2, 50, 7]], | |
).queue().launch(debug=True) | |